{-# LANGUAGE DeriveGeneric #-}
-- | Matrix datatype and operations.
--
--   Every provided example has been tested.
--   Run @cabal test@ for further tests.
module Data.Matrix (
    -- * Matrix type
    Matrix , prettyMatrix
  , nrows , ncols
  , forceMatrix
    -- * Builders
  , matrix
  , rowVector
  , colVector
    -- ** Special matrices
  , zero
  , identity
  , diagonalList
  , diagonal
  , permMatrix
    -- * List conversions
  , fromList , fromLists
  , toList   , toLists
    -- * Accessing
  , getElem , (!) , unsafeGet , safeGet, safeSet
  , getRow  , safeGetRow , getCol , safeGetCol
  , getDiag
  , getMatrixAsVector
    -- * Manipulating matrices
  , setElem
  , unsafeSet
  , transpose , setSize , extendTo
  , inverse, rref
  , mapRow , mapCol, mapPos
    -- * Submatrices
    -- ** Splitting blocks
  , submatrix
  , minorMatrix
  , splitBlocks
   -- ** Joining blocks
  , (<|>) , (<->)
  , joinBlocks
    -- * Matrix operations
  , elementwise, elementwiseUnsafe
    -- * Matrix multiplication
    -- ** About matrix multiplication
    -- $mult

    -- ** Functions
  , multStd
  , multStd2
  , multStrassen
  , multStrassenMixed
    -- * Linear transformations
  , scaleMatrix
  , scaleRow
  , combineRows
  , switchRows
  , switchCols
    -- * Decompositions
  , luDecomp , luDecompUnsafe
  , luDecomp', luDecompUnsafe'
  , cholDecomp
    -- * Properties
  , trace , diagProd
    -- ** Determinants
  , detLaplace
  , detLU
  , flatten
  ) where

import Prelude hiding (foldl1)
-- Classes
import Control.DeepSeq
import Control.Monad (forM_)
import Control.Loop (numLoop,numLoopFold)
import Data.Foldable (Foldable, foldMap, foldl1)
import Data.Maybe
import Data.Monoid
import qualified Data.Semigroup as S
import Data.Traversable
import Control.Applicative(Applicative, (<$>), (<*>), pure)
import GHC.Generics (Generic)
-- Data
import           Control.Monad.Primitive (PrimMonad, PrimState)
import           Data.List               (maximumBy,foldl1',find)
import           Data.Ord                (comparing)
import qualified Data.Vector             as V
import qualified Data.Vector.Mutable     as MV

-------------------------------------------------------
-------------------------------------------------------
---- MATRIX TYPE

encode :: Int -> (Int,Int) -> Int
{-# INLINE encode #-}
encode :: Int -> (Int, Int) -> Int
encode Int
m (Int
i,Int
j) = (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1

decode :: Int -> Int -> (Int,Int)
{-# INLINE decode #-}
decode :: Int -> Int -> (Int, Int)
decode Int
m Int
k = (Int
qInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1,Int
rInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
 where
  (Int
q,Int
r) = Int -> Int -> (Int, Int)
forall a. Integral a => a -> a -> (a, a)
quotRem Int
k Int
m

-- | Type of matrices.
--
--   Elements can be of any type. Rows and columns
--   are indexed starting by 1. This means that, if @m :: Matrix a@ and
--   @i,j :: Int@, then @m ! (i,j)@ is the element in the @i@-th row and
--   @j@-th column of @m@.
data Matrix a = M {
   forall a. Matrix a -> Int
nrows     :: {-# UNPACK #-} !Int -- ^ Number of rows.
 , forall a. Matrix a -> Int
ncols     :: {-# UNPACK #-} !Int -- ^ Number of columns.
 , forall a. Matrix a -> Int
rowOffset :: {-# UNPACK #-} !Int
 , forall a. Matrix a -> Int
colOffset :: {-# UNPACK #-} !Int
 , forall a. Matrix a -> Int
vcols     :: {-# UNPACK #-} !Int -- ^ Number of columns of the matrix without offset
 , forall a. Matrix a -> Vector a
mvect     :: V.Vector a          -- ^ Content of the matrix as a plain vector.
   } deriving ((forall x. Matrix a -> Rep (Matrix a) x)
-> (forall x. Rep (Matrix a) x -> Matrix a) -> Generic (Matrix a)
forall x. Rep (Matrix a) x -> Matrix a
forall x. Matrix a -> Rep (Matrix a) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall a x. Rep (Matrix a) x -> Matrix a
forall a x. Matrix a -> Rep (Matrix a) x
$cfrom :: forall a x. Matrix a -> Rep (Matrix a) x
from :: forall x. Matrix a -> Rep (Matrix a) x
$cto :: forall a x. Rep (Matrix a) x -> Matrix a
to :: forall x. Rep (Matrix a) x -> Matrix a
Generic)

instance Eq a => Eq (Matrix a) where
  Matrix a
m1 == :: Matrix a -> Matrix a -> Bool
== Matrix a
m2 =
    let r :: Int
r = Matrix a -> Int
forall a. Matrix a -> Int
nrows Matrix a
m1
        c :: Int
c = Matrix a -> Int
forall a. Matrix a -> Int
ncols Matrix a
m1
    in  [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ (Int
r Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Matrix a -> Int
forall a. Matrix a -> Int
nrows Matrix a
m2) Bool -> [Bool] -> [Bool]
forall a. a -> [a] -> [a]
: (Int
c Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Matrix a -> Int
forall a. Matrix a -> Int
ncols Matrix a
m2)
            Bool -> [Bool] -> [Bool]
forall a. a -> [a] -> [a]
: [ Matrix a
m1 Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
! (Int
i,Int
j) a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== Matrix a
m2 Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
! (Int
i,Int
j) | Int
i <- [Int
1 .. Int
r] , Int
j <- [Int
1 .. Int
c] ]

-- | Just a cool way to output the size of a matrix.
sizeStr :: Int -> Int -> String
sizeStr :: Int -> Int -> String
sizeStr Int
n Int
m = Int -> String
forall a. Show a => a -> String
show Int
n String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"x" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
m

-- | Display a matrix as a 'String' using the 'Show' instance of its elements.
prettyMatrix :: Show a => Matrix a -> String
prettyMatrix :: forall a. Show a => Matrix a -> String
prettyMatrix Matrix a
m = [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
   [ String
"┌ ", [String] -> String
unwords (Int -> String -> [String]
forall a. Int -> a -> [a]
replicate (Matrix a -> Int
forall a. Matrix a -> Int
ncols Matrix a
m) String
blank), String
" ┐\n"
   , [String] -> String
unlines
   [ String
"│ " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [String] -> String
unwords ((Int -> String) -> [Int] -> [String]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Int
j -> String -> String
fill (String -> String) -> String -> String
forall a b. (a -> b) -> a -> b
$ Matrix String
strings Matrix String -> (Int, Int) -> String
forall a. Matrix a -> (Int, Int) -> a
! (Int
i,Int
j)) [Int
1..Matrix a -> Int
forall a. Matrix a -> Int
ncols Matrix a
m]) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" │" | Int
i <- [Int
1..Matrix a -> Int
forall a. Matrix a -> Int
nrows Matrix a
m] ]
   , String
"└ ", [String] -> String
unwords (Int -> String -> [String]
forall a. Int -> a -> [a]
replicate (Matrix a -> Int
forall a. Matrix a -> Int
ncols Matrix a
m) String
blank), String
" ┘"
   ]
 where
   strings :: Matrix String
strings@(M Int
_ Int
_ Int
_ Int
_ Int
_ Vector String
v)  = (a -> String) -> Matrix a -> Matrix String
forall a b. (a -> b) -> Matrix a -> Matrix b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> String
forall a. Show a => a -> String
show Matrix a
m
   widest :: Int
widest = Vector Int -> Int
forall a. Ord a => Vector a -> a
V.maximum (Vector Int -> Int) -> Vector Int -> Int
forall a b. (a -> b) -> a -> b
$ (String -> Int) -> Vector String -> Vector Int
forall a b. (a -> b) -> Vector a -> Vector b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap String -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Vector String
v
   fill :: String -> String
fill String
str = Int -> Char -> String
forall a. Int -> a -> [a]
replicate (Int
widest Int -> Int -> Int
forall a. Num a => a -> a -> a
- String -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length String
str) Char
' ' String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
str
   blank :: String
blank = String -> String
fill String
""


instance Show a => Show (Matrix a) where
 show :: Matrix a -> String
show = Matrix a -> String
forall a. Show a => Matrix a -> String
prettyMatrix

instance NFData a => NFData (Matrix a) where
 rnf :: Matrix a -> ()
rnf = Vector a -> ()
forall a. NFData a => a -> ()
rnf (Vector a -> ()) -> (Matrix a -> Vector a) -> Matrix a -> ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix a -> Vector a
forall a. Matrix a -> Vector a
mvect

-- | /O(rows*cols)/. Similar to 'V.force'. It copies the matrix content
--   dropping any extra memory.
--
--   Useful when using 'submatrix' from a big matrix.
--
forceMatrix :: Matrix a -> Matrix a
forceMatrix :: forall a. Matrix a -> Matrix a
forceMatrix Matrix a
m = Int -> Int -> ((Int, Int) -> a) -> Matrix a
forall a. Int -> Int -> ((Int, Int) -> a) -> Matrix a
matrix (Matrix a -> Int
forall a. Matrix a -> Int
nrows Matrix a
m) (Matrix a -> Int
forall a. Matrix a -> Int
ncols Matrix a
m) (((Int, Int) -> a) -> Matrix a) -> ((Int, Int) -> a) -> Matrix a
forall a b. (a -> b) -> a -> b
$ \(Int
i,Int
j) -> Int -> Int -> Matrix a -> a
forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i Int
j Matrix a
m

-------------------------------------------------------
-------------------------------------------------------
---- FUNCTOR INSTANCE

instance Functor Matrix where
 {-# INLINE fmap #-}
 fmap :: forall a b. (a -> b) -> Matrix a -> Matrix b
fmap a -> b
f (M Int
n Int
m Int
ro Int
co Int
w Vector a
v) = Int -> Int -> Int -> Int -> Int -> Vector b -> Matrix b
forall a. Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
M Int
n Int
m Int
ro Int
co Int
w (Vector b -> Matrix b) -> Vector b -> Matrix b
forall a b. (a -> b) -> a -> b
$ (a -> b) -> Vector a -> Vector b
forall a b. (a -> b) -> Vector a -> Vector b
V.map a -> b
f Vector a
v

-------------------------------------------------------
-------------------------------------------------------

-------------------------------------------------------
-------------------------------------------------------
---- MONOID INSTANCE

instance Monoid a => S.Semigroup (Matrix a) where
  Matrix a
m <> :: Matrix a -> Matrix a -> Matrix a
<> Matrix a
m' = Int -> Int -> ((Int, Int) -> a) -> Matrix a
forall a. Int -> Int -> ((Int, Int) -> a) -> Matrix a
matrix (Int -> Int -> Int
forall a. Ord a => a -> a -> a
max (Matrix a -> Int
forall a. Matrix a -> Int
nrows Matrix a
m) (Matrix a -> Int
forall a. Matrix a -> Int
nrows Matrix a
m')) (Int -> Int -> Int
forall a. Ord a => a -> a -> a
max (Matrix a -> Int
forall a. Matrix a -> Int
ncols Matrix a
m) (Matrix a -> Int
forall a. Matrix a -> Int
ncols Matrix a
m')) (((Int, Int) -> a) -> Matrix a) -> ((Int, Int) -> a) -> Matrix a
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> a) -> (Int, Int) -> a
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Int -> Int -> a
zipTogether
    where zipTogether :: Int -> Int -> a
zipTogether Int
row Int
column = a -> Maybe a -> a
forall a. a -> Maybe a -> a
fromMaybe a
forall a. Monoid a => a
mempty (Maybe a -> a) -> Maybe a -> a
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Matrix a -> Maybe a
forall a. Int -> Int -> Matrix a -> Maybe a
safeGet Int
row Int
column Matrix a
m Maybe a -> Maybe a -> Maybe a
forall a. Semigroup a => a -> a -> a
<> Int -> Int -> Matrix a -> Maybe a
forall a. Int -> Int -> Matrix a -> Maybe a
safeGet Int
row Int
column Matrix a
m'

instance Monoid a => Monoid (Matrix a) where
  mempty :: Matrix a
mempty = Int -> Int -> [a] -> Matrix a
forall a. Int -> Int -> [a] -> Matrix a
fromList Int
1 Int
1 [a
forall a. Monoid a => a
mempty]
  mappend :: Matrix a -> Matrix a -> Matrix a
mappend = Matrix a -> Matrix a -> Matrix a
forall a. Semigroup a => a -> a -> a
(<>)


-------------------------------------------------------
-------------------------------------------------------
-------------------------------------------------------
-------------------------------------------------------

-------------------------------------------------------
-------------------------------------------------------
---- APPLICATIVE INSTANCE
---- Works like tensor product but applies a function

instance Applicative Matrix where
  pure :: forall a. a -> Matrix a
pure a
x = Int -> Int -> [a] -> Matrix a
forall a. Int -> Int -> [a] -> Matrix a
fromList Int
1 Int
1 [a
x]
  Matrix (a -> b)
m <*> :: forall a b. Matrix (a -> b) -> Matrix a -> Matrix b
<*> Matrix a
m' = Matrix (Matrix b) -> Matrix b
forall a. Matrix (Matrix a) -> Matrix a
flatten (Matrix (Matrix b) -> Matrix b) -> Matrix (Matrix b) -> Matrix b
forall a b. (a -> b) -> a -> b
$ (\a -> b
f -> a -> b
f (a -> b) -> Matrix a -> Matrix b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Matrix a
m') ((a -> b) -> Matrix b) -> Matrix (a -> b) -> Matrix (Matrix b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Matrix (a -> b)
m


-------------------------------------------------------
-------------------------------------------------------



-- | Flatten a matrix of matrices. All sub matrices must have same dimensions
--   This criteria is not checked.
flatten:: Matrix (Matrix a) -> Matrix a
flatten :: forall a. Matrix (Matrix a) -> Matrix a
flatten Matrix (Matrix a)
m = (Matrix a -> Matrix a -> Matrix a) -> [Matrix a] -> Matrix a
forall a. (a -> a -> a) -> [a] -> a
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 Matrix a -> Matrix a -> Matrix a
forall a. Matrix a -> Matrix a -> Matrix a
(<->) ([Matrix a] -> Matrix a) -> [Matrix a] -> Matrix a
forall a b. (a -> b) -> a -> b
$ (Int -> Matrix a) -> [Int] -> [Matrix a]
forall a b. (a -> b) -> [a] -> [b]
map ((Matrix a -> Matrix a -> Matrix a) -> Vector (Matrix a) -> Matrix a
forall a. (a -> a -> a) -> Vector a -> a
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 Matrix a -> Matrix a -> Matrix a
forall a. Matrix a -> Matrix a -> Matrix a
(<|>) (Vector (Matrix a) -> Matrix a)
-> (Int -> Vector (Matrix a)) -> Int -> Matrix a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (\Int
i -> Int -> Matrix (Matrix a) -> Vector (Matrix a)
forall a. Int -> Matrix a -> Vector a
getRow Int
i Matrix (Matrix a)
m)) [Int
1..(Matrix (Matrix a) -> Int
forall a. Matrix a -> Int
nrows Matrix (Matrix a)
m)]

-- | /O(rows*cols)/. Map a function over a row.
--   Example:
--
-- >                          ( 1 2 3 )   ( 1 2 3 )
-- >                          ( 4 5 6 )   ( 5 6 7 )
-- > mapRow (\_ x -> x + 1) 2 ( 7 8 9 ) = ( 7 8 9 )
--
mapRow :: (Int -> a -> a) -- ^ Function takes the current column as additional argument.
        -> Int            -- ^ Row to map.
        -> Matrix a -> Matrix a
mapRow :: forall a. (Int -> a -> a) -> Int -> Matrix a -> Matrix a
mapRow Int -> a -> a
f Int
r Matrix a
m =
  Int -> Int -> ((Int, Int) -> a) -> Matrix a
forall a. Int -> Int -> ((Int, Int) -> a) -> Matrix a
matrix (Matrix a -> Int
forall a. Matrix a -> Int
nrows Matrix a
m) (Matrix a -> Int
forall a. Matrix a -> Int
ncols Matrix a
m) (((Int, Int) -> a) -> Matrix a) -> ((Int, Int) -> a) -> Matrix a
forall a b. (a -> b) -> a -> b
$ \(Int
i,Int
j) ->
    let a :: a
a = Int -> Int -> Matrix a -> a
forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i Int
j Matrix a
m
    in  if Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
r
           then Int -> a -> a
f Int
j a
a
           else a
a

-- | /O(rows*cols)/. Map a function over a column.
--   Example:
--
-- >                          ( 1 2 3 )   ( 1 3 3 )
-- >                          ( 4 5 6 )   ( 4 6 6 )
-- > mapCol (\_ x -> x + 1) 2 ( 7 8 9 ) = ( 7 9 9 )
--
mapCol :: (Int -> a -> a) -- ^ Function takes the current row as additional argument.
        -> Int            -- ^ Column to map.
        -> Matrix a -> Matrix a
mapCol :: forall a. (Int -> a -> a) -> Int -> Matrix a -> Matrix a
mapCol Int -> a -> a
f Int
c Matrix a
m =
  Int -> Int -> ((Int, Int) -> a) -> Matrix a
forall a. Int -> Int -> ((Int, Int) -> a) -> Matrix a
matrix (Matrix a -> Int
forall a. Matrix a -> Int
nrows Matrix a
m) (Matrix a -> Int
forall a. Matrix a -> Int
ncols Matrix a
m) (((Int, Int) -> a) -> Matrix a) -> ((Int, Int) -> a) -> Matrix a
forall a b. (a -> b) -> a -> b
$ \(Int
i,Int
j) ->
    let a :: a
a = Int -> Int -> Matrix a -> a
forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i Int
j Matrix a
m
    in  if Int
j Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
c
           then Int -> a -> a
f Int
i a
a
           else a
a


-- | /O(rows*cols)/. Map a function over elements.
--   Example:
--
-- >                            ( 1 2 3 )   ( 0 -1 -2 )
-- >                            ( 4 5 6 )   ( 1  0 -1 )
-- > mapPos (\(r,c) a -> r - c) ( 7 8 9 ) = ( 2  1  0 )
--
mapPos :: ((Int, Int) -> a -> b) -- ^ Function takes the current Position as additional argument.
        -> Matrix a
        -> Matrix b
mapPos :: forall a b. ((Int, Int) -> a -> b) -> Matrix a -> Matrix b
mapPos (Int, Int) -> a -> b
f m :: Matrix a
m@(M {ncols :: forall a. Matrix a -> Int
ncols = Int
cols, mvect :: forall a. Matrix a -> Vector a
mvect = Vector a
vect})=
  Matrix a
m { mvect = V.imap (\Int
i a
e -> (Int, Int) -> a -> b
f (Int -> Int -> (Int, Int)
decode Int
cols Int
i) a
e) vect}

-------------------------------------------------------
-------------------------------------------------------
---- FOLDABLE AND TRAVERSABLE INSTANCES

instance Foldable Matrix where
 foldMap :: forall m a. Monoid m => (a -> m) -> Matrix a -> m
foldMap a -> m
f = (a -> m) -> Vector a -> m
forall m a. Monoid m => (a -> m) -> Vector a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap a -> m
f (Vector a -> m) -> (Matrix a -> Vector a) -> Matrix a -> m
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix a -> Vector a
forall a. Matrix a -> Vector a
mvect (Matrix a -> Vector a)
-> (Matrix a -> Matrix a) -> Matrix a -> Vector a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix a -> Matrix a
forall a. Matrix a -> Matrix a
forceMatrix

instance Traversable Matrix where
 sequenceA :: forall (f :: * -> *) a.
Applicative f =>
Matrix (f a) -> f (Matrix a)
sequenceA Matrix (f a)
m = (Vector a -> Matrix a) -> f (Vector a) -> f (Matrix a)
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
forall a. Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
M (Matrix (f a) -> Int
forall a. Matrix a -> Int
nrows Matrix (f a)
m) (Matrix (f a) -> Int
forall a. Matrix a -> Int
ncols Matrix (f a)
m) Int
0 Int
0 (Matrix (f a) -> Int
forall a. Matrix a -> Int
ncols Matrix (f a)
m)) (f (Vector a) -> f (Matrix a))
-> (Matrix (f a) -> f (Vector a)) -> Matrix (f a) -> f (Matrix a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector (f a) -> f (Vector a)
forall (t :: * -> *) (f :: * -> *) a.
(Traversable t, Applicative f) =>
t (f a) -> f (t a)
forall (f :: * -> *) a.
Applicative f =>
Vector (f a) -> f (Vector a)
sequenceA (Vector (f a) -> f (Vector a))
-> (Matrix (f a) -> Vector (f a)) -> Matrix (f a) -> f (Vector a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix (f a) -> Vector (f a)
forall a. Matrix a -> Vector a
mvect (Matrix (f a) -> f (Matrix a)) -> Matrix (f a) -> f (Matrix a)
forall a b. (a -> b) -> a -> b
$ Matrix (f a) -> Matrix (f a)
forall a. Matrix a -> Matrix a
forceMatrix Matrix (f a)
m

-------------------------------------------------------
-------------------------------------------------------
---- BUILDERS

-- | /O(rows*cols)/. The zero matrix of the given size.
--
-- > zero n m =
-- >                 m
-- >   1 ( 0 0 ... 0 0 )
-- >   2 ( 0 0 ... 0 0 )
-- >     (     ...     )
-- >     ( 0 0 ... 0 0 )
-- >   n ( 0 0 ... 0 0 )
zero :: Num a =>
     Int -- ^ Rows
  -> Int -- ^ Columns
  -> Matrix a
{-# INLINE zero #-}
zero :: forall a. Num a => Int -> Int -> Matrix a
zero Int
n Int
m = Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
forall a. Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
M Int
n Int
m Int
0 Int
0 Int
m (Vector a -> Matrix a) -> Vector a -> Matrix a
forall a b. (a -> b) -> a -> b
$ Int -> a -> Vector a
forall a. Int -> a -> Vector a
V.replicate (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
m) a
0

-- | /O(rows*cols)/. Generate a matrix from a generator function.
--   Example of usage:
--
-- >                                  (  1  0 -1 -2 )
-- >                                  (  3  2  1  0 )
-- >                                  (  5  4  3  2 )
-- > matrix 4 4 $ \(i,j) -> 2*i - j = (  7  6  5  4 )
matrix :: Int -- ^ Rows
       -> Int -- ^ Columns
       -> ((Int,Int) -> a) -- ^ Generator function
       -> Matrix a
{-# INLINE matrix #-}
matrix :: forall a. Int -> Int -> ((Int, Int) -> a) -> Matrix a
matrix Int
n Int
m (Int, Int) -> a
f = Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
forall a. Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
M Int
n Int
m Int
0 Int
0 Int
m (Vector a -> Matrix a) -> Vector a -> Matrix a
forall a b. (a -> b) -> a -> b
$ (forall s. ST s (MVector s a)) -> Vector a
forall a. (forall s. ST s (MVector s a)) -> Vector a
V.create ((forall s. ST s (MVector s a)) -> Vector a)
-> (forall s. ST s (MVector s a)) -> Vector a
forall a b. (a -> b) -> a -> b
$ do
  MVector s a
v <- Int -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) a.
PrimMonad m =>
Int -> m (MVector (PrimState m) a)
MV.new (Int -> ST s (MVector (PrimState (ST s)) a))
-> Int -> ST s (MVector (PrimState (ST s)) a)
forall a b. (a -> b) -> a -> b
$ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
m
  let en :: (Int, Int) -> Int
en = Int -> (Int, Int) -> Int
encode Int
m
  Int -> Int -> (Int -> ST s ()) -> ST s ()
forall a (m :: * -> *).
(Num a, Ord a, Monad m) =>
a -> a -> (a -> m ()) -> m ()
numLoop Int
1 Int
n ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$
    \Int
i -> Int -> Int -> (Int -> ST s ()) -> ST s ()
forall a (m :: * -> *).
(Num a, Ord a, Monad m) =>
a -> a -> (a -> m ()) -> m ()
numLoop Int
1 Int
m ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$
    \Int
j -> MVector (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.unsafeWrite MVector s a
MVector (PrimState (ST s)) a
v ((Int, Int) -> Int
en (Int
i,Int
j)) ((Int, Int) -> a
f (Int
i,Int
j))
  MVector s a -> ST s (MVector s a)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return MVector s a
v

-- | /O(rows*cols)/. Identity matrix of the given order.
--
-- > identity n =
-- >                 n
-- >   1 ( 1 0 ... 0 0 )
-- >   2 ( 0 1 ... 0 0 )
-- >     (     ...     )
-- >     ( 0 0 ... 1 0 )
-- >   n ( 0 0 ... 0 1 )
--
identity :: Num a => Int -> Matrix a
identity :: forall a. Num a => Int -> Matrix a
identity Int
n = Int -> Int -> ((Int, Int) -> a) -> Matrix a
forall a. Int -> Int -> ((Int, Int) -> a) -> Matrix a
matrix Int
n Int
n (((Int, Int) -> a) -> Matrix a) -> ((Int, Int) -> a) -> Matrix a
forall a b. (a -> b) -> a -> b
$ \(Int
i,Int
j) -> if Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
j then a
1 else a
0

-- | Similar to 'diagonalList', but using 'V.Vector', which
--   should be more efficient.
diagonal :: a -- ^ Default element
         -> V.Vector a  -- ^ Diagonal vector
         -> Matrix a
diagonal :: forall a. a -> Vector a -> Matrix a
diagonal a
e Vector a
v = Int -> Int -> ((Int, Int) -> a) -> Matrix a
forall a. Int -> Int -> ((Int, Int) -> a) -> Matrix a
matrix Int
n Int
n (((Int, Int) -> a) -> Matrix a) -> ((Int, Int) -> a) -> Matrix a
forall a b. (a -> b) -> a -> b
$ \(Int
i,Int
j) -> if Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
j then Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.unsafeIndex Vector a
v (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) else a
e
  where
    n :: Int
n = Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
v

-- | Create a matrix from a non-empty list given the desired size.
--   The list must have at least /rows*cols/ elements.
--   An example:
--
-- >                       ( 1 2 3 )
-- >                       ( 4 5 6 )
-- > fromList 3 3 [1..] =  ( 7 8 9 )
--
fromList :: Int -- ^ Rows
         -> Int -- ^ Columns
         -> [a] -- ^ List of elements
         -> Matrix a
{-# INLINE fromList #-}
fromList :: forall a. Int -> Int -> [a] -> Matrix a
fromList Int
n Int
m [a]
xs
    | Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
m Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
v =
        (String -> Matrix a
forall a. HasCallStack => String -> a
error (String -> Matrix a) -> String -> Matrix a
forall a b. (a -> b) -> a -> b
$
            String
"List size "
            String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
v)
            String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" is inconsistent with matrix size "
            String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> Int -> String
sizeStr Int
n Int
m
            String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" in fromList")
    | Bool
otherwise       = Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
forall a. Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
M Int
n Int
m Int
0 Int
0 Int
m Vector a
v
    where v :: Vector a
v = Int -> [a] -> Vector a
forall a. Int -> [a] -> Vector a
V.fromListN (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
m) [a]
xs

-- | Get the elements of a matrix stored in a list.
--
-- >        ( 1 2 3 )
-- >        ( 4 5 6 )
-- > toList ( 7 8 9 ) = [1,2,3,4,5,6,7,8,9]
--
toList :: Matrix a -> [a]
toList :: forall a. Matrix a -> [a]
toList Matrix a
m = [ Int -> Int -> Matrix a -> a
forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i Int
j Matrix a
m | Int
i <- [Int
1 .. Matrix a -> Int
forall a. Matrix a -> Int
nrows Matrix a
m] , Int
j <- [Int
1 .. Matrix a -> Int
forall a. Matrix a -> Int
ncols Matrix a
m] ]

-- | Get the elements of a matrix stored in a list of lists,
--   where each list contains the elements of a single row.
--
-- >         ( 1 2 3 )   [ [1,2,3]
-- >         ( 4 5 6 )   , [4,5,6]
-- > toLists ( 7 8 9 ) = , [7,8,9] ]
--
toLists :: Matrix a -> [[a]]
toLists :: forall a. Matrix a -> [[a]]
toLists Matrix a
m = [ [ Int -> Int -> Matrix a -> a
forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i Int
j Matrix a
m | Int
j <- [Int
1 .. Matrix a -> Int
forall a. Matrix a -> Int
ncols Matrix a
m] ] | Int
i <- [Int
1 .. Matrix a -> Int
forall a. Matrix a -> Int
nrows Matrix a
m] ]

-- | Diagonal matrix from a non-empty list given the desired size.
--   Non-diagonal elements will be filled with the given default element.
--   The list must have at least /order/ elements.
--
-- > diagonalList n 0 [1..] =
-- >                   n
-- >   1 ( 1 0 ... 0   0 )
-- >   2 ( 0 2 ... 0   0 )
-- >     (     ...       )
-- >     ( 0 0 ... n-1 0 )
-- >   n ( 0 0 ... 0   n )
--
diagonalList :: Int -> a -> [a] -> Matrix a
diagonalList :: forall a. Int -> a -> [a] -> Matrix a
diagonalList Int
n a
e [a]
xs = Int -> Int -> ((Int, Int) -> a) -> Matrix a
forall a. Int -> Int -> ((Int, Int) -> a) -> Matrix a
matrix Int
n Int
n (((Int, Int) -> a) -> Matrix a) -> ((Int, Int) -> a) -> Matrix a
forall a b. (a -> b) -> a -> b
$ \(Int
i,Int
j) -> if Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
j then [a]
xs [a] -> Int -> a
forall a. HasCallStack => [a] -> Int -> a
!! (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) else a
e

-- | Create a matrix from a non-empty list of non-empty lists.
--   /Each list must have at least as many elements as the first list/.
--   Examples:
--
-- > fromLists [ [1,2,3]      ( 1 2 3 )
-- >           , [4,5,6]      ( 4 5 6 )
-- >           , [7,8,9] ] =  ( 7 8 9 )
--
-- > fromLists [ [1,2,3  ]     ( 1 2 3 )
-- >           , [4,5,6,7]     ( 4 5 6 )
-- >           , [8,9,0  ] ] = ( 8 9 0 )
--
fromLists :: [[a]] -> Matrix a
{-# INLINE fromLists #-}
fromLists :: forall a. [[a]] -> Matrix a
fromLists [] = String -> Matrix a
forall a. HasCallStack => String -> a
error String
"fromLists: empty list."
fromLists ([a]
xs:[[a]]
xss) = Int -> Int -> [a] -> Matrix a
forall a. Int -> Int -> [a] -> Matrix a
fromList Int
n Int
m ([a] -> Matrix a) -> [a] -> Matrix a
forall a b. (a -> b) -> a -> b
$ [[a]] -> [a]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[a]] -> [a]) -> [[a]] -> [a]
forall a b. (a -> b) -> a -> b
$ [a]
xs [a] -> [[a]] -> [[a]]
forall a. a -> [a] -> [a]
: ([a] -> [a]) -> [[a]] -> [[a]]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
take Int
m) [[a]]
xss
  where
    n :: Int
n = Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [[a]] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [[a]]
xss
    m :: Int
m = [a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
xs

-- | /O(1)/. Represent a vector as a one row matrix.
rowVector :: V.Vector a -> Matrix a
rowVector :: forall a. Vector a -> Matrix a
rowVector Vector a
v = Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
forall a. Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
M Int
1 Int
m Int
0 Int
0 Int
m Vector a
v
  where
    m :: Int
m = Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
v

-- | /O(1)/. Represent a vector as a one column matrix.
colVector :: V.Vector a -> Matrix a
colVector :: forall a. Vector a -> Matrix a
colVector Vector a
v = Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
forall a. Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
M (Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
v) Int
1 Int
0 Int
0 Int
1 Vector a
v

-- | /O(rows*cols)/. Permutation matrix.
--
-- > permMatrix n i j =
-- >               i     j       n
-- >   1 ( 1 0 ... 0 ... 0 ... 0 0 )
-- >   2 ( 0 1 ... 0 ... 0 ... 0 0 )
-- >     (     ...   ...   ...     )
-- >   i ( 0 0 ... 0 ... 1 ... 0 0 )
-- >     (     ...   ...   ...     )
-- >   j ( 0 0 ... 1 ... 0 ... 0 0 )
-- >     (     ...   ...   ...     )
-- >     ( 0 0 ... 0 ... 0 ... 1 0 )
-- >   n ( 0 0 ... 0 ... 0 ... 0 1 )
--
-- When @i == j@ it reduces to 'identity' @n@.
--
permMatrix :: Num a
           => Int -- ^ Size of the matrix.
           -> Int -- ^ Permuted row 1.
           -> Int -- ^ Permuted row 2.
           -> Matrix a -- ^ Permutation matrix.
permMatrix :: forall a. Num a => Int -> Int -> Int -> Matrix a
permMatrix Int
n Int
r1 Int
r2 | Int
r1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
r2 = Int -> Matrix a
forall a. Num a => Int -> Matrix a
identity Int
n
permMatrix Int
n Int
r1 Int
r2 = Int -> Int -> ((Int, Int) -> a) -> Matrix a
forall a. Int -> Int -> ((Int, Int) -> a) -> Matrix a
matrix Int
n Int
n (Int, Int) -> a
forall {a}. Num a => (Int, Int) -> a
f
 where
  f :: (Int, Int) -> a
f (Int
i,Int
j)
   | Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
r1 = if Int
j Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
r2 then a
1 else a
0
   | Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
r2 = if Int
j Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
r1 then a
1 else a
0
   | Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
j = a
1
   | Bool
otherwise = a
0

-------------------------------------------------------
-------------------------------------------------------
---- ACCESSING

-- | /O(1)/. Get an element of a matrix. Indices range from /(1,1)/ to /(n,m)/.
--   It returns an 'error' if the requested element is outside of range.
getElem :: Int      -- ^ Row
        -> Int      -- ^ Column
        -> Matrix a -- ^ Matrix
        -> a
{-# INLINE getElem #-}
getElem :: forall a. Int -> Int -> Matrix a -> a
getElem Int
i Int
j Matrix a
m =
  a -> Maybe a -> a
forall a. a -> Maybe a -> a
fromMaybe
    (String -> a
forall a. HasCallStack => String -> a
error (String -> a) -> String -> a
forall a b. (a -> b) -> a -> b
$
       String
"getElem: Trying to get the "
        String -> String -> String
forall a. [a] -> [a] -> [a]
++ (Int, Int) -> String
forall a. Show a => a -> String
show (Int
i, Int
j)
        String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" element from a "
        String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> Int -> String
sizeStr (Matrix a -> Int
forall a. Matrix a -> Int
nrows Matrix a
m) (Matrix a -> Int
forall a. Matrix a -> Int
ncols Matrix a
m)
        String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" matrix."
    )
    (Int -> Int -> Matrix a -> Maybe a
forall a. Int -> Int -> Matrix a -> Maybe a
safeGet Int
i Int
j Matrix a
m)

-- | /O(1)/. Unsafe variant of 'getElem', without bounds checking.
unsafeGet :: Int      -- ^ Row
          -> Int      -- ^ Column
          -> Matrix a -- ^ Matrix
          -> a
{-# INLINE unsafeGet #-}
unsafeGet :: forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i Int
j (M Int
_ Int
_ Int
ro Int
co Int
w Vector a
v) = Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.unsafeIndex Vector a
v (Int -> a) -> Int -> a
forall a b. (a -> b) -> a -> b
$ Int -> (Int, Int) -> Int
encode Int
w (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
ro,Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
co)

-- | Short alias for 'getElem'.
(!) :: Matrix a -> (Int,Int) -> a
{-# INLINE (!) #-}
Matrix a
m ! :: forall a. Matrix a -> (Int, Int) -> a
! (Int
i,Int
j) = Int -> Int -> Matrix a -> a
forall a. Int -> Int -> Matrix a -> a
getElem Int
i Int
j Matrix a
m

-- | Internal alias for 'unsafeGet'.
(!.) :: Matrix a -> (Int,Int) -> a
{-# INLINE (!.) #-}
Matrix a
m !. :: forall a. Matrix a -> (Int, Int) -> a
!. (Int
i,Int
j) = Int -> Int -> Matrix a -> a
forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i Int
j Matrix a
m

-- | Variant of 'getElem' that returns Maybe instead of an error.
safeGet :: Int -> Int -> Matrix a -> Maybe a
safeGet :: forall a. Int -> Int -> Matrix a -> Maybe a
safeGet Int
i Int
j a :: Matrix a
a@(M Int
n Int
m Int
_ Int
_ Int
_ Vector a
_)
 | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
n Bool -> Bool -> Bool
|| Int
j Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
m Bool -> Bool -> Bool
|| Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
1 Bool -> Bool -> Bool
|| Int
j Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
1 = Maybe a
forall a. Maybe a
Nothing
 | Bool
otherwise = a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> a -> Maybe a
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Matrix a -> a
forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i Int
j Matrix a
a

-- | Variant of 'setElem' that returns Maybe instead of an error.
safeSet:: a -> (Int, Int) -> Matrix a -> Maybe (Matrix a)
safeSet :: forall a. a -> (Int, Int) -> Matrix a -> Maybe (Matrix a)
safeSet a
x p :: (Int, Int)
p@(Int
i,Int
j) a :: Matrix a
a@(M Int
n Int
m Int
_ Int
_ Int
_ Vector a
_)
  | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
n Bool -> Bool -> Bool
|| Int
j Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
m Bool -> Bool -> Bool
|| Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
1 Bool -> Bool -> Bool
|| Int
j Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
1 = Maybe (Matrix a)
forall a. Maybe a
Nothing
  | Bool
otherwise = Matrix a -> Maybe (Matrix a)
forall a. a -> Maybe a
Just (Matrix a -> Maybe (Matrix a)) -> Matrix a -> Maybe (Matrix a)
forall a b. (a -> b) -> a -> b
$ a -> (Int, Int) -> Matrix a -> Matrix a
forall a. a -> (Int, Int) -> Matrix a -> Matrix a
unsafeSet a
x (Int, Int)
p Matrix a
a

-- | /O(1)/. Get a row of a matrix as a vector.
getRow :: Int -> Matrix a -> V.Vector a
{-# INLINE getRow #-}
getRow :: forall a. Int -> Matrix a -> Vector a
getRow Int
i (M Int
_ Int
m Int
ro Int
co Int
w Vector a
v) = Int -> Int -> Vector a -> Vector a
forall a. Int -> Int -> Vector a -> Vector a
V.slice (Int
wInt -> Int -> Int
forall a. Num a => a -> a -> a
*(Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
ro) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
co) Int
m Vector a
v

-- | Varian of 'getRow' that returns a maybe instead of an error
safeGetRow :: Int -> Matrix a -> Maybe (V.Vector a)
safeGetRow :: forall a. Int -> Matrix a -> Maybe (Vector a)
safeGetRow Int
r Matrix a
m
    | Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Matrix a -> Int
forall a. Matrix a -> Int
nrows Matrix a
m Bool -> Bool -> Bool
|| Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
1 = Maybe (Vector a)
forall a. Maybe a
Nothing
    | Bool
otherwise = Vector a -> Maybe (Vector a)
forall a. a -> Maybe a
Just (Vector a -> Maybe (Vector a)) -> Vector a -> Maybe (Vector a)
forall a b. (a -> b) -> a -> b
$ Int -> Matrix a -> Vector a
forall a. Int -> Matrix a -> Vector a
getRow Int
r Matrix a
m

-- | /O(rows)/. Get a column of a matrix as a vector.
getCol :: Int -> Matrix a -> V.Vector a
{-# INLINE getCol #-}
getCol :: forall a. Int -> Matrix a -> Vector a
getCol Int
j (M Int
n Int
_ Int
ro Int
co Int
w Vector a
v) = Int -> (Int -> a) -> Vector a
forall a. Int -> (Int -> a) -> Vector a
V.generate Int
n ((Int -> a) -> Vector a) -> (Int -> a) -> Vector a
forall a b. (a -> b) -> a -> b
$ \Int
i -> Vector a
v Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.! Int -> (Int, Int) -> Int
encode Int
w (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
ro,Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
co)

-- | Varian of 'getColumn' that returns a maybe instead of an error
safeGetCol :: Int -> Matrix a -> Maybe (V.Vector a)
safeGetCol :: forall a. Int -> Matrix a -> Maybe (Vector a)
safeGetCol Int
c Matrix a
m
    | Int
c Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Matrix a -> Int
forall a. Matrix a -> Int
ncols Matrix a
m Bool -> Bool -> Bool
|| Int
c Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
1 = Maybe (Vector a)
forall a. Maybe a
Nothing
    | Bool
otherwise = Vector a -> Maybe (Vector a)
forall a. a -> Maybe a
Just (Vector a -> Maybe (Vector a)) -> Vector a -> Maybe (Vector a)
forall a b. (a -> b) -> a -> b
$ Int -> Matrix a -> Vector a
forall a. Int -> Matrix a -> Vector a
getCol Int
c Matrix a
m

-- | /O(min rows cols)/. Diagonal of a /not necessarily square/ matrix.
getDiag :: Matrix a -> V.Vector a
getDiag :: forall a. Matrix a -> Vector a
getDiag Matrix a
m = Int -> (Int -> a) -> Vector a
forall a. Int -> (Int -> a) -> Vector a
V.generate Int
k ((Int -> a) -> Vector a) -> (Int -> a) -> Vector a
forall a b. (a -> b) -> a -> b
$ \Int
i -> Matrix a
m Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
! (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1,Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
 where
  k :: Int
k = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min (Matrix a -> Int
forall a. Matrix a -> Int
nrows Matrix a
m) (Matrix a -> Int
forall a. Matrix a -> Int
ncols Matrix a
m)

-- | /O(rows*cols)/. Transform a 'Matrix' to a 'V.Vector' of size /rows*cols/.
--  This is equivalent to get all the rows of the matrix using 'getRow'
--  and then append them, but far more efficient.
getMatrixAsVector :: Matrix a -> V.Vector a
getMatrixAsVector :: forall a. Matrix a -> Vector a
getMatrixAsVector = Matrix a -> Vector a
forall a. Matrix a -> Vector a
mvect (Matrix a -> Vector a)
-> (Matrix a -> Matrix a) -> Matrix a -> Vector a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix a -> Matrix a
forall a. Matrix a -> Matrix a
forceMatrix

-------------------------------------------------------
-------------------------------------------------------
---- MANIPULATING MATRICES

msetElem :: PrimMonad m
         => a -- ^ New element
         -> Int -- ^ Number of columns of the matrix
         -> Int -- ^ Row offset
         -> Int -- ^ Column offset
         -> (Int,Int) -- ^ Position to set the new element
         -> MV.MVector (PrimState m) a -- ^ Mutable vector
         -> m ()
{-# INLINE msetElem #-}
msetElem :: forall (m :: * -> *) a.
PrimMonad m =>
a
-> Int
-> Int
-> Int
-> (Int, Int)
-> MVector (PrimState m) a
-> m ()
msetElem a
x Int
w Int
ro Int
co (Int
i,Int
j) MVector (PrimState m) a
v = MVector (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write MVector (PrimState m) a
v (Int -> (Int, Int) -> Int
encode Int
w (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
ro,Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
co)) a
x

unsafeMset :: PrimMonad m
         => a -- ^ New element
         -> Int -- ^ Number of columns of the matrix
         -> Int -- ^ Row offset
         -> Int -- ^ Column offset
         -> (Int,Int) -- ^ Position to set the new element
         -> MV.MVector (PrimState m) a -- ^ Mutable vector
         -> m ()
{-# INLINE unsafeMset #-}
unsafeMset :: forall (m :: * -> *) a.
PrimMonad m =>
a
-> Int
-> Int
-> Int
-> (Int, Int)
-> MVector (PrimState m) a
-> m ()
unsafeMset a
x Int
w Int
ro Int
co (Int
i,Int
j) MVector (PrimState m) a
v = MVector (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.unsafeWrite MVector (PrimState m) a
v (Int -> (Int, Int) -> Int
encode Int
w (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
ro,Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
co)) a
x

-- | Replace the value of a cell in a matrix.
setElem :: a -- ^ New value.
        -> (Int,Int) -- ^ Position to replace.
        -> Matrix a -- ^ Original matrix.
        -> Matrix a -- ^ Matrix with the given position replaced with the given value.
{-# INLINE setElem #-}
setElem :: forall a. a -> (Int, Int) -> Matrix a -> Matrix a
setElem a
x (Int, Int)
p (M Int
n Int
m Int
ro Int
co Int
w Vector a
v) = Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
forall a. Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
M Int
n Int
m Int
ro Int
co Int
w (Vector a -> Matrix a) -> Vector a -> Matrix a
forall a b. (a -> b) -> a -> b
$ (forall s. MVector s a -> ST s ()) -> Vector a -> Vector a
forall a.
(forall s. MVector s a -> ST s ()) -> Vector a -> Vector a
V.modify (a
-> Int
-> Int
-> Int
-> (Int, Int)
-> MVector (PrimState (ST s)) a
-> ST s ()
forall (m :: * -> *) a.
PrimMonad m =>
a
-> Int
-> Int
-> Int
-> (Int, Int)
-> MVector (PrimState m) a
-> m ()
msetElem a
x Int
w Int
ro Int
co (Int, Int)
p) Vector a
v

-- | Unsafe variant of 'setElem', without bounds checking.
unsafeSet :: a -- ^ New value.
        -> (Int,Int) -- ^ Position to replace.
        -> Matrix a -- ^ Original matrix.
        -> Matrix a -- ^ Matrix with the given position replaced with the given value.
{-# INLINE unsafeSet #-}
unsafeSet :: forall a. a -> (Int, Int) -> Matrix a -> Matrix a
unsafeSet a
x (Int, Int)
p (M Int
n Int
m Int
ro Int
co Int
w Vector a
v) = Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
forall a. Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
M Int
n Int
m Int
ro Int
co Int
w (Vector a -> Matrix a) -> Vector a -> Matrix a
forall a b. (a -> b) -> a -> b
$ (forall s. MVector s a -> ST s ()) -> Vector a -> Vector a
forall a.
(forall s. MVector s a -> ST s ()) -> Vector a -> Vector a
V.modify (a
-> Int
-> Int
-> Int
-> (Int, Int)
-> MVector (PrimState (ST s)) a
-> ST s ()
forall (m :: * -> *) a.
PrimMonad m =>
a
-> Int
-> Int
-> Int
-> (Int, Int)
-> MVector (PrimState m) a
-> m ()
unsafeMset a
x Int
w Int
ro Int
co (Int, Int)
p) Vector a
v

-- | /O(rows*cols)/. The transpose of a matrix.
--   Example:
--
-- >           ( 1 2 3 )   ( 1 4 7 )
-- >           ( 4 5 6 )   ( 2 5 8 )
-- > transpose ( 7 8 9 ) = ( 3 6 9 )
transpose :: Matrix a -> Matrix a
transpose :: forall a. Matrix a -> Matrix a
transpose Matrix a
m = Int -> Int -> ((Int, Int) -> a) -> Matrix a
forall a. Int -> Int -> ((Int, Int) -> a) -> Matrix a
matrix (Matrix a -> Int
forall a. Matrix a -> Int
ncols Matrix a
m) (Matrix a -> Int
forall a. Matrix a -> Int
nrows Matrix a
m) (((Int, Int) -> a) -> Matrix a) -> ((Int, Int) -> a) -> Matrix a
forall a b. (a -> b) -> a -> b
$ \(Int
i,Int
j) -> Matrix a
m Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
! (Int
j,Int
i)

-- | /O(rows*rows*rows*rows) = O(cols*cols*cols*cols)/. The inverse of a square matrix.
--   Uses naive Gaussian elimination formula.
inverse :: (Fractional a, Eq a) => Matrix a -> Either String (Matrix a)
inverse :: forall a.
(Fractional a, Eq a) =>
Matrix a -> Either String (Matrix a)
inverse Matrix a
m
    | Matrix a -> Int
forall a. Matrix a -> Int
ncols Matrix a
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Matrix a -> Int
forall a. Matrix a -> Int
nrows Matrix a
m
        = String -> Either String (Matrix a)
forall a b. a -> Either a b
Left
            (String -> Either String (Matrix a))
-> String -> Either String (Matrix a)
forall a b. (a -> b) -> a -> b
$ String
"Inverting non-square matrix with dimensions "
                String -> String -> String
forall a. [a] -> [a] -> [a]
++ String -> String
forall a. Show a => a -> String
show (Int -> Int -> String
sizeStr (Matrix a -> Int
forall a. Matrix a -> Int
ncols Matrix a
m) (Matrix a -> Int
forall a. Matrix a -> Int
nrows Matrix a
m))
    | Bool
otherwise =
        let
            adjoinedWId :: Matrix a
adjoinedWId = Matrix a
m Matrix a -> Matrix a -> Matrix a
forall a. Matrix a -> Matrix a -> Matrix a
<|> Int -> Matrix a
forall a. Num a => Int -> Matrix a
identity (Matrix a -> Int
forall a. Matrix a -> Int
nrows Matrix a
m)
            rref'd :: Either String (Matrix a)
rref'd = Matrix a -> Either String (Matrix a)
forall a.
(Fractional a, Eq a) =>
Matrix a -> Either String (Matrix a)
rref Matrix a
adjoinedWId
            checkInvertible :: Matrix a -> Either String (Matrix a)
checkInvertible Matrix a
a = if Int -> Int -> Matrix a -> a
forall a. Int -> Int -> Matrix a -> a
unsafeGet (Matrix a -> Int
forall a. Matrix a -> Int
ncols Matrix a
m) (Matrix a -> Int
forall a. Matrix a -> Int
nrows Matrix a
m) Matrix a
a a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
1
                then Matrix a -> Either String (Matrix a)
forall a b. b -> Either a b
Right Matrix a
a
                else String -> Either String (Matrix a)
forall a b. a -> Either a b
Left String
"Attempt to invert a non-invertible matrix"
        in Either String (Matrix a)
rref'd Either String (Matrix a)
-> (Matrix a -> Either String (Matrix a))
-> Either String (Matrix a)
forall a b.
Either String a -> (a -> Either String b) -> Either String b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Matrix a -> Either String (Matrix a)
forall {a}. (Eq a, Num a) => Matrix a -> Either String (Matrix a)
checkInvertible Either String (Matrix a)
-> (Matrix a -> Either String (Matrix a))
-> Either String (Matrix a)
forall a b.
Either String a -> (a -> Either String b) -> Either String b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Matrix a -> Either String (Matrix a)
forall a. a -> Either String a
forall (m :: * -> *) a. Monad m => a -> m a
return (Matrix a -> Either String (Matrix a))
-> (Matrix a -> Matrix a) -> Matrix a -> Either String (Matrix a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Int -> Int -> Int -> Matrix a -> Matrix a
forall a. Int -> Int -> Int -> Int -> Matrix a -> Matrix a
submatrix Int
1 (Matrix a -> Int
forall a. Matrix a -> Int
nrows Matrix a
m) (Matrix a -> Int
forall a. Matrix a -> Int
ncols Matrix a
m Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Matrix a -> Int
forall a. Matrix a -> Int
ncols Matrix a
m Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
2)


-- | Converts a matrix to reduced row echelon form, thus
--   solving a linear system of equations. This requires that (cols > rows)
--   if cols < rows, then there are fewer variables than equations and the
--   problem cannot be solved consistently. If rows = cols, then it is
--   basically a homogenous system of equations, so it will be reduced to
--   identity or an error depending on whether the marix is invertible
--   (this case is allowed for robustness).
--   This implementation is taken from rosettacode
--   https://rosettacode.org/wiki/Reduced_row_echelon_form#Haskell
rref :: (Fractional a, Eq a) => Matrix a -> Either String (Matrix a)
rref :: forall a.
(Fractional a, Eq a) =>
Matrix a -> Either String (Matrix a)
rref Matrix a
m
        | Matrix a -> Int
forall a. Matrix a -> Int
ncols Matrix a
m Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Matrix a -> Int
forall a. Matrix a -> Int
nrows Matrix a
m
            = String -> Either String (Matrix a)
forall a b. a -> Either a b
Left (String -> Either String (Matrix a))
-> String -> Either String (Matrix a)
forall a b. (a -> b) -> a -> b
$
                String
"Invalid dimensions "
                    String -> String -> String
forall a. [a] -> [a] -> [a]
++ String -> String
forall a. Show a => a -> String
show (Int -> Int -> String
sizeStr (Matrix a -> Int
forall a. Matrix a -> Int
ncols Matrix a
m) (Matrix a -> Int
forall a. Matrix a -> Int
nrows Matrix a
m))
        | Bool
otherwise = Matrix a -> Either String (Matrix a)
forall a b. b -> Either a b
Right (Matrix a -> Either String (Matrix a))
-> ([[a]] -> Matrix a) -> [[a]] -> Either String (Matrix a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[a]] -> Matrix a
forall a. [[a]] -> Matrix a
fromLists ([[a]] -> Either String (Matrix a))
-> [[a]] -> Either String (Matrix a)
forall a b. (a -> b) -> a -> b
$ [[a]] -> Int -> [Int] -> [[a]]
forall {a}. (Eq a, Fractional a) => [[a]] -> Int -> [Int] -> [[a]]
f [[a]]
matM Int
0 [Int
0 .. Int
rows Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
  where
    matM :: [[a]]
matM = Matrix a -> [[a]]
forall a. Matrix a -> [[a]]
toLists Matrix a
m
    rows :: Int
rows = Matrix a -> Int
forall a. Matrix a -> Int
nrows Matrix a
m
    cols :: Int
cols = Matrix a -> Int
forall a. Matrix a -> Int
ncols Matrix a
m

    f :: [[a]] -> Int -> [Int] -> [[a]]
f [[a]]
a Int
_    []           = [[a]]
a
    f [[a]]
a Int
lead (Int
r : [Int]
rs)
      | Maybe (Int, Int) -> Bool
forall a. Maybe a -> Bool
isNothing Maybe (Int, Int)
indices = [[a]]
a
      | Bool
otherwise         = [[a]] -> Int -> [Int] -> [[a]]
f [[a]]
a' (Int
lead' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [Int]
rs
      where
        indices :: Maybe (Int, Int)
indices = ((Int, Int) -> Bool) -> [(Int, Int)] -> Maybe (Int, Int)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (Int, Int) -> Bool
p [(Int, Int)]
l
        p :: (Int, Int) -> Bool
p (Int
col, Int
row) = [[a]]
a [[a]] -> Int -> [a]
forall a. HasCallStack => [a] -> Int -> a
!! Int
row [a] -> Int -> a
forall a. HasCallStack => [a] -> Int -> a
!! Int
col a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= a
0
        l :: [(Int, Int)]
l = [(Int
col, Int
row) |
            Int
col <- [Int
lead .. Int
cols Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1],
            Int
row <- [Int
r .. Int
rows Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]]

        Just (Int
lead', Int
i) = Maybe (Int, Int)
indices
        newRow :: [a]
newRow = (a -> a) -> [a] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (a -> a -> a
forall a. Fractional a => a -> a -> a
/ [[a]]
a [[a]] -> Int -> [a]
forall a. HasCallStack => [a] -> Int -> a
!! Int
i [a] -> Int -> a
forall a. HasCallStack => [a] -> Int -> a
!! Int
lead') ([a] -> [a]) -> [a] -> [a]
forall a b. (a -> b) -> a -> b
$ [[a]]
a [[a]] -> Int -> [a]
forall a. HasCallStack => [a] -> Int -> a
!! Int
i

        a' :: [[a]]
a' = (Int -> [a] -> [a]) -> [Int] -> [[a]] -> [[a]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> [a] -> [a]
g [Int
0..] ([[a]] -> [[a]]) -> [[a]] -> [[a]]
forall a b. (a -> b) -> a -> b
$
            Int -> [a] -> [[a]] -> [[a]]
forall b. Int -> b -> [b] -> [b]
replace Int
r [a]
newRow ([[a]] -> [[a]]) -> [[a]] -> [[a]]
forall a b. (a -> b) -> a -> b
$
            Int -> [a] -> [[a]] -> [[a]]
forall b. Int -> b -> [b] -> [b]
replace Int
i ([[a]]
a [[a]] -> Int -> [a]
forall a. HasCallStack => [a] -> Int -> a
!! Int
r) [[a]]
a
        g :: Int -> [a] -> [a]
g Int
n [a]
row
            | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
r    = [a]
row
            | Bool
otherwise = (a -> a -> a) -> [a] -> [a] -> [a]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith a -> a -> a
h [a]
newRow [a]
row
              where h :: a -> a -> a
h = a -> a -> a
forall a. Num a => a -> a -> a
subtract (a -> a -> a) -> (a -> a) -> a -> a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> a -> a
forall a. Num a => a -> a -> a
* [a]
row [a] -> Int -> a
forall a. HasCallStack => [a] -> Int -> a
!! Int
lead')

        replace :: Int -> b -> [b] -> [b]
        {- Replaces the element at the given index. -}
        replace :: forall b. Int -> b -> [b] -> [b]
replace Int
n b
e [b]
t = [b]
a [b] -> [b] -> [b]
forall a. [a] -> [a] -> [a]
++ b
e b -> [b] -> [b]
forall a. a -> [a] -> [a]
: [b]
b
          where ([b]
a, b
_ : [b]
b) = Int -> [b] -> ([b], [b])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
n [b]
t


-- | Extend a matrix to a given size adding a default element.
--   If the matrix already has the required size, nothing happens.
--   The matrix is /never/ reduced in size.
--   Example:
--
-- >                            ( 1 2 3 0 0 )
-- >                ( 1 2 3 )   ( 4 5 6 0 0 )
-- >                ( 4 5 6 )   ( 7 8 9 0 0 )
-- > extendTo 0 4 5 ( 7 8 9 ) = ( 0 0 0 0 0 )
--
-- The definition of 'extendTo' is based on 'setSize':
--
-- > extendTo e n m a = setSize e (max n $ nrows a) (max m $ ncols a) a
--
extendTo :: a   -- ^ Element to add when extending.
         -> Int -- ^ Minimal number of rows.
         -> Int -- ^ Minimal number of columns.
         -> Matrix a -> Matrix a
extendTo :: forall a. a -> Int -> Int -> Matrix a -> Matrix a
extendTo a
e Int
n Int
m Matrix a
a = a -> Int -> Int -> Matrix a -> Matrix a
forall a. a -> Int -> Int -> Matrix a -> Matrix a
setSize a
e (Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
n (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ Matrix a -> Int
forall a. Matrix a -> Int
nrows Matrix a
a) (Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
m (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ Matrix a -> Int
forall a. Matrix a -> Int
ncols Matrix a
a) Matrix a
a

-- | Set the size of a matrix to given parameters. Use a default element
--   for undefined entries if the matrix has been extended.
setSize :: a   -- ^ Default element.
        -> Int -- ^ Number of rows.
        -> Int -- ^ Number of columns.
        -> Matrix a
        -> Matrix a
{-# INLINE setSize #-}
setSize :: forall a. a -> Int -> Int -> Matrix a -> Matrix a
setSize a
e Int
n Int
m a :: Matrix a
a@(M Int
n0 Int
m0 Int
_ Int
_ Int
_ Vector a
_) = Int -> Int -> ((Int, Int) -> a) -> Matrix a
forall a. Int -> Int -> ((Int, Int) -> a) -> Matrix a
matrix Int
n Int
m (((Int, Int) -> a) -> Matrix a) -> ((Int, Int) -> a) -> Matrix a
forall a b. (a -> b) -> a -> b
$ \(Int
i,Int
j) ->
  if Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
n0 Bool -> Bool -> Bool
&& Int
j Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
m0
     then Int -> Int -> Matrix a -> a
forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i Int
j Matrix a
a
     else a
e

-------------------------------------------------------
-------------------------------------------------------
---- WORKING WITH BLOCKS

-- | /O(1)/. Extract a submatrix given row and column limits.
--   Example:
--
-- >                   ( 1 2 3 )
-- >                   ( 4 5 6 )   ( 2 3 )
-- > submatrix 1 2 2 3 ( 7 8 9 ) = ( 5 6 )
submatrix :: Int    -- ^ Starting row
          -> Int -- ^ Ending row
          -> Int    -- ^ Starting column
          -> Int -- ^ Ending column
          -> Matrix a
          -> Matrix a
{-# INLINE submatrix #-}
submatrix :: forall a. Int -> Int -> Int -> Int -> Matrix a -> Matrix a
submatrix Int
r1 Int
r2 Int
c1 Int
c2 (M Int
n Int
m Int
ro Int
co Int
w Vector a
v)
  | Int
r1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
1  Bool -> Bool -> Bool
|| Int
r1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
n = String -> Matrix a
forall a. HasCallStack => String -> a
error (String -> Matrix a) -> String -> Matrix a
forall a b. (a -> b) -> a -> b
$ String
"submatrix: starting row (" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
r1 String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
") is out of range. Matrix has " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
n String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" rows."
  | Int
c1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
1  Bool -> Bool -> Bool
|| Int
c1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
m = String -> Matrix a
forall a. HasCallStack => String -> a
error (String -> Matrix a) -> String -> Matrix a
forall a b. (a -> b) -> a -> b
$ String
"submatrix: starting column (" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
c1 String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
") is out of range. Matrix has " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
m String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" columns."
  | Int
r2 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
r1 Bool -> Bool -> Bool
|| Int
r2 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
n = String -> Matrix a
forall a. HasCallStack => String -> a
error (String -> Matrix a) -> String -> Matrix a
forall a b. (a -> b) -> a -> b
$ String
"submatrix: ending row (" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
r2 String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
") is out of range. Matrix has " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
n String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" rows, and starting row is " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
r1 String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"."
  | Int
c2 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
c1 Bool -> Bool -> Bool
|| Int
c2 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
m = String -> Matrix a
forall a. HasCallStack => String -> a
error (String -> Matrix a) -> String -> Matrix a
forall a b. (a -> b) -> a -> b
$ String
"submatrix: ending column (" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
c2 String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
") is out of range. Matrix has " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
m String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" columns, and starting column is " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
c1 String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"."
  | Bool
otherwise = Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
forall a. Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
M (Int
r2Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
r1Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) (Int
c2Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
c1Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) (Int
roInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
r1Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) (Int
coInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
c1Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) Int
w Vector a
v

-- | /O(rows*cols)/. Remove a row and a column from a matrix.
--   Example:
--
-- >                 ( 1 2 3 )
-- >                 ( 4 5 6 )   ( 1 3 )
-- > minorMatrix 2 2 ( 7 8 9 ) = ( 7 9 )
minorMatrix :: Int -- ^ Row @r@ to remove.
            -> Int -- ^ Column @c@ to remove.
            -> Matrix a -- ^ Original matrix.
            -> Matrix a -- ^ Matrix with row @r@ and column @c@ removed.
minorMatrix :: forall a. Int -> Int -> Matrix a -> Matrix a
minorMatrix Int
r0 Int
c0 (M Int
n Int
m Int
ro Int
co Int
w Vector a
v) =
  let r :: Int
r = Int
r0 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
ro
      c :: Int
c = Int
c0 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
co
  in  Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
forall a. Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
M (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) (Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) Int
ro Int
co (Int
wInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) (Vector a -> Matrix a) -> Vector a -> Matrix a
forall a b. (a -> b) -> a -> b
$ (Int -> a -> Bool) -> Vector a -> Vector a
forall a. (Int -> a -> Bool) -> Vector a -> Vector a
V.ifilter (\Int
k a
_ -> let (Int
i,Int
j) = Int -> Int -> (Int, Int)
decode Int
w Int
k in Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
r Bool -> Bool -> Bool
&& Int
j Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
c) Vector a
v

-- | /O(1)/. Make a block-partition of a matrix using a given element as reference.
--   The element will stay in the bottom-right corner of the top-left corner matrix.
--
-- >                 (             )   (      |      )
-- >                 (             )   ( ...  | ...  )
-- >                 (    x        )   (    x |      )
-- > splitBlocks i j (             ) = (-------------) , where x = a_{i,j}
-- >                 (             )   (      |      )
-- >                 (             )   ( ...  | ...  )
-- >                 (             )   (      |      )
--
--   Note that some blocks can end up empty. We use the following notation for these blocks:
--
-- > ( TL | TR )
-- > (---------)
-- > ( BL | BR )
--
--   Where T = Top, B = Bottom, L = Left, R = Right.
--
splitBlocks :: Int      -- ^ Row of the splitting element.
            -> Int      -- ^ Column of the splitting element.
            -> Matrix a -- ^ Matrix to split.
            -> (Matrix a,Matrix a
               ,Matrix a,Matrix a) -- ^ (TL,TR,BL,BR)
{-# INLINE[1] splitBlocks #-}
splitBlocks :: forall a.
Int -> Int -> Matrix a -> (Matrix a, Matrix a, Matrix a, Matrix a)
splitBlocks Int
i Int
j a :: Matrix a
a@(M Int
n Int
m Int
_ Int
_ Int
_ Vector a
_) =
    ( Int -> Int -> Int -> Int -> Matrix a -> Matrix a
forall a. Int -> Int -> Int -> Int -> Matrix a -> Matrix a
submatrix    Int
1  Int
i Int
1 Int
j Matrix a
a , Int -> Int -> Int -> Int -> Matrix a -> Matrix a
forall a. Int -> Int -> Int -> Int -> Matrix a -> Matrix a
submatrix    Int
1  Int
i (Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Int
m Matrix a
a
    , Int -> Int -> Int -> Int -> Matrix a -> Matrix a
forall a. Int -> Int -> Int -> Int -> Matrix a -> Matrix a
submatrix (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Int
n Int
1 Int
j Matrix a
a , Int -> Int -> Int -> Int -> Matrix a -> Matrix a
forall a. Int -> Int -> Int -> Int -> Matrix a -> Matrix a
submatrix (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Int
n (Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Int
m Matrix a
a )

-- | Join blocks of the form detailed in 'splitBlocks'. Precisely:
--
-- > joinBlocks (tl,tr,bl,br) =
-- >   (tl <|> tr)
-- >       <->
-- >   (bl <|> br)
joinBlocks :: (Matrix a,Matrix a,Matrix a,Matrix a) -> Matrix a
{-# INLINE[1] joinBlocks #-}
joinBlocks :: forall a. (Matrix a, Matrix a, Matrix a, Matrix a) -> Matrix a
joinBlocks (Matrix a
tl,Matrix a
tr,Matrix a
bl,Matrix a
br) =
  let n :: Int
n  = Matrix a -> Int
forall a. Matrix a -> Int
nrows Matrix a
tl
      nb :: Int
nb = Matrix a -> Int
forall a. Matrix a -> Int
nrows Matrix a
bl
      n' :: Int
n' = Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
nb
      m :: Int
m  = Matrix a -> Int
forall a. Matrix a -> Int
ncols Matrix a
tl
      mr :: Int
mr = Matrix a -> Int
forall a. Matrix a -> Int
ncols Matrix a
tr
      m' :: Int
m' = Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
mr
      en :: (Int, Int) -> Int
en = Int -> (Int, Int) -> Int
encode Int
m'
  in  Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
forall a. Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
M Int
n' Int
m' Int
0 Int
0 Int
m' (Vector a -> Matrix a) -> Vector a -> Matrix a
forall a b. (a -> b) -> a -> b
$ (forall s. ST s (MVector s a)) -> Vector a
forall a. (forall s. ST s (MVector s a)) -> Vector a
V.create ((forall s. ST s (MVector s a)) -> Vector a)
-> (forall s. ST s (MVector s a)) -> Vector a
forall a b. (a -> b) -> a -> b
$ do
        MVector (PrimState (ST s)) a
v <- Int -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) a.
PrimMonad m =>
Int -> m (MVector (PrimState m) a)
MV.new (Int
n'Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
m')
        let wr :: Int -> a -> ST s ()
wr = MVector (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write MVector (PrimState (ST s)) a
v
        Int -> Int -> (Int -> ST s ()) -> ST s ()
forall a (m :: * -> *).
(Num a, Ord a, Monad m) =>
a -> a -> (a -> m ()) -> m ()
numLoop Int
1 Int
n  ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
          Int -> Int -> (Int -> ST s ()) -> ST s ()
forall a (m :: * -> *).
(Num a, Ord a, Monad m) =>
a -> a -> (a -> m ()) -> m ()
numLoop Int
1 Int
m  ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
j -> Int -> a -> ST s ()
wr ((Int, Int) -> Int
en (Int
i ,Int
j  )) (a -> ST s ()) -> a -> ST s ()
forall a b. (a -> b) -> a -> b
$ Matrix a
tl Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
! (Int
i,Int
j)
          Int -> Int -> (Int -> ST s ()) -> ST s ()
forall a (m :: * -> *).
(Num a, Ord a, Monad m) =>
a -> a -> (a -> m ()) -> m ()
numLoop Int
1 Int
mr ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
j -> Int -> a -> ST s ()
wr ((Int, Int) -> Int
en (Int
i ,Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
m)) (a -> ST s ()) -> a -> ST s ()
forall a b. (a -> b) -> a -> b
$ Matrix a
tr Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
! (Int
i,Int
j)
        Int -> Int -> (Int -> ST s ()) -> ST s ()
forall a (m :: * -> *).
(Num a, Ord a, Monad m) =>
a -> a -> (a -> m ()) -> m ()
numLoop Int
1 Int
nb ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
          let i' :: Int
i' = Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
n
          Int -> Int -> (Int -> ST s ()) -> ST s ()
forall a (m :: * -> *).
(Num a, Ord a, Monad m) =>
a -> a -> (a -> m ()) -> m ()
numLoop Int
1 Int
m  ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
j -> Int -> a -> ST s ()
wr ((Int, Int) -> Int
en (Int
i',Int
j  )) (a -> ST s ()) -> a -> ST s ()
forall a b. (a -> b) -> a -> b
$ Matrix a
bl Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
! (Int
i,Int
j)
          Int -> Int -> (Int -> ST s ()) -> ST s ()
forall a (m :: * -> *).
(Num a, Ord a, Monad m) =>
a -> a -> (a -> m ()) -> m ()
numLoop Int
1 Int
mr ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
j -> Int -> a -> ST s ()
wr ((Int, Int) -> Int
en (Int
i',Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
m)) (a -> ST s ()) -> a -> ST s ()
forall a b. (a -> b) -> a -> b
$ Matrix a
br Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
! (Int
i,Int
j)
        MVector s a -> ST s (MVector s a)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return MVector s a
MVector (PrimState (ST s)) a
v

{-# RULES
"matrix/splitAndJoin"
   forall i j m. joinBlocks (splitBlocks i j m) = m
  #-}

-- | Horizontally join two matrices. Visually:
--
-- > ( A ) <|> ( B ) = ( A | B )
--
-- Where both matrices /A/ and /B/ have the same number of rows.
-- /This condition is not checked/.
(<|>) :: Matrix a -> Matrix a -> Matrix a
{-# INLINE (<|>) #-}
Matrix a
m <|> :: forall a. Matrix a -> Matrix a -> Matrix a
<|> Matrix a
m' =
  let c :: Int
c = Matrix a -> Int
forall a. Matrix a -> Int
ncols Matrix a
m
  in  Int -> Int -> ((Int, Int) -> a) -> Matrix a
forall a. Int -> Int -> ((Int, Int) -> a) -> Matrix a
matrix (Matrix a -> Int
forall a. Matrix a -> Int
nrows Matrix a
m) (Int
c Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Matrix a -> Int
forall a. Matrix a -> Int
ncols Matrix a
m') (((Int, Int) -> a) -> Matrix a) -> ((Int, Int) -> a) -> Matrix a
forall a b. (a -> b) -> a -> b
$ \(Int
i,Int
j) ->
        if Int
j Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
c then Matrix a
m Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
! (Int
i,Int
j) else Matrix a
m' Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
! (Int
i,Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
c)

-- | Vertically join two matrices. Visually:
--
-- >                   ( A )
-- > ( A ) <-> ( B ) = ( - )
-- >                   ( B )
--
-- Where both matrices /A/ and /B/ have the same number of columns.
-- /This condition is not checked/.
(<->) :: Matrix a -> Matrix a -> Matrix a
{-# INLINE (<->) #-}
Matrix a
m <-> :: forall a. Matrix a -> Matrix a -> Matrix a
<-> Matrix a
m' =
  let r :: Int
r = Matrix a -> Int
forall a. Matrix a -> Int
nrows Matrix a
m
  in  Int -> Int -> ((Int, Int) -> a) -> Matrix a
forall a. Int -> Int -> ((Int, Int) -> a) -> Matrix a
matrix (Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Matrix a -> Int
forall a. Matrix a -> Int
nrows Matrix a
m') (Matrix a -> Int
forall a. Matrix a -> Int
ncols Matrix a
m) (((Int, Int) -> a) -> Matrix a) -> ((Int, Int) -> a) -> Matrix a
forall a b. (a -> b) -> a -> b
$ \(Int
i,Int
j) ->
        if Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
r then Matrix a
m Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
! (Int
i,Int
j) else Matrix a
m' Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
! (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
r,Int
j)

-------------------------------------------------------
-------------------------------------------------------
---- MATRIX OPERATIONS

-- | Perform an operation element-wise.
--   The second matrix must have at least as many rows
--   and columns as the first matrix. If it's bigger,
--   the leftover items will be ignored.
--   If it's smaller, it will cause a run-time error.
--   You may want to use 'elementwiseUnsafe' if you
--   are definitely sure that a run-time error won't
--   arise.
elementwise :: (a -> b -> c) -> (Matrix a -> Matrix b -> Matrix c)
elementwise :: forall a b c. (a -> b -> c) -> Matrix a -> Matrix b -> Matrix c
elementwise a -> b -> c
f Matrix a
m Matrix b
m' = Int -> Int -> ((Int, Int) -> c) -> Matrix c
forall a. Int -> Int -> ((Int, Int) -> a) -> Matrix a
matrix (Matrix a -> Int
forall a. Matrix a -> Int
nrows Matrix a
m) (Matrix a -> Int
forall a. Matrix a -> Int
ncols Matrix a
m) (((Int, Int) -> c) -> Matrix c) -> ((Int, Int) -> c) -> Matrix c
forall a b. (a -> b) -> a -> b
$
  \(Int, Int)
k -> a -> b -> c
f (Matrix a
m Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
! (Int, Int)
k) (Matrix b
m' Matrix b -> (Int, Int) -> b
forall a. Matrix a -> (Int, Int) -> a
! (Int, Int)
k)

-- | Unsafe version of 'elementwise', but faster.
elementwiseUnsafe :: (a -> b -> c) -> (Matrix a -> Matrix b -> Matrix c)
{-# INLINE elementwiseUnsafe #-}
elementwiseUnsafe :: forall a b c. (a -> b -> c) -> Matrix a -> Matrix b -> Matrix c
elementwiseUnsafe a -> b -> c
f Matrix a
m Matrix b
m' = Int -> Int -> ((Int, Int) -> c) -> Matrix c
forall a. Int -> Int -> ((Int, Int) -> a) -> Matrix a
matrix (Matrix a -> Int
forall a. Matrix a -> Int
nrows Matrix a
m) (Matrix a -> Int
forall a. Matrix a -> Int
ncols Matrix a
m) (((Int, Int) -> c) -> Matrix c) -> ((Int, Int) -> c) -> Matrix c
forall a b. (a -> b) -> a -> b
$
  \(Int
i,Int
j) -> a -> b -> c
f (Int -> Int -> Matrix a -> a
forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i Int
j Matrix a
m) (Int -> Int -> Matrix b -> b
forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i Int
j Matrix b
m')

infixl 6 +., -.

-- | Internal unsafe addition.
(+.) :: Num a => Matrix a -> Matrix a -> Matrix a
{-# INLINE (+.) #-}
+. :: forall a. Num a => Matrix a -> Matrix a -> Matrix a
(+.) = (a -> a -> a) -> Matrix a -> Matrix a -> Matrix a
forall a b c. (a -> b -> c) -> Matrix a -> Matrix b -> Matrix c
elementwiseUnsafe a -> a -> a
forall a. Num a => a -> a -> a
(+)

-- | Internal unsafe substraction.
(-.) :: Num a => Matrix a -> Matrix a -> Matrix a
{-# INLINE (-.) #-}
-. :: forall a. Num a => Matrix a -> Matrix a -> Matrix a
(-.) = (a -> a -> a) -> Matrix a -> Matrix a -> Matrix a
forall a b c. (a -> b -> c) -> Matrix a -> Matrix b -> Matrix c
elementwiseUnsafe (-)

-------------------------------------------------------
-------------------------------------------------------
---- MATRIX MULTIPLICATION

{- $mult

Four methods are provided for matrix multiplication.

* 'multStd':
     Matrix multiplication following directly the definition.
     This is the best choice when you know for sure that your
     matrices are small.

* 'multStd2':
     Matrix multiplication following directly the definition.
     However, using a different definition from 'multStd'.
     According to our benchmarks with this version, 'multStd2' is
     around 3 times faster than 'multStd'.

* 'multStrassen':
     Matrix multiplication following the Strassen's algorithm.
     Complexity grows slower but also some work is added
     partitioning the matrix. Also, it only works on square
     matrices of order @2^n@, so if this condition is not
     met, it is zero-padded until this is accomplished.
     Therefore, its use is not recommended.

* 'multStrassenMixed':
     This function mixes the previous methods.
     It provides a better performance in general. Method @(@'*'@)@
     of the 'Num' class uses this function because it gives the best
     average performance. However, if you know for sure that your matrices are
     small (size less than 500x500), you should use 'multStd' or 'multStd2' instead,
     since 'multStrassenMixed' is going to switch to those functions anyway.

We keep researching how to get better performance for matrix multiplication.
If you want to be on the safe side, use ('*').

-}

-- | Standard matrix multiplication by definition.
multStd :: Num a => Matrix a -> Matrix a -> Matrix a
{-# INLINE multStd #-}
multStd :: forall a. Num a => Matrix a -> Matrix a -> Matrix a
multStd a1 :: Matrix a
a1@(M Int
n Int
m Int
_ Int
_ Int
_ Vector a
_) a2 :: Matrix a
a2@(M Int
n' Int
m' Int
_ Int
_ Int
_ Vector a
_)
   -- Checking that sizes match...
   | Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
n' = String -> Matrix a
forall a. HasCallStack => String -> a
error (String -> Matrix a) -> String -> Matrix a
forall a b. (a -> b) -> a -> b
$ String
"Multiplication of " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> Int -> String
sizeStr Int
n Int
m String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" and "
                    String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> Int -> String
sizeStr Int
n' Int
m' String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" matrices."
   | Bool
otherwise = Matrix a -> Matrix a -> Matrix a
forall a. Num a => Matrix a -> Matrix a -> Matrix a
multStd_ Matrix a
a1 Matrix a
a2

-- | Standard matrix multiplication by definition.
multStd2 :: Num a => Matrix a -> Matrix a -> Matrix a
{-# INLINE multStd2 #-}
multStd2 :: forall a. Num a => Matrix a -> Matrix a -> Matrix a
multStd2 a1 :: Matrix a
a1@(M Int
n Int
m Int
_ Int
_ Int
_ Vector a
_) a2 :: Matrix a
a2@(M Int
n' Int
m' Int
_ Int
_ Int
_ Vector a
_)
   -- Checking that sizes match...
   | Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
n' = String -> Matrix a
forall a. HasCallStack => String -> a
error (String -> Matrix a) -> String -> Matrix a
forall a b. (a -> b) -> a -> b
$ String
"Multiplication of " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> Int -> String
sizeStr Int
n Int
m String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" and "
                    String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> Int -> String
sizeStr Int
n' Int
m' String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" matrices."
   | Bool
otherwise = Matrix a -> Matrix a -> Matrix a
forall a. Num a => Matrix a -> Matrix a -> Matrix a
multStd__ Matrix a
a1 Matrix a
a2

-- | Standard matrix multiplication by definition, without checking if sizes match.
multStd_ :: Num a => Matrix a -> Matrix a -> Matrix a
{-# INLINE multStd_ #-}
multStd_ :: forall a. Num a => Matrix a -> Matrix a -> Matrix a
multStd_ a :: Matrix a
a@(M Int
1 Int
1 Int
_ Int
_ Int
_ Vector a
_) b :: Matrix a
b@(M Int
1 Int
1 Int
_ Int
_ Int
_ Vector a
_) = Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
forall a. Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
M Int
1 Int
1 Int
0 Int
0 Int
1 (Vector a -> Matrix a) -> Vector a -> Matrix a
forall a b. (a -> b) -> a -> b
$ a -> Vector a
forall a. a -> Vector a
V.singleton (a -> Vector a) -> a -> Vector a
forall a b. (a -> b) -> a -> b
$ (Matrix a
a Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
! (Int
1,Int
1)) a -> a -> a
forall a. Num a => a -> a -> a
* (Matrix a
b Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
! (Int
1,Int
1))
multStd_ a :: Matrix a
a@(M Int
2 Int
2 Int
_ Int
_ Int
_ Vector a
_) b :: Matrix a
b@(M Int
2 Int
2 Int
_ Int
_ Int
_ Vector a
_) =
  Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
forall a. Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
M Int
2 Int
2 Int
0 Int
0 Int
2 (Vector a -> Matrix a) -> Vector a -> Matrix a
forall a b. (a -> b) -> a -> b
$
    let -- A
        a11 :: a
a11 = Matrix a
a Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
!. (Int
1,Int
1) ; a12 :: a
a12 = Matrix a
a Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
!. (Int
1,Int
2)
        a21 :: a
a21 = Matrix a
a Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
!. (Int
2,Int
1) ; a22 :: a
a22 = Matrix a
a Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
!. (Int
2,Int
2)
        -- B
        b11 :: a
b11 = Matrix a
b Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
!. (Int
1,Int
1) ; b12 :: a
b12 = Matrix a
b Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
!. (Int
1,Int
2)
        b21 :: a
b21 = Matrix a
b Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
!. (Int
2,Int
1) ; b22 :: a
b22 = Matrix a
b Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
!. (Int
2,Int
2)
    in [a] -> Vector a
forall a. [a] -> Vector a
V.fromList
         [ a
a11a -> a -> a
forall a. Num a => a -> a -> a
*a
b11 a -> a -> a
forall a. Num a => a -> a -> a
+ a
a12a -> a -> a
forall a. Num a => a -> a -> a
*a
b21 , a
a11a -> a -> a
forall a. Num a => a -> a -> a
*a
b12 a -> a -> a
forall a. Num a => a -> a -> a
+ a
a12a -> a -> a
forall a. Num a => a -> a -> a
*a
b22
         , a
a21a -> a -> a
forall a. Num a => a -> a -> a
*a
b11 a -> a -> a
forall a. Num a => a -> a -> a
+ a
a22a -> a -> a
forall a. Num a => a -> a -> a
*a
b21 , a
a21a -> a -> a
forall a. Num a => a -> a -> a
*a
b12 a -> a -> a
forall a. Num a => a -> a -> a
+ a
a22a -> a -> a
forall a. Num a => a -> a -> a
*a
b22
           ]
multStd_ a :: Matrix a
a@(M Int
3 Int
3 Int
_ Int
_ Int
_ Vector a
_) b :: Matrix a
b@(M Int
3 Int
3 Int
_ Int
_ Int
_ Vector a
_) =
  Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
forall a. Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
M Int
3 Int
3 Int
0 Int
0 Int
3 (Vector a -> Matrix a) -> Vector a -> Matrix a
forall a b. (a -> b) -> a -> b
$
    let -- A
        a11 :: a
a11 = Matrix a
a Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
!. (Int
1,Int
1) ; a12 :: a
a12 = Matrix a
a Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
!. (Int
1,Int
2) ; a13 :: a
a13 = Matrix a
a Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
!. (Int
1,Int
3)
        a21 :: a
a21 = Matrix a
a Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
!. (Int
2,Int
1) ; a22 :: a
a22 = Matrix a
a Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
!. (Int
2,Int
2) ; a23 :: a
a23 = Matrix a
a Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
!. (Int
2,Int
3)
        a31 :: a
a31 = Matrix a
a Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
!. (Int
3,Int
1) ; a32 :: a
a32 = Matrix a
a Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
!. (Int
3,Int
2) ; a33 :: a
a33 = Matrix a
a Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
!. (Int
3,Int
3)
        -- B
        b11 :: a
b11 = Matrix a
b Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
!. (Int
1,Int
1) ; b12 :: a
b12 = Matrix a
b Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
!. (Int
1,Int
2) ; b13 :: a
b13 = Matrix a
b Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
!. (Int
1,Int
3)
        b21 :: a
b21 = Matrix a
b Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
!. (Int
2,Int
1) ; b22 :: a
b22 = Matrix a
b Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
!. (Int
2,Int
2) ; b23 :: a
b23 = Matrix a
b Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
!. (Int
2,Int
3)
        b31 :: a
b31 = Matrix a
b Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
!. (Int
3,Int
1) ; b32 :: a
b32 = Matrix a
b Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
!. (Int
3,Int
2) ; b33 :: a
b33 = Matrix a
b Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
!. (Int
3,Int
3)
    in [a] -> Vector a
forall a. [a] -> Vector a
V.fromList
         [ a
a11a -> a -> a
forall a. Num a => a -> a -> a
*a
b11 a -> a -> a
forall a. Num a => a -> a -> a
+ a
a12a -> a -> a
forall a. Num a => a -> a -> a
*a
b21 a -> a -> a
forall a. Num a => a -> a -> a
+ a
a13a -> a -> a
forall a. Num a => a -> a -> a
*a
b31 , a
a11a -> a -> a
forall a. Num a => a -> a -> a
*a
b12 a -> a -> a
forall a. Num a => a -> a -> a
+ a
a12a -> a -> a
forall a. Num a => a -> a -> a
*a
b22 a -> a -> a
forall a. Num a => a -> a -> a
+ a
a13a -> a -> a
forall a. Num a => a -> a -> a
*a
b32 , a
a11a -> a -> a
forall a. Num a => a -> a -> a
*a
b13 a -> a -> a
forall a. Num a => a -> a -> a
+ a
a12a -> a -> a
forall a. Num a => a -> a -> a
*a
b23 a -> a -> a
forall a. Num a => a -> a -> a
+ a
a13a -> a -> a
forall a. Num a => a -> a -> a
*a
b33
         , a
a21a -> a -> a
forall a. Num a => a -> a -> a
*a
b11 a -> a -> a
forall a. Num a => a -> a -> a
+ a
a22a -> a -> a
forall a. Num a => a -> a -> a
*a
b21 a -> a -> a
forall a. Num a => a -> a -> a
+ a
a23a -> a -> a
forall a. Num a => a -> a -> a
*a
b31 , a
a21a -> a -> a
forall a. Num a => a -> a -> a
*a
b12 a -> a -> a
forall a. Num a => a -> a -> a
+ a
a22a -> a -> a
forall a. Num a => a -> a -> a
*a
b22 a -> a -> a
forall a. Num a => a -> a -> a
+ a
a23a -> a -> a
forall a. Num a => a -> a -> a
*a
b32 , a
a21a -> a -> a
forall a. Num a => a -> a -> a
*a
b13 a -> a -> a
forall a. Num a => a -> a -> a
+ a
a22a -> a -> a
forall a. Num a => a -> a -> a
*a
b23 a -> a -> a
forall a. Num a => a -> a -> a
+ a
a23a -> a -> a
forall a. Num a => a -> a -> a
*a
b33
         , a
a31a -> a -> a
forall a. Num a => a -> a -> a
*a
b11 a -> a -> a
forall a. Num a => a -> a -> a
+ a
a32a -> a -> a
forall a. Num a => a -> a -> a
*a
b21 a -> a -> a
forall a. Num a => a -> a -> a
+ a
a33a -> a -> a
forall a. Num a => a -> a -> a
*a
b31 , a
a31a -> a -> a
forall a. Num a => a -> a -> a
*a
b12 a -> a -> a
forall a. Num a => a -> a -> a
+ a
a32a -> a -> a
forall a. Num a => a -> a -> a
*a
b22 a -> a -> a
forall a. Num a => a -> a -> a
+ a
a33a -> a -> a
forall a. Num a => a -> a -> a
*a
b32 , a
a31a -> a -> a
forall a. Num a => a -> a -> a
*a
b13 a -> a -> a
forall a. Num a => a -> a -> a
+ a
a32a -> a -> a
forall a. Num a => a -> a -> a
*a
b23 a -> a -> a
forall a. Num a => a -> a -> a
+ a
a33a -> a -> a
forall a. Num a => a -> a -> a
*a
b33
           ]
multStd_ a :: Matrix a
a@(M Int
n Int
m Int
_ Int
_ Int
_ Vector a
_) b :: Matrix a
b@(M Int
_ Int
m' Int
_ Int
_ Int
_ Vector a
_) = Int -> Int -> ((Int, Int) -> a) -> Matrix a
forall a. Int -> Int -> ((Int, Int) -> a) -> Matrix a
matrix Int
n Int
m' (((Int, Int) -> a) -> Matrix a) -> ((Int, Int) -> a) -> Matrix a
forall a b. (a -> b) -> a -> b
$ \(Int
i,Int
j) -> [a] -> a
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [ Matrix a
a Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
!. (Int
i,Int
k) a -> a -> a
forall a. Num a => a -> a -> a
* Matrix a
b Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
!. (Int
k,Int
j) | Int
k <- [Int
1 .. Int
m] ]

multStd__ :: Num a => Matrix a -> Matrix a -> Matrix a
{-# INLINE multStd__ #-}
multStd__ :: forall a. Num a => Matrix a -> Matrix a -> Matrix a
multStd__ Matrix a
a Matrix a
b = Int -> Int -> ((Int, Int) -> a) -> Matrix a
forall a. Int -> Int -> ((Int, Int) -> a) -> Matrix a
matrix Int
r Int
c (((Int, Int) -> a) -> Matrix a) -> ((Int, Int) -> a) -> Matrix a
forall a b. (a -> b) -> a -> b
$ \(Int
i,Int
j) -> Vector a -> Vector a -> a
forall a. Num a => Vector a -> Vector a -> a
dotProduct (Vector (Vector a) -> Int -> Vector a
forall a. Vector a -> Int -> a
V.unsafeIndex Vector (Vector a)
avs (Int -> Vector a) -> Int -> Vector a
forall a b. (a -> b) -> a -> b
$ Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (Vector (Vector a) -> Int -> Vector a
forall a. Vector a -> Int -> a
V.unsafeIndex Vector (Vector a)
bvs (Int -> Vector a) -> Int -> Vector a
forall a b. (a -> b) -> a -> b
$ Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
  where
    r :: Int
r = Matrix a -> Int
forall a. Matrix a -> Int
nrows Matrix a
a
    avs :: Vector (Vector a)
avs = Int -> (Int -> Vector a) -> Vector (Vector a)
forall a. Int -> (Int -> a) -> Vector a
V.generate Int
r ((Int -> Vector a) -> Vector (Vector a))
-> (Int -> Vector a) -> Vector (Vector a)
forall a b. (a -> b) -> a -> b
$ \Int
i -> Int -> Matrix a -> Vector a
forall a. Int -> Matrix a -> Vector a
getRow (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Matrix a
a
    c :: Int
c = Matrix a -> Int
forall a. Matrix a -> Int
ncols Matrix a
b
    bvs :: Vector (Vector a)
bvs = Int -> (Int -> Vector a) -> Vector (Vector a)
forall a. Int -> (Int -> a) -> Vector a
V.generate Int
c ((Int -> Vector a) -> Vector (Vector a))
-> (Int -> Vector a) -> Vector (Vector a)
forall a b. (a -> b) -> a -> b
$ \Int
i -> Int -> Matrix a -> Vector a
forall a. Int -> Matrix a -> Vector a
getCol (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Matrix a
b

dotProduct :: Num a => V.Vector a -> V.Vector a -> a
{-# INLINE dotProduct #-}
dotProduct :: forall a. Num a => Vector a -> Vector a -> a
dotProduct Vector a
v1 Vector a
v2 = Int -> Int -> a -> (a -> Int -> a) -> a
forall a acc.
(Num a, Eq a) =>
a -> a -> acc -> (acc -> a -> acc) -> acc
numLoopFold Int
0 (Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
v1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) a
0 ((a -> Int -> a) -> a) -> (a -> Int -> a) -> a
forall a b. (a -> b) -> a -> b
$
  \a
r Int
i -> Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.unsafeIndex Vector a
v1 Int
i a -> a -> a
forall a. Num a => a -> a -> a
* Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.unsafeIndex Vector a
v2 Int
i a -> a -> a
forall a. Num a => a -> a -> a
+ a
r

{-
dotProduct v1 v2 = go (V.length v1 - 1) 0
  where
    go (-1) a = a
    go i a = go (i-1) $ (V.unsafeIndex v1 i) * (V.unsafeIndex v2 i) + a
-}

first :: (a -> Bool) -> [a] -> a
first :: forall a. (a -> Bool) -> [a] -> a
first a -> Bool
f = [a] -> a
go
 where
  go :: [a] -> a
go (a
x:[a]
xs) = if a -> Bool
f a
x then a
x else [a] -> a
go [a]
xs
  go [a]
_ = String -> a
forall a. HasCallStack => String -> a
error String
"first: no element match the condition."

-- | Strassen's algorithm over square matrices of order @2^n@.
strassen :: Num a => Matrix a -> Matrix a -> Matrix a
-- Trivial 1x1 multiplication.
strassen :: forall a. Num a => Matrix a -> Matrix a -> Matrix a
strassen a :: Matrix a
a@(M Int
1 Int
1 Int
_ Int
_ Int
_ Vector a
_) b :: Matrix a
b@(M Int
1 Int
1 Int
_ Int
_ Int
_ Vector a
_) = Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
forall a. Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
M Int
1 Int
1 Int
0 Int
0 Int
1 (Vector a -> Matrix a) -> Vector a -> Matrix a
forall a b. (a -> b) -> a -> b
$ a -> Vector a
forall a. a -> Vector a
V.singleton (a -> Vector a) -> a -> Vector a
forall a b. (a -> b) -> a -> b
$ (Matrix a
a Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
! (Int
1,Int
1)) a -> a -> a
forall a. Num a => a -> a -> a
* (Matrix a
b Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
! (Int
1,Int
1))
-- General case guesses that the input matrices are square matrices
-- whose order is a power of two.
strassen Matrix a
a Matrix a
b = (Matrix a, Matrix a, Matrix a, Matrix a) -> Matrix a
forall a. (Matrix a, Matrix a, Matrix a, Matrix a) -> Matrix a
joinBlocks (Matrix a
c11,Matrix a
c12,Matrix a
c21,Matrix a
c22)
 where
  -- Size of the subproblem is halved.
  n :: Int
n = Int -> Int -> Int
forall a. Integral a => a -> a -> a
div (Matrix a -> Int
forall a. Matrix a -> Int
nrows Matrix a
a) Int
2
  -- Split of the original problem into smaller subproblems.
  (Matrix a
a11,Matrix a
a12,Matrix a
a21,Matrix a
a22) = Int -> Int -> Matrix a -> (Matrix a, Matrix a, Matrix a, Matrix a)
forall a.
Int -> Int -> Matrix a -> (Matrix a, Matrix a, Matrix a, Matrix a)
splitBlocks Int
n Int
n Matrix a
a
  (Matrix a
b11,Matrix a
b12,Matrix a
b21,Matrix a
b22) = Int -> Int -> Matrix a -> (Matrix a, Matrix a, Matrix a, Matrix a)
forall a.
Int -> Int -> Matrix a -> (Matrix a, Matrix a, Matrix a, Matrix a)
splitBlocks Int
n Int
n Matrix a
b
  -- The seven Strassen's products.
  p1 :: Matrix a
p1 = Matrix a -> Matrix a -> Matrix a
forall a. Num a => Matrix a -> Matrix a -> Matrix a
strassen (Matrix a
a11 Matrix a -> Matrix a -> Matrix a
forall a. Num a => a -> a -> a
+ Matrix a
a22) (Matrix a
b11 Matrix a -> Matrix a -> Matrix a
forall a. Num a => a -> a -> a
+ Matrix a
b22)
  p2 :: Matrix a
p2 = Matrix a -> Matrix a -> Matrix a
forall a. Num a => Matrix a -> Matrix a -> Matrix a
strassen (Matrix a
a21 Matrix a -> Matrix a -> Matrix a
forall a. Num a => a -> a -> a
+ Matrix a
a22)  Matrix a
b11
  p3 :: Matrix a
p3 = Matrix a -> Matrix a -> Matrix a
forall a. Num a => Matrix a -> Matrix a -> Matrix a
strassen  Matrix a
a11        (Matrix a
b12 Matrix a -> Matrix a -> Matrix a
forall a. Num a => a -> a -> a
- Matrix a
b22)
  p4 :: Matrix a
p4 = Matrix a -> Matrix a -> Matrix a
forall a. Num a => Matrix a -> Matrix a -> Matrix a
strassen        Matrix a
a22  (Matrix a
b21 Matrix a -> Matrix a -> Matrix a
forall a. Num a => a -> a -> a
- Matrix a
b11)
  p5 :: Matrix a
p5 = Matrix a -> Matrix a -> Matrix a
forall a. Num a => Matrix a -> Matrix a -> Matrix a
strassen (Matrix a
a11 Matrix a -> Matrix a -> Matrix a
forall a. Num a => a -> a -> a
+ Matrix a
a12)        Matrix a
b22
  p6 :: Matrix a
p6 = Matrix a -> Matrix a -> Matrix a
forall a. Num a => Matrix a -> Matrix a -> Matrix a
strassen (Matrix a
a21 Matrix a -> Matrix a -> Matrix a
forall a. Num a => a -> a -> a
- Matrix a
a11) (Matrix a
b11 Matrix a -> Matrix a -> Matrix a
forall a. Num a => a -> a -> a
+ Matrix a
b12)
  p7 :: Matrix a
p7 = Matrix a -> Matrix a -> Matrix a
forall a. Num a => Matrix a -> Matrix a -> Matrix a
strassen (Matrix a
a12 Matrix a -> Matrix a -> Matrix a
forall a. Num a => a -> a -> a
- Matrix a
a22) (Matrix a
b21 Matrix a -> Matrix a -> Matrix a
forall a. Num a => a -> a -> a
+ Matrix a
b22)
  -- Merging blocks
  c11 :: Matrix a
c11 = Matrix a
p1 Matrix a -> Matrix a -> Matrix a
forall a. Num a => a -> a -> a
+ Matrix a
p4 Matrix a -> Matrix a -> Matrix a
forall a. Num a => a -> a -> a
- Matrix a
p5 Matrix a -> Matrix a -> Matrix a
forall a. Num a => a -> a -> a
+ Matrix a
p7
  c12 :: Matrix a
c12 = Matrix a
p3 Matrix a -> Matrix a -> Matrix a
forall a. Num a => a -> a -> a
+ Matrix a
p5
  c21 :: Matrix a
c21 = Matrix a
p2 Matrix a -> Matrix a -> Matrix a
forall a. Num a => a -> a -> a
+ Matrix a
p4
  c22 :: Matrix a
c22 = Matrix a
p1 Matrix a -> Matrix a -> Matrix a
forall a. Num a => a -> a -> a
- Matrix a
p2 Matrix a -> Matrix a -> Matrix a
forall a. Num a => a -> a -> a
+ Matrix a
p3 Matrix a -> Matrix a -> Matrix a
forall a. Num a => a -> a -> a
+ Matrix a
p6

-- | Strassen's matrix multiplication.
multStrassen :: Num a => Matrix a -> Matrix a -> Matrix a
multStrassen :: forall a. Num a => Matrix a -> Matrix a -> Matrix a
multStrassen a1 :: Matrix a
a1@(M Int
n Int
m Int
_ Int
_ Int
_ Vector a
_) a2 :: Matrix a
a2@(M Int
n' Int
m' Int
_ Int
_ Int
_ Vector a
_)
   | Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
n' = String -> Matrix a
forall a. HasCallStack => String -> a
error (String -> Matrix a) -> String -> Matrix a
forall a b. (a -> b) -> a -> b
$ String
"Multiplication of " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> Int -> String
sizeStr Int
n Int
m String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" and "
                    String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> Int -> String
sizeStr Int
n' Int
m' String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" matrices."
   | Bool
otherwise =
       let mx :: Int
mx = [Int] -> Int
forall a. Ord a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum [Int
n,Int
m,Int
n',Int
m']
           n2 :: Int
n2  = (Int -> Bool) -> [Int] -> Int
forall a. (a -> Bool) -> [a] -> a
first (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
mx) ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int
2Int -> Int -> Int
forall a b. (Num a, Integral b) => a -> b -> a
^) [(Int
0 :: Int)..]
           b1 :: Matrix a
b1 = a -> Int -> Int -> Matrix a -> Matrix a
forall a. a -> Int -> Int -> Matrix a -> Matrix a
setSize a
0 Int
n2 Int
n2 Matrix a
a1
           b2 :: Matrix a
b2 = a -> Int -> Int -> Matrix a -> Matrix a
forall a. a -> Int -> Int -> Matrix a -> Matrix a
setSize a
0 Int
n2 Int
n2 Matrix a
a2
       in  Int -> Int -> Int -> Int -> Matrix a -> Matrix a
forall a. Int -> Int -> Int -> Int -> Matrix a -> Matrix a
submatrix Int
1 Int
n Int
1 Int
m' (Matrix a -> Matrix a) -> Matrix a -> Matrix a
forall a b. (a -> b) -> a -> b
$ Matrix a -> Matrix a -> Matrix a
forall a. Num a => Matrix a -> Matrix a -> Matrix a
strassen Matrix a
b1 Matrix a
b2

strmixFactor :: Int
strmixFactor :: Int
strmixFactor = Int
300

-- | Strassen's mixed algorithm.
strassenMixed :: Num a => Matrix a -> Matrix a -> Matrix a
{-# SPECIALIZE strassenMixed :: Matrix Double -> Matrix Double -> Matrix Double #-}
{-# SPECIALIZE strassenMixed :: Matrix Int -> Matrix Int -> Matrix Int #-}
{-# SPECIALIZE strassenMixed :: Matrix Rational -> Matrix Rational -> Matrix Rational #-}
strassenMixed :: forall a. Num a => Matrix a -> Matrix a -> Matrix a
strassenMixed Matrix a
a Matrix a
b
 | Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
strmixFactor = Matrix a -> Matrix a -> Matrix a
forall a. Num a => Matrix a -> Matrix a -> Matrix a
multStd__ Matrix a
a Matrix a
b
 | Int -> Bool
forall a. Integral a => a -> Bool
odd Int
r = let r' :: Int
r' = Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
               a' :: Matrix a
a' = a -> Int -> Int -> Matrix a -> Matrix a
forall a. a -> Int -> Int -> Matrix a -> Matrix a
setSize a
0 Int
r' Int
r' Matrix a
a
               b' :: Matrix a
b' = a -> Int -> Int -> Matrix a -> Matrix a
forall a. a -> Int -> Int -> Matrix a -> Matrix a
setSize a
0 Int
r' Int
r' Matrix a
b
           in  Int -> Int -> Int -> Int -> Matrix a -> Matrix a
forall a. Int -> Int -> Int -> Int -> Matrix a -> Matrix a
submatrix Int
1 Int
r Int
1 Int
r (Matrix a -> Matrix a) -> Matrix a -> Matrix a
forall a b. (a -> b) -> a -> b
$ Matrix a -> Matrix a -> Matrix a
forall a. Num a => Matrix a -> Matrix a -> Matrix a
strassenMixed Matrix a
a' Matrix a
b'
 | Bool
otherwise =
      Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
forall a. Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
M Int
r Int
r Int
0 Int
0 Int
r (Vector a -> Matrix a) -> Vector a -> Matrix a
forall a b. (a -> b) -> a -> b
$ (forall s. ST s (MVector s a)) -> Vector a
forall a. (forall s. ST s (MVector s a)) -> Vector a
V.create ((forall s. ST s (MVector s a)) -> Vector a)
-> (forall s. ST s (MVector s a)) -> Vector a
forall a b. (a -> b) -> a -> b
$ do
         MVector s a
v <- Int -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) a.
PrimMonad m =>
Int -> m (MVector (PrimState m) a)
MV.unsafeNew (Int
rInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
r)
         let en :: (Int, Int) -> Int
en = Int -> (Int, Int) -> Int
encode Int
r
             n' :: Int
n' = Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
         -- c11 = p1 + p4 - p5 + p7
         [ST s ()] -> ST s ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ [ MVector (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write MVector s a
MVector (PrimState (ST s)) a
v Int
k (a -> ST s ()) -> a -> ST s ()
forall a b. (a -> b) -> a -> b
$
                         Int -> Int -> Matrix a -> a
forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i Int
j Matrix a
p1
                       a -> a -> a
forall a. Num a => a -> a -> a
+ Int -> Int -> Matrix a -> a
forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i Int
j Matrix a
p4
                       a -> a -> a
forall a. Num a => a -> a -> a
- Int -> Int -> Matrix a -> a
forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i Int
j Matrix a
p5
                       a -> a -> a
forall a. Num a => a -> a -> a
+ Int -> Int -> Matrix a -> a
forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i Int
j Matrix a
p7
                   | Int
i <- [Int
1..Int
n]
                   , Int
j <- [Int
1..Int
n]
                   , let k :: Int
k = (Int, Int) -> Int
en (Int
i,Int
j)
                     ]
         -- c12 = p3 + p5
         [ST s ()] -> ST s ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ [ MVector (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write MVector s a
MVector (PrimState (ST s)) a
v Int
k (a -> ST s ()) -> a -> ST s ()
forall a b. (a -> b) -> a -> b
$
                         Int -> Int -> Matrix a -> a
forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i Int
j' Matrix a
p3
                       a -> a -> a
forall a. Num a => a -> a -> a
+ Int -> Int -> Matrix a -> a
forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i Int
j' Matrix a
p5
                   | Int
i <- [Int
1..Int
n]
                   , Int
j <- [Int
n'..Int
r]
                   , let k :: Int
k = (Int, Int) -> Int
en (Int
i,Int
j)
                   , let j' :: Int
j' = Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n
                     ]
         -- c21 = p2 + p4
         [ST s ()] -> ST s ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ [ MVector (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write MVector s a
MVector (PrimState (ST s)) a
v Int
k (a -> ST s ()) -> a -> ST s ()
forall a b. (a -> b) -> a -> b
$
                         Int -> Int -> Matrix a -> a
forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i' Int
j Matrix a
p2
                       a -> a -> a
forall a. Num a => a -> a -> a
+ Int -> Int -> Matrix a -> a
forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i' Int
j Matrix a
p4
                   | Int
i <- [Int
n'..Int
r]
                   , Int
j <- [Int
1..Int
n]
                   , let k :: Int
k = (Int, Int) -> Int
en (Int
i,Int
j)
                   , let i' :: Int
i' = Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n
                     ]
         -- c22 = p1 - p2 + p3 + p6
         [ST s ()] -> ST s ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ [ MVector (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write MVector s a
MVector (PrimState (ST s)) a
v Int
k (a -> ST s ()) -> a -> ST s ()
forall a b. (a -> b) -> a -> b
$
                         Int -> Int -> Matrix a -> a
forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i' Int
j' Matrix a
p1
                       a -> a -> a
forall a. Num a => a -> a -> a
- Int -> Int -> Matrix a -> a
forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i' Int
j' Matrix a
p2
                       a -> a -> a
forall a. Num a => a -> a -> a
+ Int -> Int -> Matrix a -> a
forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i' Int
j' Matrix a
p3
                       a -> a -> a
forall a. Num a => a -> a -> a
+ Int -> Int -> Matrix a -> a
forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i' Int
j' Matrix a
p6
                   | Int
i <- [Int
n'..Int
r]
                   , Int
j <- [Int
n'..Int
r]
                   , let k :: Int
k = (Int, Int) -> Int
en (Int
i,Int
j)
                   , let i' :: Int
i' = Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n
                   , let j' :: Int
j' = Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n
                     ]
         MVector s a -> ST s (MVector s a)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return MVector s a
v
 where
  r :: Int
r = Matrix a -> Int
forall a. Matrix a -> Int
nrows Matrix a
a
  -- Size of the subproblem is halved.
  n :: Int
n = Int -> Int -> Int
forall a. Integral a => a -> a -> a
quot Int
r Int
2
  -- Split of the original problem into smaller subproblems.
  (Matrix a
a11,Matrix a
a12,Matrix a
a21,Matrix a
a22) = Int -> Int -> Matrix a -> (Matrix a, Matrix a, Matrix a, Matrix a)
forall a.
Int -> Int -> Matrix a -> (Matrix a, Matrix a, Matrix a, Matrix a)
splitBlocks Int
n Int
n Matrix a
a
  (Matrix a
b11,Matrix a
b12,Matrix a
b21,Matrix a
b22) = Int -> Int -> Matrix a -> (Matrix a, Matrix a, Matrix a, Matrix a)
forall a.
Int -> Int -> Matrix a -> (Matrix a, Matrix a, Matrix a, Matrix a)
splitBlocks Int
n Int
n Matrix a
b
  -- The seven Strassen's products.
  p1 :: Matrix a
p1 = Matrix a -> Matrix a -> Matrix a
forall a. Num a => Matrix a -> Matrix a -> Matrix a
strassenMixed (Matrix a
a11 Matrix a -> Matrix a -> Matrix a
forall a. Num a => Matrix a -> Matrix a -> Matrix a
+. Matrix a
a22) (Matrix a
b11 Matrix a -> Matrix a -> Matrix a
forall a. Num a => Matrix a -> Matrix a -> Matrix a
+. Matrix a
b22)
  p2 :: Matrix a
p2 = Matrix a -> Matrix a -> Matrix a
forall a. Num a => Matrix a -> Matrix a -> Matrix a
strassenMixed (Matrix a
a21 Matrix a -> Matrix a -> Matrix a
forall a. Num a => Matrix a -> Matrix a -> Matrix a
+. Matrix a
a22)  Matrix a
b11
  p3 :: Matrix a
p3 = Matrix a -> Matrix a -> Matrix a
forall a. Num a => Matrix a -> Matrix a -> Matrix a
strassenMixed  Matrix a
a11         (Matrix a
b12 Matrix a -> Matrix a -> Matrix a
forall a. Num a => Matrix a -> Matrix a -> Matrix a
-. Matrix a
b22)
  p4 :: Matrix a
p4 = Matrix a -> Matrix a -> Matrix a
forall a. Num a => Matrix a -> Matrix a -> Matrix a
strassenMixed         Matrix a
a22  (Matrix a
b21 Matrix a -> Matrix a -> Matrix a
forall a. Num a => Matrix a -> Matrix a -> Matrix a
-. Matrix a
b11)
  p5 :: Matrix a
p5 = Matrix a -> Matrix a -> Matrix a
forall a. Num a => Matrix a -> Matrix a -> Matrix a
strassenMixed (Matrix a
a11 Matrix a -> Matrix a -> Matrix a
forall a. Num a => Matrix a -> Matrix a -> Matrix a
+. Matrix a
a12)         Matrix a
b22
  p6 :: Matrix a
p6 = Matrix a -> Matrix a -> Matrix a
forall a. Num a => Matrix a -> Matrix a -> Matrix a
strassenMixed (Matrix a
a21 Matrix a -> Matrix a -> Matrix a
forall a. Num a => Matrix a -> Matrix a -> Matrix a
-. Matrix a
a11) (Matrix a
b11 Matrix a -> Matrix a -> Matrix a
forall a. Num a => Matrix a -> Matrix a -> Matrix a
+. Matrix a
b12)
  p7 :: Matrix a
p7 = Matrix a -> Matrix a -> Matrix a
forall a. Num a => Matrix a -> Matrix a -> Matrix a
strassenMixed (Matrix a
a12 Matrix a -> Matrix a -> Matrix a
forall a. Num a => Matrix a -> Matrix a -> Matrix a
-. Matrix a
a22) (Matrix a
b21 Matrix a -> Matrix a -> Matrix a
forall a. Num a => Matrix a -> Matrix a -> Matrix a
+. Matrix a
b22)

-- | Mixed Strassen's matrix multiplication.
multStrassenMixed :: Num a => Matrix a -> Matrix a -> Matrix a
{-# INLINE multStrassenMixed #-}
multStrassenMixed :: forall a. Num a => Matrix a -> Matrix a -> Matrix a
multStrassenMixed a1 :: Matrix a
a1@(M Int
n Int
m Int
_ Int
_ Int
_ Vector a
_) a2 :: Matrix a
a2@(M Int
n' Int
m' Int
_ Int
_ Int
_ Vector a
_)
   | Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
n' = String -> Matrix a
forall a. HasCallStack => String -> a
error (String -> Matrix a) -> String -> Matrix a
forall a b. (a -> b) -> a -> b
$ String
"Multiplication of " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> Int -> String
sizeStr Int
n Int
m String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" and "
                    String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> Int -> String
sizeStr Int
n' Int
m' String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" matrices."
   | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
strmixFactor = Matrix a -> Matrix a -> Matrix a
forall a. Num a => Matrix a -> Matrix a -> Matrix a
multStd__ Matrix a
a1 Matrix a
a2
   | Bool
otherwise =
       let mx :: Int
mx = [Int] -> Int
forall a. Ord a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum [Int
n,Int
m,Int
n',Int
m']
           n2 :: Int
n2 = if Int -> Bool
forall a. Integral a => a -> Bool
even Int
mx then Int
mx else Int
mxInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1
           b1 :: Matrix a
b1 = a -> Int -> Int -> Matrix a -> Matrix a
forall a. a -> Int -> Int -> Matrix a -> Matrix a
setSize a
0 Int
n2 Int
n2 Matrix a
a1
           b2 :: Matrix a
b2 = a -> Int -> Int -> Matrix a -> Matrix a
forall a. a -> Int -> Int -> Matrix a -> Matrix a
setSize a
0 Int
n2 Int
n2 Matrix a
a2
       in  Int -> Int -> Int -> Int -> Matrix a -> Matrix a
forall a. Int -> Int -> Int -> Int -> Matrix a -> Matrix a
submatrix Int
1 Int
n Int
1 Int
m' (Matrix a -> Matrix a) -> Matrix a -> Matrix a
forall a b. (a -> b) -> a -> b
$ Matrix a -> Matrix a -> Matrix a
forall a. Num a => Matrix a -> Matrix a -> Matrix a
strassenMixed Matrix a
b1 Matrix a
b2

-------------------------------------------------------
-------------------------------------------------------
---- NUMERICAL INSTANCE

instance Num a => Num (Matrix a) where
 fromInteger :: Integer -> Matrix a
fromInteger = Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
forall a. Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
M Int
1 Int
1 Int
0 Int
0 Int
1 (Vector a -> Matrix a)
-> (Integer -> Vector a) -> Integer -> Matrix a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Vector a
forall a. a -> Vector a
V.singleton (a -> Vector a) -> (Integer -> a) -> Integer -> Vector a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> a
forall a. Num a => Integer -> a
fromInteger
 negate :: Matrix a -> Matrix a
negate = (a -> a) -> Matrix a -> Matrix a
forall a b. (a -> b) -> Matrix a -> Matrix b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Num a => a -> a
negate
 abs :: Matrix a -> Matrix a
abs = (a -> a) -> Matrix a -> Matrix a
forall a b. (a -> b) -> Matrix a -> Matrix b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Num a => a -> a
abs
 signum :: Matrix a -> Matrix a
signum = (a -> a) -> Matrix a -> Matrix a
forall a b. (a -> b) -> Matrix a -> Matrix b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Num a => a -> a
signum

 -- Addition of matrices.
 {-# SPECIALIZE (+) :: Matrix Double -> Matrix Double -> Matrix Double #-}
 {-# SPECIALIZE (+) :: Matrix Int -> Matrix Int -> Matrix Int #-}
 {-# SPECIALIZE (+) :: Matrix Rational -> Matrix Rational -> Matrix Rational #-}
 + :: Matrix a -> Matrix a -> Matrix a
(+) = (a -> a -> a) -> Matrix a -> Matrix a -> Matrix a
forall a b c. (a -> b -> c) -> Matrix a -> Matrix b -> Matrix c
elementwise a -> a -> a
forall a. Num a => a -> a -> a
(+)

 -- Substraction of matrices.
 {-# SPECIALIZE (-) :: Matrix Double -> Matrix Double -> Matrix Double #-}
 {-# SPECIALIZE (-) :: Matrix Int -> Matrix Int -> Matrix Int #-}
 {-# SPECIALIZE (-) :: Matrix Rational -> Matrix Rational -> Matrix Rational #-}
 (-) = (a -> a -> a) -> Matrix a -> Matrix a -> Matrix a
forall a b c. (a -> b -> c) -> Matrix a -> Matrix b -> Matrix c
elementwise (-)

 -- Multiplication of matrices.
 {-# INLINE (*) #-}
 * :: Matrix a -> Matrix a -> Matrix a
(*) = Matrix a -> Matrix a -> Matrix a
forall a. Num a => Matrix a -> Matrix a -> Matrix a
multStrassenMixed

-------------------------------------------------------
-------------------------------------------------------
---- TRANSFORMATIONS

-- | Scale a matrix by a given factor.
--   Example:
--
-- >               ( 1 2 3 )   (  2  4  6 )
-- >               ( 4 5 6 )   (  8 10 12 )
-- > scaleMatrix 2 ( 7 8 9 ) = ( 14 16 18 )
scaleMatrix :: Num a => a -> Matrix a -> Matrix a
scaleMatrix :: forall a. Num a => a -> Matrix a -> Matrix a
scaleMatrix = (a -> a) -> Matrix a -> Matrix a
forall a b. (a -> b) -> Matrix a -> Matrix b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((a -> a) -> Matrix a -> Matrix a)
-> (a -> a -> a) -> a -> Matrix a -> Matrix a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> a -> a
forall a. Num a => a -> a -> a
(*)

-- | Scale a row by a given factor.
--   Example:
--
-- >              ( 1 2 3 )   (  1  2  3 )
-- >              ( 4 5 6 )   (  8 10 12 )
-- > scaleRow 2 2 ( 7 8 9 ) = (  7  8  9 )
scaleRow :: Num a => a -> Int -> Matrix a -> Matrix a
scaleRow :: forall a. Num a => a -> Int -> Matrix a -> Matrix a
scaleRow = (Int -> a -> a) -> Int -> Matrix a -> Matrix a
forall a. (Int -> a -> a) -> Int -> Matrix a -> Matrix a
mapRow ((Int -> a -> a) -> Int -> Matrix a -> Matrix a)
-> (a -> Int -> a -> a) -> a -> Int -> Matrix a -> Matrix a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> a) -> Int -> a -> a
forall a b. a -> b -> a
const ((a -> a) -> Int -> a -> a) -> (a -> a -> a) -> a -> Int -> a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> a -> a
forall a. Num a => a -> a -> a
(*)

-- | Add to one row a scalar multiple of another row.
--   Example:
--
-- >                   ( 1 2 3 )   (  1  2  3 )
-- >                   ( 4 5 6 )   (  6  9 12 )
-- > combineRows 2 2 1 ( 7 8 9 ) = (  7  8  9 )
combineRows :: Num a => Int -> a -> Int -> Matrix a -> Matrix a
combineRows :: forall a. Num a => Int -> a -> Int -> Matrix a -> Matrix a
combineRows Int
r1 a
l Int
r2 Matrix a
m = (Int -> a -> a) -> Int -> Matrix a -> Matrix a
forall a. (Int -> a -> a) -> Int -> Matrix a -> Matrix a
mapRow (\Int
j a
x -> a
x a -> a -> a
forall a. Num a => a -> a -> a
+ a
l a -> a -> a
forall a. Num a => a -> a -> a
* Int -> Int -> Matrix a -> a
forall a. Int -> Int -> Matrix a -> a
getElem Int
r2 Int
j Matrix a
m) Int
r1 Matrix a
m

-- | Switch two rows of a matrix.
--   Example:
--
-- >                ( 1 2 3 )   ( 4 5 6 )
-- >                ( 4 5 6 )   ( 1 2 3 )
-- > switchRows 1 2 ( 7 8 9 ) = ( 7 8 9 )
switchRows :: Int -- ^ Row 1.
           -> Int -- ^ Row 2.
           -> Matrix a -- ^ Original matrix.
           -> Matrix a -- ^ Matrix with rows 1 and 2 switched.
switchRows :: forall a. Int -> Int -> Matrix a -> Matrix a
switchRows Int
r1 Int
r2 (M Int
n Int
m Int
ro Int
co Int
w Vector a
vs) = Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
forall a. Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
M Int
n Int
m Int
ro Int
co Int
w (Vector a -> Matrix a) -> Vector a -> Matrix a
forall a b. (a -> b) -> a -> b
$ (forall s. MVector s a -> ST s ()) -> Vector a -> Vector a
forall a.
(forall s. MVector s a -> ST s ()) -> Vector a -> Vector a
V.modify (\MVector s a
mv -> do
  Int -> Int -> (Int -> ST s ()) -> ST s ()
forall a (m :: * -> *).
(Num a, Ord a, Monad m) =>
a -> a -> (a -> m ()) -> m ()
numLoop Int
1 Int
m ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
j ->
    MVector (PrimState (ST s)) a -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> Int -> m ()
MV.swap MVector s a
MVector (PrimState (ST s)) a
mv (Int -> (Int, Int) -> Int
encode Int
w (Int
r1Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
ro,Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
co)) (Int -> (Int, Int) -> Int
encode Int
w (Int
r2Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
ro,Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
co))) Vector a
vs

-- | Switch two coumns of a matrix.
--   Example:
--
-- >                ( 1 2 3 )   ( 2 1 3 )
-- >                ( 4 5 6 )   ( 5 4 6 )
-- > switchCols 1 2 ( 7 8 9 ) = ( 8 7 9 )
switchCols :: Int -- ^ Col 1.
           -> Int -- ^ Col 2.
           -> Matrix a -- ^ Original matrix.
           -> Matrix a -- ^ Matrix with cols 1 and 2 switched.
switchCols :: forall a. Int -> Int -> Matrix a -> Matrix a
switchCols Int
c1 Int
c2 (M Int
n Int
m Int
ro Int
co Int
w Vector a
vs) = Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
forall a. Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
M Int
n Int
m Int
ro Int
co Int
w (Vector a -> Matrix a) -> Vector a -> Matrix a
forall a b. (a -> b) -> a -> b
$ (forall s. MVector s a -> ST s ()) -> Vector a -> Vector a
forall a.
(forall s. MVector s a -> ST s ()) -> Vector a -> Vector a
V.modify (\MVector s a
mv -> do
  Int -> Int -> (Int -> ST s ()) -> ST s ()
forall a (m :: * -> *).
(Num a, Ord a, Monad m) =>
a -> a -> (a -> m ()) -> m ()
numLoop Int
1 Int
n ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
j ->
    MVector (PrimState (ST s)) a -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> Int -> m ()
MV.swap MVector s a
MVector (PrimState (ST s)) a
mv (Int -> (Int, Int) -> Int
encode Int
m (Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
ro,Int
c1Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
co)) (Int -> (Int, Int) -> Int
encode Int
m (Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
ro,Int
c2Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
co))) Vector a
vs

-------------------------------------------------------
-------------------------------------------------------
---- DECOMPOSITIONS

-- LU DECOMPOSITION

-- | Matrix LU decomposition with /partial pivoting/.
--   The result for a matrix /M/ is given in the format /(U,L,P,d)/ where:
--
--   * /U/ is an upper triangular matrix.
--
--   * /L/ is an /unit/ lower triangular matrix.
--
--   * /P/ is a permutation matrix.
--
--   * /d/ is the determinant of /P/.
--
--   * /PM = LU/.
--
--   These properties are only guaranteed when the input matrix is invertible.
--   An additional property matches thanks to the strategy followed for pivoting:
--
--   * /L_(i,j)/ <= 1, for all /i,j/.
--
--   This follows from the maximal property of the selected pivots, which also
--   leads to a better numerical stability of the algorithm.
--
--   Example:
--
-- >          ( 1 2 0 )     ( 2 0  2 )   (   1 0 0 )   ( 0 0 1 )
-- >          ( 0 2 1 )     ( 0 2 -1 )   ( 1/2 1 0 )   ( 1 0 0 )
-- > luDecomp ( 2 0 2 ) = ( ( 0 0  2 ) , (   0 1 1 ) , ( 0 1 0 ) , 1 )
--
--   'Nothing' is returned if no LU decomposition exists.
luDecomp :: (Ord a, Fractional a) => Matrix a -> Maybe (Matrix a,Matrix a,Matrix a,a)
luDecomp :: forall a.
(Ord a, Fractional a) =>
Matrix a -> Maybe (Matrix a, Matrix a, Matrix a, a)
luDecomp Matrix a
a = Matrix a
-> Matrix a
-> Matrix a
-> a
-> Int
-> Int
-> Maybe (Matrix a, Matrix a, Matrix a, a)
forall a.
(Ord a, Fractional a) =>
Matrix a
-> Matrix a
-> Matrix a
-> a
-> Int
-> Int
-> Maybe (Matrix a, Matrix a, Matrix a, a)
recLUDecomp Matrix a
a Matrix a
i Matrix a
i a
1 Int
1 Int
n
 where
  i :: Matrix a
i = Int -> Matrix a
forall a. Num a => Int -> Matrix a
identity (Int -> Matrix a) -> Int -> Matrix a
forall a b. (a -> b) -> a -> b
$ Matrix a -> Int
forall a. Matrix a -> Int
nrows Matrix a
a
  n :: Int
n = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min (Matrix a -> Int
forall a. Matrix a -> Int
nrows Matrix a
a) (Matrix a -> Int
forall a. Matrix a -> Int
ncols Matrix a
a)

recLUDecomp ::  (Ord a, Fractional a)
            =>  Matrix a -- ^ U
            ->  Matrix a -- ^ L
            ->  Matrix a -- ^ P
            ->  a        -- ^ d
            ->  Int      -- ^ Current row
            ->  Int      -- ^ Total rows
            -> Maybe (Matrix a,Matrix a,Matrix a,a)
recLUDecomp :: forall a.
(Ord a, Fractional a) =>
Matrix a
-> Matrix a
-> Matrix a
-> a
-> Int
-> Int
-> Maybe (Matrix a, Matrix a, Matrix a, a)
recLUDecomp Matrix a
u Matrix a
l Matrix a
p a
d Int
k Int
n =
    if Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
n then (Matrix a, Matrix a, Matrix a, a)
-> Maybe (Matrix a, Matrix a, Matrix a, a)
forall a. a -> Maybe a
Just (Matrix a
u,Matrix a
l,Matrix a
p,a
d)
    else if a
ukk a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0 then Maybe (Matrix a, Matrix a, Matrix a, a)
forall a. Maybe a
Nothing
                     else Matrix a
-> Matrix a
-> Matrix a
-> a
-> Int
-> Int
-> Maybe (Matrix a, Matrix a, Matrix a, a)
forall a.
(Ord a, Fractional a) =>
Matrix a
-> Matrix a
-> Matrix a
-> a
-> Int
-> Int
-> Maybe (Matrix a, Matrix a, Matrix a, a)
recLUDecomp Matrix a
u'' Matrix a
l'' Matrix a
p' a
d' (Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Int
n
 where
  -- Pivot strategy: maximum value in absolute value below the current row.
  i :: Int
i  = (Int -> Int -> Ordering) -> [Int] -> Int
forall (t :: * -> *) a.
Foldable t =>
(a -> a -> Ordering) -> t a -> a
maximumBy (\Int
x Int
y -> a -> a -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (a -> a
forall a. Num a => a -> a
abs (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ Matrix a
u Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
! (Int
x,Int
k)) (a -> a
forall a. Num a => a -> a
abs (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ Matrix a
u Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
! (Int
y,Int
k))) [ Int
k .. Int
n ]
  -- Switching to place pivot in current row.
  u' :: Matrix a
u' = Int -> Int -> Matrix a -> Matrix a
forall a. Int -> Int -> Matrix a -> Matrix a
switchRows Int
k Int
i Matrix a
u
  l' :: Matrix a
l' = let lw :: Int
lw = Matrix a -> Int
forall a. Matrix a -> Int
vcols Matrix a
l
           en :: (Int, Int) -> Int
en = Int -> (Int, Int) -> Int
encode Int
lw
           lro :: Int
lro = Matrix a -> Int
forall a. Matrix a -> Int
rowOffset Matrix a
l
           lco :: Int
lco = Matrix a -> Int
forall a. Matrix a -> Int
colOffset Matrix a
l
       in  if Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
k
              then Matrix a
l
              else Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
forall a. Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
M (Matrix a -> Int
forall a. Matrix a -> Int
nrows Matrix a
l) (Matrix a -> Int
forall a. Matrix a -> Int
ncols Matrix a
l) Int
lro Int
lco Int
lw (Vector a -> Matrix a) -> Vector a -> Matrix a
forall a b. (a -> b) -> a -> b
$
                     (forall s. MVector s a -> ST s ()) -> Vector a -> Vector a
forall a.
(forall s. MVector s a -> ST s ()) -> Vector a -> Vector a
V.modify (\MVector s a
mv -> [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
1 .. Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$
                                 \Int
j -> MVector (PrimState (ST s)) a -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> Int -> m ()
MV.swap MVector s a
MVector (PrimState (ST s)) a
mv ((Int, Int) -> Int
en (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
lro,Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
lco))
                                                  ((Int, Int) -> Int
en (Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
lro,Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
lco))
                                ) (Vector a -> Vector a) -> Vector a -> Vector a
forall a b. (a -> b) -> a -> b
$ Matrix a -> Vector a
forall a. Matrix a -> Vector a
mvect Matrix a
l
  p' :: Matrix a
p' = Int -> Int -> Matrix a -> Matrix a
forall a. Int -> Int -> Matrix a -> Matrix a
switchRows Int
k Int
i Matrix a
p
  -- Permutation determinant
  d' :: a
d' = if Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
k then a
d else a -> a
forall a. Num a => a -> a
negate a
d
  -- Cancel elements below the pivot.
  (Matrix a
u'',Matrix a
l'') = Matrix a -> Matrix a -> Int -> (Matrix a, Matrix a)
go Matrix a
u' Matrix a
l' (Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
  ukk :: a
ukk = Matrix a
u' Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
! (Int
k,Int
k)
  go :: Matrix a -> Matrix a -> Int -> (Matrix a, Matrix a)
go Matrix a
u_ Matrix a
l_ Int
j =
    if Int
j Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Matrix a -> Int
forall a. Matrix a -> Int
nrows Matrix a
u_
    then (Matrix a
u_,Matrix a
l_)
    else let x :: a
x = (Matrix a
u_ Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
! (Int
j,Int
k)) a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
ukk
         in  Matrix a -> Matrix a -> Int -> (Matrix a, Matrix a)
go (Int -> a -> Int -> Matrix a -> Matrix a
forall a. Num a => Int -> a -> Int -> Matrix a -> Matrix a
combineRows Int
j (-a
x) Int
k Matrix a
u_) (a -> (Int, Int) -> Matrix a -> Matrix a
forall a. a -> (Int, Int) -> Matrix a -> Matrix a
setElem a
x (Int
j,Int
k) Matrix a
l_) (Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)

-- | Unsafe version of 'luDecomp'. It fails when the input matrix is singular.
luDecompUnsafe :: (Ord a, Fractional a) => Matrix a -> (Matrix a, Matrix a, Matrix a, a)
luDecompUnsafe :: forall a.
(Ord a, Fractional a) =>
Matrix a -> (Matrix a, Matrix a, Matrix a, a)
luDecompUnsafe Matrix a
m = case Matrix a -> Maybe (Matrix a, Matrix a, Matrix a, a)
forall a.
(Ord a, Fractional a) =>
Matrix a -> Maybe (Matrix a, Matrix a, Matrix a, a)
luDecomp Matrix a
m of
  Just (Matrix a, Matrix a, Matrix a, a)
x -> (Matrix a, Matrix a, Matrix a, a)
x
  Maybe (Matrix a, Matrix a, Matrix a, a)
_ -> String -> (Matrix a, Matrix a, Matrix a, a)
forall a. HasCallStack => String -> a
error String
"luDecompUnsafe of singular matrix."

-- | Matrix LU decomposition with /complete pivoting/.
--   The result for a matrix /M/ is given in the format /(U,L,P,Q,d,e)/ where:
--
--   * /U/ is an upper triangular matrix.
--
--   * /L/ is an /unit/ lower triangular matrix.
--
--   * /P,Q/ are permutation matrices.
--
--   * /d,e/ are the determinants of /P/ and /Q/ respectively.
--
--   * /PMQ = LU/.
--
--   These properties are only guaranteed when the input matrix is invertible.
--   An additional property matches thanks to the strategy followed for pivoting:
--
--   * /L_(i,j)/ <= 1, for all /i,j/.
--
--   This follows from the maximal property of the selected pivots, which also
--   leads to a better numerical stability of the algorithm.
--
--   Example:
--
-- >           ( 1 0 )     ( 2 1 )   (   1    0 0 )   ( 0 0 1 )
-- >           ( 0 2 )     ( 0 2 )   (   0    1 0 )   ( 0 1 0 )   ( 1 0 )
-- > luDecomp' ( 2 1 ) = ( ( 0 0 ) , ( 1/2 -1/4 1 ) , ( 1 0 0 ) , ( 0 1 ) , -1 , 1 )
--
--   'Nothing' is returned if no LU decomposition exists.
luDecomp' :: (Ord a, Fractional a) => Matrix a -> Maybe (Matrix a,Matrix a,Matrix a,Matrix a,a,a)
luDecomp' :: forall a.
(Ord a, Fractional a) =>
Matrix a -> Maybe (Matrix a, Matrix a, Matrix a, Matrix a, a, a)
luDecomp' Matrix a
a = Matrix a
-> Matrix a
-> Matrix a
-> Matrix a
-> a
-> a
-> Int
-> Int
-> Maybe (Matrix a, Matrix a, Matrix a, Matrix a, a, a)
forall a.
(Ord a, Fractional a) =>
Matrix a
-> Matrix a
-> Matrix a
-> Matrix a
-> a
-> a
-> Int
-> Int
-> Maybe (Matrix a, Matrix a, Matrix a, Matrix a, a, a)
recLUDecomp' Matrix a
a Matrix a
i Matrix a
i (Int -> Matrix a
forall a. Num a => Int -> Matrix a
identity (Int -> Matrix a) -> Int -> Matrix a
forall a b. (a -> b) -> a -> b
$ Matrix a -> Int
forall a. Matrix a -> Int
ncols Matrix a
a) a
1 a
1 Int
1 Int
n
 where
  i :: Matrix a
i = Int -> Matrix a
forall a. Num a => Int -> Matrix a
identity (Int -> Matrix a) -> Int -> Matrix a
forall a b. (a -> b) -> a -> b
$ Matrix a -> Int
forall a. Matrix a -> Int
nrows Matrix a
a
  n :: Int
n = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min (Matrix a -> Int
forall a. Matrix a -> Int
nrows Matrix a
a) (Matrix a -> Int
forall a. Matrix a -> Int
ncols Matrix a
a)

-- | Unsafe version of 'luDecomp''. It fails when the input matrix is singular.
luDecompUnsafe' :: (Ord a, Fractional a) => Matrix a -> (Matrix a, Matrix a, Matrix a, Matrix a, a, a)
luDecompUnsafe' :: forall a.
(Ord a, Fractional a) =>
Matrix a -> (Matrix a, Matrix a, Matrix a, Matrix a, a, a)
luDecompUnsafe' Matrix a
m = case Matrix a -> Maybe (Matrix a, Matrix a, Matrix a, Matrix a, a, a)
forall a.
(Ord a, Fractional a) =>
Matrix a -> Maybe (Matrix a, Matrix a, Matrix a, Matrix a, a, a)
luDecomp' Matrix a
m of
  Just (Matrix a, Matrix a, Matrix a, Matrix a, a, a)
x -> (Matrix a, Matrix a, Matrix a, Matrix a, a, a)
x
  Maybe (Matrix a, Matrix a, Matrix a, Matrix a, a, a)
_ -> String -> (Matrix a, Matrix a, Matrix a, Matrix a, a, a)
forall a. HasCallStack => String -> a
error String
"luDecompUnsafe' of singular matrix."

recLUDecomp' ::  (Ord a, Fractional a)
            =>  Matrix a -- ^ U
            ->  Matrix a -- ^ L
            ->  Matrix a -- ^ P
            ->  Matrix a -- ^ Q
            ->  a        -- ^ d
            ->  a        -- ^ e
            ->  Int      -- ^ Current row
            ->  Int      -- ^ Total rows
            ->  Maybe (Matrix a,Matrix a,Matrix a,Matrix a,a,a)
recLUDecomp' :: forall a.
(Ord a, Fractional a) =>
Matrix a
-> Matrix a
-> Matrix a
-> Matrix a
-> a
-> a
-> Int
-> Int
-> Maybe (Matrix a, Matrix a, Matrix a, Matrix a, a, a)
recLUDecomp' Matrix a
u Matrix a
l Matrix a
p Matrix a
q a
d a
e Int
k Int
n =
    if Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
n Bool -> Bool -> Bool
|| Matrix a
u'' Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
! (Int
k, Int
k) a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0
    then (Matrix a, Matrix a, Matrix a, Matrix a, a, a)
-> Maybe (Matrix a, Matrix a, Matrix a, Matrix a, a, a)
forall a. a -> Maybe a
Just (Matrix a
u,Matrix a
l,Matrix a
p,Matrix a
q,a
d,a
e)
    else if a
ukk a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0
            then Maybe (Matrix a, Matrix a, Matrix a, Matrix a, a, a)
forall a. Maybe a
Nothing
            else Matrix a
-> Matrix a
-> Matrix a
-> Matrix a
-> a
-> a
-> Int
-> Int
-> Maybe (Matrix a, Matrix a, Matrix a, Matrix a, a, a)
forall a.
(Ord a, Fractional a) =>
Matrix a
-> Matrix a
-> Matrix a
-> Matrix a
-> a
-> a
-> Int
-> Int
-> Maybe (Matrix a, Matrix a, Matrix a, Matrix a, a, a)
recLUDecomp' Matrix a
u'' Matrix a
l'' Matrix a
p' Matrix a
q' a
d' a
e' (Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Int
n
 where
  -- Pivot strategy: maximum value in absolute value below the current row & col.
  (Int
i, Int
j) = ((Int, Int) -> (Int, Int) -> Ordering)
-> [(Int, Int)] -> (Int, Int)
forall (t :: * -> *) a.
Foldable t =>
(a -> a -> Ordering) -> t a -> a
maximumBy (((Int, Int) -> a) -> (Int, Int) -> (Int, Int) -> Ordering
forall a b. Ord a => (b -> a) -> b -> b -> Ordering
comparing (\(Int
i0, Int
j0) -> a -> a
forall a. Num a => a -> a
abs (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ Matrix a
u Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
! (Int
i0,Int
j0)))
           [ (Int
i0, Int
j0) | Int
i0 <- [Int
k .. Matrix a -> Int
forall a. Matrix a -> Int
nrows Matrix a
u], Int
j0 <- [Int
k .. Matrix a -> Int
forall a. Matrix a -> Int
ncols Matrix a
u] ]
  -- Switching to place pivot in current row.
  u' :: Matrix a
u' = Int -> Int -> Matrix a -> Matrix a
forall a. Int -> Int -> Matrix a -> Matrix a
switchCols Int
k Int
j (Matrix a -> Matrix a) -> Matrix a -> Matrix a
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Matrix a -> Matrix a
forall a. Int -> Int -> Matrix a -> Matrix a
switchRows Int
k Int
i Matrix a
u
  l'0 :: Matrix a
l'0 = Int -> Int -> Matrix a -> Matrix a
forall a. Int -> Int -> Matrix a -> Matrix a
switchRows Int
k Int
i Matrix a
l
  l' :: Matrix a
l' = Int -> Int -> Matrix a -> Matrix a
forall a. Int -> Int -> Matrix a -> Matrix a
switchCols Int
k Int
i Matrix a
l'0
  p' :: Matrix a
p' = Int -> Int -> Matrix a -> Matrix a
forall a. Int -> Int -> Matrix a -> Matrix a
switchRows Int
k Int
i Matrix a
p
  q' :: Matrix a
q' = Int -> Int -> Matrix a -> Matrix a
forall a. Int -> Int -> Matrix a -> Matrix a
switchCols Int
k Int
j Matrix a
q
  -- Permutation determinant
  d' :: a
d' = if Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
k then a
d else a -> a
forall a. Num a => a -> a
negate a
d
  e' :: a
e' = if Int
j Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
k then a
e else a -> a
forall a. Num a => a -> a
negate a
e
  -- Cancel elements below the pivot.
  (Matrix a
u'',Matrix a
l'') = Matrix a -> Matrix a -> Int -> (Matrix a, Matrix a)
go Matrix a
u' Matrix a
l' (Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
  ukk :: a
ukk = Matrix a
u' Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
! (Int
k,Int
k)
  go :: Matrix a -> Matrix a -> Int -> (Matrix a, Matrix a)
go Matrix a
u_ Matrix a
l_ Int
h =
    if Int
h Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Matrix a -> Int
forall a. Matrix a -> Int
nrows Matrix a
u_
    then (Matrix a
u_,Matrix a
l_)
    else let x :: a
x = (Matrix a
u_ Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
! (Int
h,Int
k)) a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
ukk
         in  Matrix a -> Matrix a -> Int -> (Matrix a, Matrix a)
go (Int -> a -> Int -> Matrix a -> Matrix a
forall a. Num a => Int -> a -> Int -> Matrix a -> Matrix a
combineRows Int
h (-a
x) Int
k Matrix a
u_) (a -> (Int, Int) -> Matrix a -> Matrix a
forall a. a -> (Int, Int) -> Matrix a -> Matrix a
setElem a
x (Int
h,Int
k) Matrix a
l_) (Int
hInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)

-- CHOLESKY DECOMPOSITION

-- | Simple Cholesky decomposition of a symmetric, positive definite matrix.
--   The result for a matrix /M/ is a lower triangular matrix /L/ such that:
--
--   * /M = LL^T/.
--
--   Example:
--
-- >            (  2 -1  0 )   (  1.41  0     0    )
-- >            ( -1  2 -1 )   ( -0.70  1.22  0    )
-- > cholDecomp (  0 -1  2 ) = (  0.00 -0.81  1.15 )
cholDecomp :: (Floating a) => Matrix a -> Matrix a
cholDecomp :: forall a. Floating a => Matrix a -> Matrix a
cholDecomp Matrix a
a
        | (Matrix a -> Int
forall a. Matrix a -> Int
nrows Matrix a
a Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1) Bool -> Bool -> Bool
&& (Matrix a -> Int
forall a. Matrix a -> Int
ncols Matrix a
a Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1) = (a -> a) -> Matrix a -> Matrix a
forall a b. (a -> b) -> Matrix a -> Matrix b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Floating a => a -> a
sqrt Matrix a
a
        | Bool
otherwise = (Matrix a, Matrix a, Matrix a, Matrix a) -> Matrix a
forall a. (Matrix a, Matrix a, Matrix a, Matrix a) -> Matrix a
joinBlocks (Matrix a
l11, Matrix a
l12, Matrix a
l21, Matrix a
l22) where
    (Matrix a
a11, Matrix a
a12, Matrix a
a21, Matrix a
a22) = Int -> Int -> Matrix a -> (Matrix a, Matrix a, Matrix a, Matrix a)
forall a.
Int -> Int -> Matrix a -> (Matrix a, Matrix a, Matrix a, Matrix a)
splitBlocks Int
1 Int
1 Matrix a
a
    l11' :: a
l11' = a -> a
forall a. Floating a => a -> a
sqrt (Matrix a
a11 Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
! (Int
1,Int
1))
    l11 :: Matrix a
l11 = Int -> Int -> [a] -> Matrix a
forall a. Int -> Int -> [a] -> Matrix a
fromList Int
1 Int
1 [a
l11']
    l12 :: Matrix a
l12 = Int -> Int -> Matrix a
forall a. Num a => Int -> Int -> Matrix a
zero (Matrix a -> Int
forall a. Matrix a -> Int
nrows Matrix a
a12) (Matrix a -> Int
forall a. Matrix a -> Int
ncols Matrix a
a12)
    l21 :: Matrix a
l21 = a -> Matrix a -> Matrix a
forall a. Num a => a -> Matrix a -> Matrix a
scaleMatrix (a
1a -> a -> a
forall a. Fractional a => a -> a -> a
/a
l11') Matrix a
a21
    a22' :: Matrix a
a22' = Matrix a
a22 Matrix a -> Matrix a -> Matrix a
forall a. Num a => a -> a -> a
- Matrix a -> Matrix a -> Matrix a
forall a. Num a => Matrix a -> Matrix a -> Matrix a
multStd Matrix a
l21 (Matrix a -> Matrix a
forall a. Matrix a -> Matrix a
transpose Matrix a
l21)
    l22 :: Matrix a
l22 = Matrix a -> Matrix a
forall a. Floating a => Matrix a -> Matrix a
cholDecomp Matrix a
a22'

-------------------------------------------------------
-------------------------------------------------------
---- PROPERTIES

{-# RULES
"matrix/traceOfSum"
    forall a b. trace (a + b) = trace a + trace b

"matrix/traceOfScale"
    forall k a. trace (scaleMatrix k a) = k * trace a
  #-}

-- | Sum of the elements in the diagonal. See also 'getDiag'.
--   Example:
--
-- >       ( 1 2 3 )
-- >       ( 4 5 6 )
-- > trace ( 7 8 9 ) = 15
trace :: Num a => Matrix a -> a
trace :: forall a. Num a => Matrix a -> a
trace = Vector a -> a
forall a. Num a => Vector a -> a
V.sum (Vector a -> a) -> (Matrix a -> Vector a) -> Matrix a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix a -> Vector a
forall a. Matrix a -> Vector a
getDiag

-- | Product of the elements in the diagonal. See also 'getDiag'.
--   Example:
--
-- >          ( 1 2 3 )
-- >          ( 4 5 6 )
-- > diagProd ( 7 8 9 ) = 45
diagProd :: Num a => Matrix a -> a
diagProd :: forall a. Num a => Matrix a -> a
diagProd = Vector a -> a
forall a. Num a => Vector a -> a
V.product (Vector a -> a) -> (Matrix a -> Vector a) -> Matrix a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix a -> Vector a
forall a. Matrix a -> Vector a
getDiag

-- DETERMINANT

{-# RULES
"matrix/detLaplaceProduct"
    forall a b. detLaplace (a*b) = detLaplace a * detLaplace b

"matrix/detLUProduct"
    forall a b. detLU (a*b) = detLU a * detLU b
  #-}

-- | Matrix determinant using Laplace expansion.
--   If the elements of the 'Matrix' are instance of 'Ord' and 'Fractional'
--   consider to use 'detLU' in order to obtain better performance.
--   Function 'detLaplace' is /extremely/ slow.
detLaplace :: Num a => Matrix a -> a
detLaplace :: forall a. Num a => Matrix a -> a
detLaplace m :: Matrix a
m@(M Int
1 Int
1 Int
_ Int
_ Int
_ Vector a
_) = Matrix a
m Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
! (Int
1,Int
1)
detLaplace Matrix a
m = [a] -> a
sum1 [ (-a
1)a -> Int -> a
forall a b. (Num a, Integral b) => a -> b -> a
^(Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) a -> a -> a
forall a. Num a => a -> a -> a
* Matrix a
m Matrix a -> (Int, Int) -> a
forall a. Matrix a -> (Int, Int) -> a
! (Int
i,Int
1) a -> a -> a
forall a. Num a => a -> a -> a
* Matrix a -> a
forall a. Num a => Matrix a -> a
detLaplace (Int -> Int -> Matrix a -> Matrix a
forall a. Int -> Int -> Matrix a -> Matrix a
minorMatrix Int
i Int
1 Matrix a
m) | Int
i <- [Int
1 .. Matrix a -> Int
forall a. Matrix a -> Int
nrows Matrix a
m] ]
  where
    sum1 :: [a] -> a
sum1 = (a -> a -> a) -> [a] -> a
forall a. HasCallStack => (a -> a -> a) -> [a] -> a
foldl1' a -> a -> a
forall a. Num a => a -> a -> a
(+)

-- | Matrix determinant using LU decomposition.
--   It works even when the input matrix is singular.
detLU :: (Ord a, Fractional a) => Matrix a -> a
detLU :: forall a. (Ord a, Fractional a) => Matrix a -> a
detLU Matrix a
m = case Matrix a -> Maybe (Matrix a, Matrix a, Matrix a, a)
forall a.
(Ord a, Fractional a) =>
Matrix a -> Maybe (Matrix a, Matrix a, Matrix a, a)
luDecomp Matrix a
m of
  Just (Matrix a
u,Matrix a
_,Matrix a
_,a
d) -> a
d a -> a -> a
forall a. Num a => a -> a -> a
* Matrix a -> a
forall a. Num a => Matrix a -> a
diagProd Matrix a
u
  Maybe (Matrix a, Matrix a, Matrix a, a)
Nothing -> a
0