{- | Implementation of matrix.

Matrices are a backbone of any machine learning library,
since most of the operations are implemented by the matrices
combinations (matrix multiplication, elementwise operations).

'Mat' datatype provides interface for all of those operations.
-}


-- 'TypeFamilies' are needed to instantiate 'Indexable', 'ElementwiseScalarOps', 'SingletonOps', 'MatOps' typeclasses.

{-# LANGUAGE TypeFamilies #-}


module Synapse.Tensors.Mat
    ( --  * 'Mat' datatype and simple getters.


      Mat (nRows, nCols)

    , nElements
    , size

    , isTransposed
    , isSubmatrix

      -- * Utility


    , force
    , toLists

      -- * Constructors


    , empty
    , singleton
    , fromList
    , fromLists
    , generate
    , replicate

      -- * Vec operations


    , rowVec
    , colVec
    , fromVec

    , indexRow
    , indexCol
    , safeIndexRow
    , safeIndexCol

    , diagonal
    , flatten

     -- * Combining


    , map
    , mapRow
    , mapCol
    , for
    , imap
    , elementwise

      -- * Operations with matrices


    , setSize
    , extend
    , shrink

    , swapRows
    , swapCols
    , transpose

      -- * Submatrices


    , minor
    , submatrix
    , split
    , join
    , (<|>)
    , (<->)

      -- * Mathematics


    , zeroes
    , ones
    , identity

    , adamarMul
    , matMul

    , det
    , rref
    , inverse
    , orthogonalized
    ) where


import Synapse.Tensors (DType, Indexable(..), (!), ElementwiseScalarOps(..), SingletonOps(..), MatOps(..))

import Synapse.Tensors.Vec (Vec(Vec))
import qualified Synapse.Tensors.Vec as SV

import Prelude hiding (map, replicate, zip)
import Data.Foldable (Foldable(..))
import Data.List (find)
import Data.Tuple (swap)

import qualified Data.Vector as V


{- | Mathematical matrix (collection of elements).

This implementation focuses on sharing parts of matrices and clever indexation
to reduce overhead of several essential operations.
Those include splitting matrices into submatrices, transposing - their asymptotical complexity becomes O(1).
However there are few downsides:
the first is that severely splitted matrix is hard to garbage collect and is not cache-friendly
and the second is that mass traversal operations on those sparse matrices might not fuse and combine well.
'force' function and some functions that by their nature are forced address those issues,
but most of the time those problems are not significant enough and you are just better using convenient functions instead of workarounds.
-}
data Mat a = Mat
    { forall a. Mat a -> Int
nRows     :: {-# UNPACK #-} !Int         -- ^ Number of rows. 

    , forall a. Mat a -> Int
nCols     :: {-# UNPACK #-} !Int         -- ^ Number of columns.

    , forall a. Mat a -> Int
rowStride :: {-# UNPACK #-} !Int         -- ^ How much increasing row index affects true indexing.

    , forall a. Mat a -> Int
colStride :: {-# UNPACK #-} !Int         -- ^ How much increasing column index affects true indexing.

    , forall a. Mat a -> Int
rowOffset :: {-# UNPACK #-} !Int         -- ^ Row offset (from which row index does the matrix actually start).

    , forall a. Mat a -> Int
colOffset :: {-# UNPACK #-} !Int         -- ^ Column offset (from which column index does the matrix actually start).

    , forall a. Mat a -> Vector a
storage   ::                 V.Vector a  -- ^ Internal storage (elements are stored in a vector using row-major ordering).

    }

-- | Number of elements in a matrix.

nElements :: Mat a -> Int
nElements :: forall a. Mat a -> Int
nElements Mat a
mat = Mat a -> Int
forall a. Mat a -> Int
nRows Mat a
mat Int -> Int -> Int
forall a. Num a => a -> a -> a
* Mat a -> Int
forall a. Mat a -> Int
nCols Mat a
mat

-- | Size of matrix.

size :: Mat a -> (Int, Int)
size :: forall a. Mat a -> (Int, Int)
size Mat a
mat = (Mat a -> Int
forall a. Mat a -> Int
nRows Mat a
mat, Mat a -> Int
forall a. Mat a -> Int
nCols Mat a
mat)


-- | Returns whether the matrix is transposed. If the matrix consists of only one element, it is considered never transposed.

isTransposed :: Mat a -> Bool
isTransposed :: forall a. Mat a -> Bool
isTransposed Mat a
mat = Mat a -> Int
forall a. Mat a -> Int
colStride Mat a
mat Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
1 Bool -> Bool -> Bool
&& Mat a -> Int
forall a. Mat a -> Int
rowStride Mat a
mat Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1

-- | Returns whether the matrix is a submatrix from another matrix.

isSubmatrix :: Mat a -> Bool
isSubmatrix :: forall a. Mat a -> Bool
isSubmatrix Mat a
mat = (Int -> Bool) -> [Int] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Int
0 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/=) [Mat a -> Int
forall a. Mat a -> Int
rowOffset Mat a
mat, Mat a -> Int
forall a. Mat a -> Int
colOffset Mat a
mat]


-- | Converts two dimensional matrix index to one dimensional vector index.

indexMatToVec :: Mat a -> (Int, Int) -> Int
indexMatToVec :: forall a. Mat a -> (Int, Int) -> Int
indexMatToVec (Mat Int
_ Int
_ Int
rk Int
ck Int
r0 Int
c0 Vector a
_) (Int
r, Int
c) = (Int
r0 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
r) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
rk Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (Int
c0 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
c) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
ck

-- | Converts one dimensional vector index to two dimensional matrix index.

indexVecToMat :: Mat a -> Int -> (Int, Int)
indexVecToMat :: forall a. Mat a -> Int -> (Int, Int)
indexVecToMat mat :: Mat a
mat@(Mat Int
_ Int
_ Int
rk Int
ck Int
r0 Int
c0 Vector a
_) Int
i = let t :: Bool
t = Mat a -> Bool
forall a. Mat a -> Bool
isTransposed Mat a
mat
                                                  (Int
r', Int
c') = (if Bool
t then (Int, Int) -> (Int, Int)
forall a b. (a, b) -> (b, a)
swap else (Int, Int) -> (Int, Int)
forall a. a -> a
id) ((Int, Int) -> (Int, Int)) -> (Int, Int) -> (Int, Int)
forall a b. (a -> b) -> a -> b
$ Int -> Int -> (Int, Int)
forall a. Integral a => a -> a -> (a, a)
quotRem Int
i (if Bool
t then Int
ck else Int
rk)
                                              in (Int
r' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
r0, Int
c' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
c0)



-- | Copies matrix data dropping any extra memory that may be held if given matrix is a submatrix.

force :: Mat a -> Mat a
force :: forall a. Mat a -> Mat a
force Mat a
x = Int -> Int -> Int -> Int -> Int -> Int -> Vector a -> Mat a
forall a.
Int -> Int -> Int -> Int -> Int -> Int -> Vector a -> Mat a
Mat (Mat a -> Int
forall a. Mat a -> Int
nRows Mat a
x) (Mat a -> Int
forall a. Mat a -> Int
nCols Mat a
x) (Mat a -> Int
forall a. Mat a -> Int
nCols Mat a
x) Int
1 Int
0 Int
0 ([a] -> Vector a
forall a. [a] -> Vector a
V.fromList [Mat a -> Index (Mat a) -> DType (Mat a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex Mat a
x (Int
r, Int
c) | Int
r <- [Int
0 .. Mat a -> Int
forall a. Mat a -> Int
nRows Mat a
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1], Int
c <- [Int
0 .. Mat a -> Int
forall a. Mat a -> Int
nCols Mat a
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]])

-- | Converts matrix to list of lists.

toLists :: Mat a -> [[a]]
toLists :: forall a. Mat a -> [[a]]
toLists Mat a
x = [[Mat a -> Index (Mat a) -> DType (Mat a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex Mat a
x (Int
r, Int
c) | Int
c <- [Int
0 .. Mat a -> Int
forall a. Mat a -> Int
nCols Mat a
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ] | Int
r <- [Int
0 .. Mat a -> Int
forall a. Mat a -> Int
nRows Mat a
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]]


-- Typeclasses


instance Show a => Show (Mat a) where
    show :: Mat a -> String
show Mat a
mat = String
"(" String -> ShowS
forall a. [a] -> [a] -> [a]
++ (Int, Int) -> String
forall a. Show a => a -> String
show (Mat a -> (Int, Int)
forall a. Mat a -> (Int, Int)
size Mat a
mat) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"): " String -> ShowS
forall a. [a] -> [a] -> [a]
++ [[a]] -> String
forall a. Show a => a -> String
show (Mat a -> [[a]]
forall a. Mat a -> [[a]]
toLists Mat a
mat)


type instance DType (Mat a) = a


instance Indexable (Mat a) where
    type Index (Mat a) = (Int, Int)

    unsafeIndex :: Mat a -> Index (Mat a) -> DType (Mat a)
unsafeIndex Mat a
x (Int
r, Int
c) = Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.unsafeIndex (Mat a -> Vector a
forall a. Mat a -> Vector a
storage Mat a
x) (Mat a -> (Int, Int) -> Int
forall a. Mat a -> (Int, Int) -> Int
indexMatToVec Mat a
x (Int
r, Int
c))
    ! :: Mat a -> Index (Mat a) -> DType (Mat a)
(!) Mat a
x (Int
r, Int
c)
        | Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Mat a -> Int
forall a. Mat a -> Int
nRows Mat a
x Bool -> Bool -> Bool
|| Int
c Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
c Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Mat a -> Int
forall a. Mat a -> Int
nCols Mat a
x = String -> DType (Mat a)
forall a. HasCallStack => String -> a
error (String -> DType (Mat a)) -> String -> DType (Mat a)
forall a b. (a -> b) -> a -> b
$ String
"Index " String -> ShowS
forall a. [a] -> [a] -> [a]
++ (Int, Int) -> String
forall a. Show a => a -> String
show (Int
r, Int
c) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" is out of bounds for matrix with size " String -> ShowS
forall a. [a] -> [a] -> [a]
++ (Int, Int) -> String
forall a. Show a => a -> String
show (Mat a -> (Int, Int)
forall a. Mat a -> (Int, Int)
size Mat a
x)
        | Bool
otherwise                                      = Mat a -> Index (Mat a) -> DType (Mat a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex Mat a
x (Int
r, Int
c)
    !? :: Mat a -> Index (Mat a) -> Maybe (DType (Mat a))
(!?) Mat a
x (Int
r, Int
c)
        | Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Mat a -> Int
forall a. Mat a -> Int
nRows Mat a
x Bool -> Bool -> Bool
|| Int
c Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
c Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Mat a -> Int
forall a. Mat a -> Int
nCols Mat a
x = Maybe a
Maybe (DType (Mat a))
forall a. Maybe a
Nothing
        | Bool
otherwise                                      = DType (Mat a) -> Maybe (DType (Mat a))
forall a. a -> Maybe a
Just (DType (Mat a) -> Maybe (DType (Mat a)))
-> DType (Mat a) -> Maybe (DType (Mat a))
forall a b. (a -> b) -> a -> b
$ Mat a -> Index (Mat a) -> DType (Mat a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex Mat a
x (Int
r, Int
c)


instance Num a => Num (Mat a) where
    + :: Mat a -> Mat a -> Mat a
(+) = (a -> a -> a) -> Mat a -> Mat a -> Mat a
forall a b c. (a -> b -> c) -> Mat a -> Mat b -> Mat c
elementwise a -> a -> a
forall a. Num a => a -> a -> a
(+)
    (-) = (a -> a -> a) -> Mat a -> Mat a -> Mat a
forall a b c. (a -> b -> c) -> Mat a -> Mat b -> Mat c
elementwise (-)
    negate :: Mat a -> Mat a
negate = (a -> a) -> Mat a -> Mat a
forall a b. (a -> b) -> Mat a -> Mat b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a
0 a -> a -> a
forall a. Num a => a -> a -> a
-)
    * :: Mat a -> Mat a -> Mat a
(*) = Mat a -> Mat a -> Mat a
forall a. Num a => Mat a -> Mat a -> Mat a
adamarMul
    abs :: Mat a -> Mat a
abs = (a -> a) -> Mat a -> Mat a
forall a b. (a -> b) -> Mat a -> Mat b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Num a => a -> a
abs
    signum :: Mat a -> Mat a
signum = (a -> a) -> Mat a -> Mat a
forall a b. (a -> b) -> Mat a -> Mat b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Num a => a -> a
signum
    fromInteger :: Integer -> Mat a
fromInteger = a -> Mat a
DType (Mat a) -> Mat a
forall f. SingletonOps f => DType f -> f
singleton (a -> Mat a) -> (Integer -> a) -> Integer -> Mat a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> a
forall a. Num a => Integer -> a
fromInteger

instance Fractional a => Fractional (Mat a) where
    / :: Mat a -> Mat a -> Mat a
(/) = (a -> a -> a) -> Mat a -> Mat a -> Mat a
forall a b c. (a -> b -> c) -> Mat a -> Mat b -> Mat c
elementwise a -> a -> a
forall a. Fractional a => a -> a -> a
(/)
    recip :: Mat a -> Mat a
recip = (a -> a) -> Mat a -> Mat a
forall a b. (a -> b) -> Mat a -> Mat b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a
1a -> a -> a
forall a. Fractional a => a -> a -> a
/)
    fromRational :: Rational -> Mat a
fromRational = a -> Mat a
DType (Mat a) -> Mat a
forall f. SingletonOps f => DType f -> f
singleton (a -> Mat a) -> (Rational -> a) -> Rational -> Mat a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rational -> a
forall a. Fractional a => Rational -> a
fromRational

instance Floating a => Floating (Mat a) where
    pi :: Mat a
pi = DType (Mat a) -> Mat a
forall f. SingletonOps f => DType f -> f
singleton a
DType (Mat a)
forall a. Floating a => a
pi
    ** :: Mat a -> Mat a -> Mat a
(**) = (a -> a -> a) -> Mat a -> Mat a -> Mat a
forall a b c. (a -> b -> c) -> Mat a -> Mat b -> Mat c
elementwise a -> a -> a
forall a. Floating a => a -> a -> a
(**)
    sqrt :: Mat a -> Mat a
sqrt = (a -> a) -> Mat a -> Mat a
forall a b. (a -> b) -> Mat a -> Mat b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Floating a => a -> a
sqrt
    exp :: Mat a -> Mat a
exp = (a -> a) -> Mat a -> Mat a
forall a b. (a -> b) -> Mat a -> Mat b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Floating a => a -> a
exp
    log :: Mat a -> Mat a
log = (a -> a) -> Mat a -> Mat a
forall a b. (a -> b) -> Mat a -> Mat b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Floating a => a -> a
log
    sin :: Mat a -> Mat a
sin = (a -> a) -> Mat a -> Mat a
forall a b. (a -> b) -> Mat a -> Mat b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Floating a => a -> a
sin
    cos :: Mat a -> Mat a
cos = (a -> a) -> Mat a -> Mat a
forall a b. (a -> b) -> Mat a -> Mat b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Floating a => a -> a
cos
    asin :: Mat a -> Mat a
asin = (a -> a) -> Mat a -> Mat a
forall a b. (a -> b) -> Mat a -> Mat b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Floating a => a -> a
asin
    acos :: Mat a -> Mat a
acos = (a -> a) -> Mat a -> Mat a
forall a b. (a -> b) -> Mat a -> Mat b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Floating a => a -> a
acos
    atan :: Mat a -> Mat a
atan = (a -> a) -> Mat a -> Mat a
forall a b. (a -> b) -> Mat a -> Mat b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Floating a => a -> a
atan
    sinh :: Mat a -> Mat a
sinh = (a -> a) -> Mat a -> Mat a
forall a b. (a -> b) -> Mat a -> Mat b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Floating a => a -> a
sinh
    cosh :: Mat a -> Mat a
cosh = (a -> a) -> Mat a -> Mat a
forall a b. (a -> b) -> Mat a -> Mat b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Floating a => a -> a
cosh
    asinh :: Mat a -> Mat a
asinh = (a -> a) -> Mat a -> Mat a
forall a b. (a -> b) -> Mat a -> Mat b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Floating a => a -> a
asinh
    acosh :: Mat a -> Mat a
acosh = (a -> a) -> Mat a -> Mat a
forall a b. (a -> b) -> Mat a -> Mat b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Floating a => a -> a
acosh
    atanh :: Mat a -> Mat a
atanh = (a -> a) -> Mat a -> Mat a
forall a b. (a -> b) -> Mat a -> Mat b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Floating a => a -> a
atanh


instance ElementwiseScalarOps (Mat a) where
    +. :: Num (DType (Mat a)) => Mat a -> DType (Mat a) -> Mat a
(+.) Mat a
x DType (Mat a)
n = (a -> a) -> Mat a -> Mat a
forall a b. (a -> b) -> Mat a -> Mat b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (DType (Mat a) -> DType (Mat a) -> DType (Mat a)
forall a. Num a => a -> a -> a
+ DType (Mat a)
n) Mat a
x
    -. :: Num (DType (Mat a)) => Mat a -> DType (Mat a) -> Mat a
(-.) Mat a
x DType (Mat a)
n = (a -> a) -> Mat a -> Mat a
forall a b. (a -> b) -> Mat a -> Mat b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> a -> a
forall a. Num a => a -> a -> a
subtract a
DType (Mat a)
n) Mat a
x
    *. :: Num (DType (Mat a)) => Mat a -> DType (Mat a) -> Mat a
(*.) Mat a
x DType (Mat a)
n = (a -> a) -> Mat a -> Mat a
forall a b. (a -> b) -> Mat a -> Mat b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (DType (Mat a) -> DType (Mat a) -> DType (Mat a)
forall a. Num a => a -> a -> a
* DType (Mat a)
n) Mat a
x
    /. :: Fractional (DType (Mat a)) => Mat a -> DType (Mat a) -> Mat a
(/.) Mat a
x DType (Mat a)
n = (a -> a) -> Mat a -> Mat a
forall a b. (a -> b) -> Mat a -> Mat b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (DType (Mat a) -> DType (Mat a) -> DType (Mat a)
forall a. Fractional a => a -> a -> a
/ DType (Mat a)
n) Mat a
x
    **. :: Floating (DType (Mat a)) => Mat a -> DType (Mat a) -> Mat a
(**.) Mat a
x DType (Mat a)
n = (a -> a) -> Mat a -> Mat a
forall a b. (a -> b) -> Mat a -> Mat b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (DType (Mat a) -> DType (Mat a) -> DType (Mat a)
forall a. Floating a => a -> a -> a
** DType (Mat a)
n) Mat a
x

    elementsMin :: Ord (DType (Mat a)) => Mat a -> DType (Mat a) -> Mat a
elementsMin Mat a
x DType (Mat a)
n = (a -> a) -> Mat a -> Mat a
forall a b. (a -> b) -> Mat a -> Mat b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> a -> a
forall a. Ord a => a -> a -> a
min a
DType (Mat a)
n) Mat a
x
    elementsMax :: Ord (DType (Mat a)) => Mat a -> DType (Mat a) -> Mat a
elementsMax Mat a
x DType (Mat a)
n = (a -> a) -> Mat a -> Mat a
forall a b. (a -> b) -> Mat a -> Mat b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> a -> a
forall a. Ord a => a -> a -> a
max a
DType (Mat a)
n) Mat a
x

instance SingletonOps (Mat a) where
    singleton :: DType (Mat a) -> Mat a
singleton DType (Mat a)
x =  Int -> Int -> Int -> Int -> Int -> Int -> Vector a -> Mat a
forall a.
Int -> Int -> Int -> Int -> Int -> Int -> Vector a -> Mat a
Mat Int
1 Int
1 Int
1 Int
1 Int
0 Int
0 (a -> Vector a
forall a. a -> Vector a
V.singleton a
DType (Mat a)
x)
    isSingleton :: Mat a -> Bool
isSingleton Mat a
mat = Mat a -> Int
forall a. Mat a -> Int
nElements Mat a
mat Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1
    unSingleton :: Mat a -> DType (Mat a)
unSingleton Mat a
mat
        | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Mat a -> Bool
forall f. SingletonOps f => f -> Bool
isSingleton Mat a
mat = String -> a
forall a. HasCallStack => String -> a
error String
"Matrix is not a singleton"
        | Bool
otherwise             = Mat a -> Index (Mat a) -> DType (Mat a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex Mat a
mat (Int
0, Int
0)

    extendSingleton :: Mat a -> Mat a -> Mat a
extendSingleton Mat a
mat Mat a
reference = Mat a -> a -> (Int, Int) -> Mat a
forall a. Mat a -> a -> (Int, Int) -> Mat a
extend Mat a
mat (Mat a -> DType (Mat a)
forall f. SingletonOps f => f -> DType f
unSingleton Mat a
mat) (Mat a -> (Int, Int)
forall a. Mat a -> (Int, Int)
size Mat a
reference)

    elementsSum :: Num (DType (Mat a)) => Mat a -> Mat a
elementsSum mat :: Mat a
mat@(Mat Int
rows Int
cols Int
_ Int
_ Int
_ Int
_ Vector a
_) = DType (Mat a) -> Mat a
forall f. SingletonOps f => DType f -> f
singleton (DType (Mat a) -> Mat a) -> DType (Mat a) -> Mat a
forall a b. (a -> b) -> a -> b
$ [a] -> a
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Mat a -> Index (Mat a) -> DType (Mat a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex Mat a
mat (Int
r, Int
c) | Int
r <- [Int
0 .. Int
rows Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1], Int
c <- [Int
0 .. Int
cols Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]]
    elementsProduct :: Fractional (DType (Mat a)) => Mat a -> Mat a
elementsProduct mat :: Mat a
mat@(Mat Int
rows Int
cols Int
_ Int
_ Int
_ Int
_ Vector a
_) = DType (Mat a) -> Mat a
forall f. SingletonOps f => DType f -> f
singleton (DType (Mat a) -> Mat a) -> DType (Mat a) -> Mat a
forall a b. (a -> b) -> a -> b
$ [a] -> a
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Mat a -> Index (Mat a) -> DType (Mat a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex Mat a
mat (Int
r, Int
c) | Int
r <- [Int
0 .. Int
rows Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1], Int
c <- [Int
0 .. Int
cols Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]]
    mean :: Fractional (DType (Mat a)) => Mat a -> Mat a
mean Mat a
x = Mat a -> Mat a
forall f. (SingletonOps f, Num (DType f)) => f -> f
elementsSum Mat a
x Mat a -> DType (Mat a) -> Mat a
forall f.
(ElementwiseScalarOps f, Fractional (DType f)) =>
f -> DType f -> f
/. Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Mat a -> Int
forall a. Mat a -> Int
nElements Mat a
x)
    norm :: Floating (DType (Mat a)) => Mat a -> Mat a
norm Mat a
x = Mat a -> Mat a
forall a. Floating a => a -> a
sqrt (Mat a -> Mat a) -> Mat a -> Mat a
forall a b. (a -> b) -> a -> b
$ Mat a -> Mat a
forall f. (SingletonOps f, Num (DType f)) => f -> f
elementsSum (Mat a -> Mat a) -> Mat a -> Mat a
forall a b. (a -> b) -> a -> b
$ Mat a
x Mat a -> Mat a -> Mat a
forall a. Num a => a -> a -> a
* Mat a
x


instance Eq a => Eq (Mat a) where
    == :: Mat a -> Mat a -> Bool
(==) a :: Mat a
a@(Mat Int
rows1 Int
cols1 Int
_ Int
_ Int
_ Int
_ Vector a
_) b :: Mat a
b@(Mat Int
rows2 Int
cols2 Int
_ Int
_ Int
_ Int
_ Vector a
_)
        | Int
rows1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
rows2 Bool -> Bool -> Bool
|| Int
cols1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
cols2 = Bool
False
        | Bool
otherwise                        = [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and [Mat a -> Index (Mat a) -> DType (Mat a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex Mat a
a (Int
r, Int
c) a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== Mat a -> Index (Mat a) -> DType (Mat a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex Mat a
b (Int
r, Int
c) | Int
r <- [Int
0 .. Int
rows1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1], Int
c <- [Int
0 .. Int
cols1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]]


instance Functor Mat where
    fmap :: forall a b. (a -> b) -> Mat a -> Mat b
fmap a -> b
f (Mat Int
rows Int
cols Int
rk Int
ck Int
r0 Int
c0 Vector a
x) = Int -> Int -> Int -> Int -> Int -> Int -> Vector b -> Mat b
forall a.
Int -> Int -> Int -> Int -> Int -> Int -> Vector a -> Mat a
Mat Int
rows Int
cols Int
rk Int
ck Int
r0 Int
c0 ((a -> b) -> Vector a -> Vector b
forall a b. (a -> b) -> Vector a -> Vector b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f Vector a
x)
    <$ :: forall a b. a -> Mat b -> Mat a
(<$) = (b -> a) -> Mat b -> Mat a
forall a b. (a -> b) -> Mat a -> Mat b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((b -> a) -> Mat b -> Mat a)
-> (a -> b -> a) -> a -> Mat b -> Mat a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> b -> a
forall a b. a -> b -> a
const

instance Applicative Mat where
    pure :: forall a. a -> Mat a
pure a
x = Int -> Int -> Int -> Int -> Int -> Int -> Vector a -> Mat a
forall a.
Int -> Int -> Int -> Int -> Int -> Int -> Vector a -> Mat a
Mat Int
1 Int
1 Int
1 Int
1 Int
0 Int
0 (a -> Vector a
forall a. a -> Vector a
V.singleton a
x)
    <*> :: forall a b. Mat (a -> b) -> Mat a -> Mat b
(<*>) = ((a -> b) -> a -> b) -> Mat (a -> b) -> Mat a -> Mat b
forall a b c. (a -> b -> c) -> Mat a -> Mat b -> Mat c
elementwise (\a -> b
f a
x -> a -> b
f a
x)

instance Foldable Mat where
    foldr :: forall a b. (a -> b -> b) -> b -> Mat a -> b
foldr a -> b -> b
f b
x = (a -> b -> b) -> b -> Vector a -> b
forall a b. (a -> b -> b) -> b -> Vector a -> b
V.foldr a -> b -> b
f b
x (Vector a -> b) -> (Mat a -> Vector a) -> Mat a -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Mat a -> Vector a
forall a. Mat a -> Vector a
storage
    foldl :: forall b a. (b -> a -> b) -> b -> Mat a -> b
foldl b -> a -> b
f b
x = (b -> a -> b) -> b -> Vector a -> b
forall a b. (a -> b -> a) -> a -> Vector b -> a
V.foldl b -> a -> b
f b
x (Vector a -> b) -> (Mat a -> Vector a) -> Mat a -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Mat a -> Vector a
forall a. Mat a -> Vector a
storage

    foldr' :: forall a b. (a -> b -> b) -> b -> Mat a -> b
foldr' a -> b -> b
f b
x = (a -> b -> b) -> b -> Vector a -> b
forall a b. (a -> b -> b) -> b -> Vector a -> b
V.foldr' a -> b -> b
f b
x (Vector a -> b) -> (Mat a -> Vector a) -> Mat a -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Mat a -> Vector a
forall a. Mat a -> Vector a
storage
    foldl' :: forall b a. (b -> a -> b) -> b -> Mat a -> b
foldl' b -> a -> b
f b
x = (b -> a -> b) -> b -> Vector a -> b
forall a b. (a -> b -> a) -> a -> Vector b -> a
V.foldl' b -> a -> b
f b
x (Vector a -> b) -> (Mat a -> Vector a) -> Mat a -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Mat a -> Vector a
forall a. Mat a -> Vector a
storage

    foldr1 :: forall a. (a -> a -> a) -> Mat a -> a
foldr1 a -> a -> a
f = (a -> a -> a) -> Vector a -> a
forall a. (a -> a -> a) -> Vector a -> a
V.foldr1 a -> a -> a
f (Vector a -> a) -> (Mat a -> Vector a) -> Mat a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Mat a -> Vector a
forall a. Mat a -> Vector a
storage
    foldl1 :: forall a. (a -> a -> a) -> Mat a -> a
foldl1 a -> a -> a
f = (a -> a -> a) -> Vector a -> a
forall a. (a -> a -> a) -> Vector a -> a
V.foldl1 a -> a -> a
f (Vector a -> a) -> (Mat a -> Vector a) -> Mat a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Mat a -> Vector a
forall a. Mat a -> Vector a
storage

    toList :: forall a. Mat a -> [a]
toList = Vector a -> [a]
forall a. Vector a -> [a]
V.toList (Vector a -> [a]) -> (Mat a -> Vector a) -> Mat a -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Mat a -> Vector a
forall a. Mat a -> Vector a
storage

    null :: forall a. Mat a -> Bool
null Mat a
x = Mat a -> Int
forall a. Mat a -> Int
nElements Mat a
x Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0

    length :: forall a. Mat a -> Int
length = Mat a -> Int
forall a. Mat a -> Int
nElements

instance Traversable Mat where
    sequenceA :: forall (f :: * -> *) a. Applicative f => Mat (f a) -> f (Mat a)
sequenceA mat :: Mat (f a)
mat@(Mat Int
rows Int
cols Int
_ Int
_ Int
_ Int
_ Vector (f a)
_) = (Vector a -> Mat a) -> f (Vector a) -> f (Mat a)
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> Int -> Int -> Int -> Int -> Int -> Vector a -> Mat a
forall a.
Int -> Int -> Int -> Int -> Int -> Int -> Vector a -> Mat a
Mat Int
rows Int
cols Int
cols Int
1 Int
0 Int
0) (f (Vector a) -> f (Mat a))
-> (Mat (f a) -> f (Vector a)) -> Mat (f a) -> f (Mat a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector (f a) -> f (Vector a)
forall (t :: * -> *) (f :: * -> *) a.
(Traversable t, Applicative f) =>
t (f a) -> f (t a)
forall (f :: * -> *) a.
Applicative f =>
Vector (f a) -> f (Vector a)
sequenceA (Vector (f a) -> f (Vector a))
-> (Mat (f a) -> Vector (f a)) -> Mat (f a) -> f (Vector a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Mat (f a) -> Vector (f a)
forall a. Mat a -> Vector a
storage (Mat (f a) -> f (Mat a)) -> Mat (f a) -> f (Mat a)
forall a b. (a -> b) -> a -> b
$ Mat (f a) -> Mat (f a)
forall a. Mat a -> Mat a
force Mat (f a)
mat


-- Constructors


-- | Creates empty 'Mat'.

empty :: Mat a
empty :: forall a. Mat a
empty = Int -> Int -> Int -> Int -> Int -> Int -> Vector a -> Mat a
forall a.
Int -> Int -> Int -> Int -> Int -> Int -> Vector a -> Mat a
Mat Int
0 Int
0 Int
0 Int
0 Int
0 Int
0 Vector a
forall a. Vector a
V.empty

-- | Creates 'Mat' from list (will throw an error, if elements of that list do not form a matrix of given size).

fromList :: (Int, Int) -> [a] -> Mat a
fromList :: forall a. (Int, Int) -> [a] -> Mat a
fromList (Int
rows, Int
cols) [a]
xs
    | Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
rows Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
cols = String -> Mat a
forall a. HasCallStack => String -> a
error String
"Given dimensions do not match with list length"
    | Bool
otherwise                 = Int -> Int -> Int -> Int -> Int -> Int -> Vector a -> Mat a
forall a.
Int -> Int -> Int -> Int -> Int -> Int -> Vector a -> Mat a
Mat Int
rows Int
cols Int
cols Int
1 Int
0 Int
0 Vector a
m
  where
    m :: Vector a
m = [a] -> Vector a
forall a. [a] -> Vector a
V.fromList [a]
xs

-- | Creates 'Mat' from list of lists (alias for @fromLists (rows, cols) (concat xs)@).

fromLists :: (Int, Int) -> [[a]] -> Mat a
fromLists :: forall a. (Int, Int) -> [[a]] -> Mat a
fromLists (Int
rows, Int
cols) [[a]]
xs = (Int, Int) -> [a] -> Mat a
forall a. (Int, Int) -> [a] -> Mat a
fromList (Int
rows, Int
cols) ([[a]] -> [a]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[a]]
xs)

-- | Creates 'Mat' of given size using generating function.

generate :: (Int, Int) -> ((Int, Int) -> a) -> Mat a
generate :: forall a. (Int, Int) -> ((Int, Int) -> a) -> Mat a
generate (Int
rows, Int
cols) (Int, Int) -> a
f = Int -> Int -> Int -> Int -> Int -> Int -> Vector a -> Mat a
forall a.
Int -> Int -> Int -> Int -> Int -> Int -> Vector a -> Mat a
Mat Int
rows Int
cols Int
cols Int
1 Int
0 Int
0 (Int -> (Int -> a) -> Vector a
forall a. Int -> (Int -> a) -> Vector a
V.generate (Int
rows Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
cols) ((Int, Int) -> a
f ((Int, Int) -> a) -> (Int -> (Int, Int)) -> Int -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Int -> (Int, Int)) -> Int -> Int -> (Int, Int)
forall a b c. (a -> b -> c) -> b -> a -> c
flip Int -> Int -> (Int, Int)
forall a. Integral a => a -> a -> (a, a)
quotRem Int
cols))

-- | Creates 'Mat' of given size filled with given element.

replicate :: (Int, Int) -> a -> Mat a
replicate :: forall a. (Int, Int) -> a -> Mat a
replicate (Int
rows, Int
cols) a
x = Int -> Int -> Int -> Int -> Int -> Int -> Vector a -> Mat a
forall a.
Int -> Int -> Int -> Int -> Int -> Int -> Vector a -> Mat a
Mat Int
rows Int
cols Int
cols Int
1 Int
0 Int
0 (Int -> a -> Vector a
forall a. Int -> a -> Vector a
V.replicate (Int
rows Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
cols) a
x)


-- Vec operations


-- | Converts 'Synapse.Tensors.Vec.Vec' to a one row 'Mat'.

rowVec :: Vec a -> Mat a
rowVec :: forall a. Vec a -> Mat a
rowVec (Vec Vector a
x) = Int -> Int -> Int -> Int -> Int -> Int -> Vector a -> Mat a
forall a.
Int -> Int -> Int -> Int -> Int -> Int -> Vector a -> Mat a
Mat Int
1 (Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
x) (Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
x) Int
1 Int
0 Int
0 Vector a
x

-- | Converts 'Synapse.Tensors.Vec.Vec' to a one column 'Mat'.

colVec :: Vec a -> Mat a
colVec :: forall a. Vec a -> Mat a
colVec (Vec Vector a
x) = Int -> Int -> Int -> Int -> Int -> Int -> Vector a -> Mat a
forall a.
Int -> Int -> Int -> Int -> Int -> Int -> Vector a -> Mat a
Mat (Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
x) Int
1 Int
1 Int
1 Int
0 Int
0 Vector a
x

-- | Initializes 'Mat' from given 'Synapse.Tensors.Vec.Vec'.

fromVec :: (Int, Int) -> Vec a -> Mat a
fromVec :: forall a. (Int, Int) -> Vec a -> Mat a
fromVec (Int
rows, Int
cols) (Vec Vector a
x)
    | Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
x Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
rows Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
cols = String -> Mat a
forall a. HasCallStack => String -> a
error String
"Given dimensions do not match with vector length"
    | Bool
otherwise                 = Int -> Int -> Int -> Int -> Int -> Int -> Vector a -> Mat a
forall a.
Int -> Int -> Int -> Int -> Int -> Int -> Vector a -> Mat a
Mat Int
rows Int
cols Int
cols Int
1 Int
0 Int
0 Vector a
x


-- | Extracts row from 'Mat'. If row is not present, an error is thrown.

indexRow :: Mat a -> Int -> Vec a
indexRow :: forall a. Mat a -> Int -> Vec a
indexRow mat :: Mat a
mat@(Mat Int
rows Int
cols Int
_ Int
_ Int
_ Int
_ Vector a
x) Int
r
    | Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
rows = String -> Vec a
forall a. HasCallStack => String -> a
error String
"Given row is not present in the matrix"
    | Mat a -> Bool
forall a. Mat a -> Bool
isTransposed Mat a
mat   = Vector a -> Vec a
forall a. Vector a -> Vec a
Vec (Vector a -> Vec a) -> Vector a -> Vec a
forall a b. (a -> b) -> a -> b
$ [a] -> Vector a
forall a. [a] -> Vector a
V.fromList [Mat a -> Index (Mat a) -> DType (Mat a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex Mat a
mat (Int
r, Int
c) | Int
c <- [Int
0 .. Int
cols Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]]
    | Bool
otherwise          = Vector a -> Vec a
forall a. Vector a -> Vec a
Vec (Vector a -> Vec a) -> Vector a -> Vec a
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Vector a -> Vector a
forall a. Int -> Int -> Vector a -> Vector a
V.slice (Mat a -> (Int, Int) -> Int
forall a. Mat a -> (Int, Int) -> Int
indexMatToVec Mat a
mat (Int
r, Int
0)) Int
cols Vector a
x

-- | Extracts column from 'Mat'. If column is not present, an error is thrown.

indexCol :: Mat a -> Int -> Vec a
indexCol :: forall a. Mat a -> Int -> Vec a
indexCol mat :: Mat a
mat@(Mat Int
rows Int
cols Int
_ Int
_ Int
_ Int
_ Vector a
x) Int
c
    | Int
c Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
c Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
cols = String -> Vec a
forall a. HasCallStack => String -> a
error String
"Given column is not present in the matrix"
    | Mat a -> Bool
forall a. Mat a -> Bool
isTransposed Mat a
mat   = Vector a -> Vec a
forall a. Vector a -> Vec a
Vec (Vector a -> Vec a) -> Vector a -> Vec a
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Vector a -> Vector a
forall a. Int -> Int -> Vector a -> Vector a
V.slice (Mat a -> (Int, Int) -> Int
forall a. Mat a -> (Int, Int) -> Int
indexMatToVec Mat a
mat (Int
0, Int
c)) Int
rows Vector a
x
    | Bool
otherwise          = Vector a -> Vec a
forall a. Vector a -> Vec a
Vec (Vector a -> Vec a) -> Vector a -> Vec a
forall a b. (a -> b) -> a -> b
$ [a] -> Vector a
forall a. [a] -> Vector a
V.fromList [Mat a -> Index (Mat a) -> DType (Mat a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex Mat a
mat (Int
r, Int
c) | Int
r <- [Int
0 .. Int
rows Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]]

-- | Extracts row from 'Mat'.

safeIndexRow :: Mat a -> Int -> Maybe (Vec a)
safeIndexRow :: forall a. Mat a -> Int -> Maybe (Vec a)
safeIndexRow mat :: Mat a
mat@(Mat Int
rows Int
cols Int
_ Int
_ Int
_ Int
_ Vector a
x) Int
r
    | Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
rows = Maybe (Vec a)
forall a. Maybe a
Nothing
    | Mat a -> Bool
forall a. Mat a -> Bool
isTransposed Mat a
mat   = Vec a -> Maybe (Vec a)
forall a. a -> Maybe a
Just (Vec a -> Maybe (Vec a)) -> Vec a -> Maybe (Vec a)
forall a b. (a -> b) -> a -> b
$ Vector a -> Vec a
forall a. Vector a -> Vec a
Vec (Vector a -> Vec a) -> Vector a -> Vec a
forall a b. (a -> b) -> a -> b
$ [a] -> Vector a
forall a. [a] -> Vector a
V.fromList [Mat a -> Index (Mat a) -> DType (Mat a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex Mat a
mat (Int
r, Int
c) | Int
c <- [Int
0 .. Int
cols Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]]
    | Bool
otherwise          = Vec a -> Maybe (Vec a)
forall a. a -> Maybe a
Just (Vec a -> Maybe (Vec a)) -> Vec a -> Maybe (Vec a)
forall a b. (a -> b) -> a -> b
$ Vector a -> Vec a
forall a. Vector a -> Vec a
Vec (Vector a -> Vec a) -> Vector a -> Vec a
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Vector a -> Vector a
forall a. Int -> Int -> Vector a -> Vector a
V.slice (Mat a -> (Int, Int) -> Int
forall a. Mat a -> (Int, Int) -> Int
indexMatToVec Mat a
mat (Int
r, Int
0)) Int
cols Vector a
x

-- | Extracts column from 'Mat'.

safeIndexCol :: Mat a -> Int -> Maybe (Vec a)
safeIndexCol :: forall a. Mat a -> Int -> Maybe (Vec a)
safeIndexCol mat :: Mat a
mat@(Mat Int
rows Int
cols Int
_ Int
_ Int
_ Int
_ Vector a
x) Int
c
    | Int
c Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
c Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
cols = Maybe (Vec a)
forall a. Maybe a
Nothing
    | Mat a -> Bool
forall a. Mat a -> Bool
isTransposed Mat a
mat   = Vec a -> Maybe (Vec a)
forall a. a -> Maybe a
Just (Vec a -> Maybe (Vec a)) -> Vec a -> Maybe (Vec a)
forall a b. (a -> b) -> a -> b
$ Vector a -> Vec a
forall a. Vector a -> Vec a
Vec (Vector a -> Vec a) -> Vector a -> Vec a
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Vector a -> Vector a
forall a. Int -> Int -> Vector a -> Vector a
V.slice (Mat a -> (Int, Int) -> Int
forall a. Mat a -> (Int, Int) -> Int
indexMatToVec Mat a
mat (Int
0, Int
c)) Int
rows Vector a
x
    | Bool
otherwise          = Vec a -> Maybe (Vec a)
forall a. a -> Maybe a
Just (Vec a -> Maybe (Vec a)) -> Vec a -> Maybe (Vec a)
forall a b. (a -> b) -> a -> b
$ Vector a -> Vec a
forall a. Vector a -> Vec a
Vec (Vector a -> Vec a) -> Vector a -> Vec a
forall a b. (a -> b) -> a -> b
$ [a] -> Vector a
forall a. [a] -> Vector a
V.fromList [Mat a -> Index (Mat a) -> DType (Mat a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex Mat a
mat (Int
r, Int
c) | Int
r <- [Int
0 .. Int
rows Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]]


-- | Extracts diagonal from 'Mat'.

diagonal :: Mat a -> Vec a
diagonal :: forall a. Mat a -> Vec a
diagonal Mat a
x = Vector a -> Vec a
forall a. Vector a -> Vec a
Vec (Vector a -> Vec a) -> Vector a -> Vec a
forall a b. (a -> b) -> a -> b
$ [a] -> Vector a
forall a. [a] -> Vector a
V.fromList [Mat a -> Index (Mat a) -> DType (Mat a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex Mat a
x (Int
n, Int
n) | Int
n <- [Int
0 .. Int -> Int -> Int
forall a. Ord a => a -> a -> a
min (Mat a -> Int
forall a. Mat a -> Int
nRows Mat a
x) (Mat a -> Int
forall a. Mat a -> Int
nCols Mat a
x) Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]]

-- | Flattens 'Mat' to a 'Vec'.

flatten :: Mat a -> Vec a
flatten :: forall a. Mat a -> Vec a
flatten = Vector a -> Vec a
forall a. Vector a -> Vec a
Vec (Vector a -> Vec a) -> (Mat a -> Vector a) -> Mat a -> Vec a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Mat a -> Vector a
forall a. Mat a -> Vector a
storage (Mat a -> Vector a) -> (Mat a -> Mat a) -> Mat a -> Vector a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Mat a -> Mat a
forall a. Mat a -> Mat a
force


-- Combining


-- | Applies function to every element of 'Mat'.

map :: (a -> b) -> Mat a -> Mat b
map :: forall a b. (a -> b) -> Mat a -> Mat b
map = (a -> b) -> Mat a -> Mat b
forall a b. (a -> b) -> Mat a -> Mat b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap

-- | Applies function to a given row. If new 'Synapse.Tensors.Vec.Vec' is longer then кщц, it is truncated.

mapRow :: Int -> (Vec a -> Vec a) -> Mat a -> Mat a
mapRow :: forall a. Int -> (Vec a -> Vec a) -> Mat a -> Mat a
mapRow Int
row Vec a -> Vec a
f Mat a
mat = let newRow :: Vec a
newRow = Vec a -> Vec a
f (Vec a -> Vec a) -> Vec a -> Vec a
forall a b. (a -> b) -> a -> b
$ Mat a -> Int -> Vec a
forall a. Mat a -> Int -> Vec a
indexRow Mat a
mat Int
row
                   in ((Int, Int) -> a -> a) -> Mat a -> Mat a
forall a b. ((Int, Int) -> a -> b) -> Mat a -> Mat b
imap (\(Int
r, Int
c) a
x -> if Int
r Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
row then Vec a -> Index (Vec a) -> DType (Vec a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex Vec a
newRow Int
Index (Vec a)
c else a
x) Mat a
mat

-- | Applies function to a given column. If new 'Synapse.Tensors.Vec.Vec' is longer then column, it is truncated.

mapCol :: Int -> (Vec a -> Vec a) -> Mat a -> Mat a
mapCol :: forall a. Int -> (Vec a -> Vec a) -> Mat a -> Mat a
mapCol Int
col Vec a -> Vec a
f Mat a
mat = let newCol :: Vec a
newCol = Vec a -> Vec a
f (Vec a -> Vec a) -> Vec a -> Vec a
forall a b. (a -> b) -> a -> b
$ Mat a -> Int -> Vec a
forall a. Mat a -> Int -> Vec a
indexCol Mat a
mat Int
col
                   in ((Int, Int) -> a -> a) -> Mat a -> Mat a
forall a b. ((Int, Int) -> a -> b) -> Mat a -> Mat b
imap (\(Int
r, Int
c) a
x -> if Int
c Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
col then Vec a -> Index (Vec a) -> DType (Vec a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex Vec a
newCol Int
Index (Vec a)
r else a
x) Mat a
mat

-- | Flipped 'map'.

for :: Mat a -> (a -> b) -> Mat b
for :: forall a b. Mat a -> (a -> b) -> Mat b
for = ((a -> b) -> Mat a -> Mat b) -> Mat a -> (a -> b) -> Mat b
forall a b c. (a -> b -> c) -> b -> a -> c
flip (a -> b) -> Mat a -> Mat b
forall a b. (a -> b) -> Mat a -> Mat b
map

-- | Applies function to every element and its position of 'Mat'.

imap :: ((Int, Int) -> a -> b) -> Mat a -> Mat b
imap :: forall a b. ((Int, Int) -> a -> b) -> Mat a -> Mat b
imap (Int, Int) -> a -> b
f mat :: Mat a
mat@(Mat Int
rows Int
cols Int
rk Int
ck Int
r0 Int
c0 Vector a
x) = Int -> Int -> Int -> Int -> Int -> Int -> Vector b -> Mat b
forall a.
Int -> Int -> Int -> Int -> Int -> Int -> Vector a -> Mat a
Mat Int
rows Int
cols Int
rk Int
ck Int
r0 Int
c0 ((Int -> a -> b) -> Vector a -> Vector b
forall a b. (Int -> a -> b) -> Vector a -> Vector b
V.imap ((Int, Int) -> a -> b
f ((Int, Int) -> a -> b) -> (Int -> (Int, Int)) -> Int -> a -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Mat a -> Int -> (Int, Int)
forall a. Mat a -> Int -> (Int, Int)
indexVecToMat Mat a
mat) Vector a
x)

-- | Zips two 'Mat's together using given function.

elementwise :: (a -> b -> c) -> Mat a -> Mat b -> Mat c
elementwise :: forall a b c. (a -> b -> c) -> Mat a -> Mat b -> Mat c
elementwise a -> b -> c
f Mat a
a Mat b
b
    | Mat a -> (Int, Int)
forall a. Mat a -> (Int, Int)
size Mat a
a (Int, Int) -> (Int, Int) -> Bool
forall a. Eq a => a -> a -> Bool
/= Mat b -> (Int, Int)
forall a. Mat a -> (Int, Int)
size Mat b
b = String -> Mat c
forall a. HasCallStack => String -> a
error String
"Two matrices have different sizes"
    | Bool
otherwise        = let (Int
rows, Int
cols) = Mat a -> (Int, Int)
forall a. Mat a -> (Int, Int)
size Mat a
a
                         in Int -> Int -> Int -> Int -> Int -> Int -> Vector c -> Mat c
forall a.
Int -> Int -> Int -> Int -> Int -> Int -> Vector a -> Mat a
Mat Int
rows Int
cols Int
cols Int
1 Int
0 Int
0 (Vector c -> Mat c) -> Vector c -> Mat c
forall a b. (a -> b) -> a -> b
$
                            [c] -> Vector c
forall a. [a] -> Vector a
V.fromList [a -> b -> c
f (Mat a -> Index (Mat a) -> DType (Mat a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex Mat a
a (Int
r, Int
c)) (Mat b -> Index (Mat b) -> DType (Mat b)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex Mat b
b (Int
r, Int
c)) | Int
r <- [Int
0 .. Int
rows Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1], Int
c <- [Int
0 .. Int
cols Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]]


-- Operations with matrices


-- | Sets new size for a matrix relative to top left corner and uses given element for new entries if the matrix is extended.

setSize :: Mat a -> a -> (Int, Int) -> Mat a
setSize :: forall a. Mat a -> a -> (Int, Int) -> Mat a
setSize Mat a
mat a
x = ((Int, Int) -> ((Int, Int) -> a) -> Mat a)
-> ((Int, Int) -> a) -> (Int, Int) -> Mat a
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Int, Int) -> ((Int, Int) -> a) -> Mat a
forall a. (Int, Int) -> ((Int, Int) -> a) -> Mat a
generate (((Int, Int) -> a) -> (Int, Int) -> Mat a)
-> ((Int, Int) -> a) -> (Int, Int) -> Mat a
forall a b. (a -> b) -> a -> b
$ \(Int
r, Int
c) -> if Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Mat a -> Int
forall a. Mat a -> Int
nRows Mat a
mat Bool -> Bool -> Bool
&& Int
c Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Mat a -> Int
forall a. Mat a -> Int
nCols Mat a
mat then Mat a -> Index (Mat a) -> DType (Mat a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex Mat a
mat (Int
r, Int
c) else a
x

-- | Extends matrix size relative to top left corner using given element for new entries. The matrix is never reduced in size.

extend :: Mat a -> a -> (Int, Int) -> Mat a
extend :: forall a. Mat a -> a -> (Int, Int) -> Mat a
extend Mat a
mat a
x (Int
rows, Int
cols) = Mat a -> a -> (Int, Int) -> Mat a
forall a. Mat a -> a -> (Int, Int) -> Mat a
setSize Mat a
mat a
x (Int -> Int -> Int
forall a. Ord a => a -> a -> a
max (Mat a -> Int
forall a. Mat a -> Int
nRows Mat a
mat) Int
rows, Int -> Int -> Int
forall a. Ord a => a -> a -> a
max (Mat a -> Int
forall a. Mat a -> Int
nCols Mat a
mat) Int
cols)

-- | Shrinks matrix size relative to top left corner. The matrix is never extended in size.

shrink :: Mat a -> (Int, Int) -> Mat a
shrink :: forall a. Mat a -> (Int, Int) -> Mat a
shrink Mat a
mat (Int
rows, Int
cols) = Mat a -> a -> (Int, Int) -> Mat a
forall a. Mat a -> a -> (Int, Int) -> Mat a
setSize Mat a
mat a
forall a. HasCallStack => a
undefined (Int -> Int -> Int
forall a. Ord a => a -> a -> a
min (Mat a -> Int
forall a. Mat a -> Int
nRows Mat a
mat) Int
rows, Int -> Int -> Int
forall a. Ord a => a -> a -> a
min (Mat a -> Int
forall a. Mat a -> Int
nCols Mat a
mat) Int
cols)


-- | Swaps two rows.

swapRows :: Mat a -> Int -> Int -> Mat a
swapRows :: forall a. Mat a -> Int -> Int -> Mat a
swapRows mat :: Mat a
mat@(Mat Int
rows Int
cols Int
_ Int
_ Int
_ Int
_ Vector a
_) Int
row1 Int
row2
    | Int
row1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
row2 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
row1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
rows Bool -> Bool -> Bool
|| Int
row2 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
rows = String -> Mat a
forall a. HasCallStack => String -> a
error String
"Given row indices are out of bounds"
    | Bool
otherwise                                            = (Int, Int) -> ((Int, Int) -> a) -> Mat a
forall a. (Int, Int) -> ((Int, Int) -> a) -> Mat a
generate (Int
rows, Int
cols) (((Int, Int) -> a) -> Mat a) -> ((Int, Int) -> a) -> Mat a
forall a b. (a -> b) -> a -> b
$ \(Int
r, Int
c) -> if Int
r Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
row1 then Mat a -> Index (Mat a) -> DType (Mat a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex Mat a
mat (Int
row2, Int
c)
                                                                                                else if Int
r Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
row2 then Mat a -> Index (Mat a) -> DType (Mat a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex Mat a
mat (Int
row1, Int
c)
                                                                                                else Mat a -> Index (Mat a) -> DType (Mat a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex Mat a
mat (Int
r, Int
c)
-- | Swaps two columns.

swapCols :: Mat a -> Int -> Int -> Mat a
swapCols :: forall a. Mat a -> Int -> Int -> Mat a
swapCols mat :: Mat a
mat@(Mat Int
rows Int
cols Int
_ Int
_ Int
_ Int
_ Vector a
_) Int
col1 Int
col2
    | Int
col1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
col2 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
col1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
cols Bool -> Bool -> Bool
|| Int
col2 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
cols = String -> Mat a
forall a. HasCallStack => String -> a
error String
"Given column indices are out of bounds"
    | Bool
otherwise                                            = (Int, Int) -> ((Int, Int) -> a) -> Mat a
forall a. (Int, Int) -> ((Int, Int) -> a) -> Mat a
generate (Int
rows, Int
cols) (((Int, Int) -> a) -> Mat a) -> ((Int, Int) -> a) -> Mat a
forall a b. (a -> b) -> a -> b
$ \(Int
r, Int
c) -> if Int
c Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
col1 then Mat a -> Index (Mat a) -> DType (Mat a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex Mat a
mat (Int
r, Int
col2)
                                                                                                else if Int
c Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
col2 then Mat a -> Index (Mat a) -> DType (Mat a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex Mat a
mat (Int
r, Int
col1)
                                                                                                else Mat a -> Index (Mat a) -> DType (Mat a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex Mat a
mat (Int
r, Int
c)


-- Submatrices


-- | Extacts minor matrix, skipping given row and column.

minor :: Mat a -> (Int, Int) -> Mat a
minor :: forall a. Mat a -> (Int, Int) -> Mat a
minor mat :: Mat a
mat@(Mat Int
rows Int
cols Int
_ Int
_ Int
_ Int
_ Vector a
_) (Int
r', Int
c') = Int -> Int -> Int -> Int -> Int -> Int -> Vector a -> Mat a
forall a.
Int -> Int -> Int -> Int -> Int -> Int -> Vector a -> Mat a
Mat (Int
rows Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (Int
cols Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (Int
cols Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int
1 Int
0 Int
0 (Vector a -> Mat a) -> Vector a -> Mat a
forall a b. (a -> b) -> a -> b
$
                                             [a] -> Vector a
forall a. [a] -> Vector a
V.fromList [Mat a -> Index (Mat a) -> DType (Mat a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex Mat a
mat (Int
r, Int
c) | Int
r <- [Int
0 .. Int
rows Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1], Int
c <- [Int
0 .. Int
cols Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1], Int
r Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
r' Bool -> Bool -> Bool
&& Int
c Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
c']

-- | Extracts submatrix, that is located between given two positions.

submatrix :: Mat a -> ((Int, Int), (Int, Int)) -> Mat a
submatrix :: forall a. Mat a -> ((Int, Int), (Int, Int)) -> Mat a
submatrix (Mat Int
rows Int
cols Int
rk Int
ck Int
r0 Int
c0 Vector a
x) ((Int
r1, Int
c1), (Int
r2, Int
c2))
    | Int
r1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
c1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
r2 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
c2 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
||
      Int
r2 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
r1 Bool -> Bool -> Bool
|| Int
r2 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
rows Bool -> Bool -> Bool
|| Int
c2 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
c2 Bool -> Bool -> Bool
|| Int
c2 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
cols = String -> Mat a
forall a. HasCallStack => String -> a
error String
"Given row and column limits are incorrect"
    | Bool
otherwise                                    = Int -> Int -> Int -> Int -> Int -> Int -> Vector a -> Mat a
forall a.
Int -> Int -> Int -> Int -> Int -> Int -> Vector a -> Mat a
Mat (Int
r2 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
r1) (Int
c2 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
c1) Int
rk Int
ck (Int
r0 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
r1) (Int
c0 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
c1) Vector a
x

-- | Splits matrix into 4 parts, given position is a pivot, that corresponds to first element of bottom-right subpart.

split :: Mat a -> (Int, Int) -> (Mat a, Mat a, Mat a, Mat a)
split :: forall a. Mat a -> (Int, Int) -> (Mat a, Mat a, Mat a, Mat a)
split Mat a
x (Int
r, Int
c) = (Mat a -> ((Int, Int), (Int, Int)) -> Mat a
forall a. Mat a -> ((Int, Int), (Int, Int)) -> Mat a
submatrix Mat a
x ((Int
0, Int
0), (Int
r, Int
c)),       Mat a -> ((Int, Int), (Int, Int)) -> Mat a
forall a. Mat a -> ((Int, Int), (Int, Int)) -> Mat a
submatrix Mat a
x ((Int
0, Int
c), (Int
r, Mat a -> Int
forall a. Mat a -> Int
nCols Mat a
x)),
                  Mat a -> ((Int, Int), (Int, Int)) -> Mat a
forall a. Mat a -> ((Int, Int), (Int, Int)) -> Mat a
submatrix Mat a
x ((Int
r, Int
0), (Mat a -> Int
forall a. Mat a -> Int
nRows Mat a
x, Int
c)), Mat a -> ((Int, Int), (Int, Int)) -> Mat a
forall a. Mat a -> ((Int, Int), (Int, Int)) -> Mat a
submatrix Mat a
x ((Int
r, Int
c), (Mat a -> Int
forall a. Mat a -> Int
nRows Mat a
x, Mat a -> Int
forall a. Mat a -> Int
nCols Mat a
x)))

-- | Joins 4 blocks of matrices.

join :: (Mat a, Mat a, Mat a, Mat a) -> Mat a
join :: forall a. (Mat a, Mat a, Mat a, Mat a) -> Mat a
join (tl :: Mat a
tl@(Mat Int
rowsTL Int
colsTL Int
_ Int
_ Int
_ Int
_ Vector a
_), tr :: Mat a
tr@(Mat Int
rowsTR Int
colsTR Int
_ Int
_ Int
_ Int
_ Vector a
_),
      bl :: Mat a
bl@(Mat Int
rowsBL Int
colsBL Int
_ Int
_ Int
_ Int
_ Vector a
_), br :: Mat a
br@(Mat Int
rowsBR Int
colsBR Int
_ Int
_ Int
_ Int
_ Vector a
_))
    | Int
rowsTL Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
rowsTR Bool -> Bool -> Bool
|| Int
rowsBL Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
rowsBR Bool -> Bool -> Bool
||
      Int
colsTL Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
colsBL Bool -> Bool -> Bool
|| Int
colsTR Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
colsBR = String -> Mat a
forall a. HasCallStack => String -> a
error String
"Matrices dimensions do not match"
    | Bool
otherwise                            = (Int, Int) -> ((Int, Int) -> a) -> Mat a
forall a. (Int, Int) -> ((Int, Int) -> a) -> Mat a
generate (Int
rowsTL Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
rowsBL, Int
colsTL Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
colsTR) (((Int, Int) -> a) -> Mat a) -> ((Int, Int) -> a) -> Mat a
forall a b. (a -> b) -> a -> b
$
                                             \(Int
r, Int
c) -> (Mat a -> (Int, Int) -> a) -> (Mat a, (Int, Int)) -> a
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Mat a -> (Int, Int) -> a
Mat a -> Index (Mat a) -> DType (Mat a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex ((Mat a, (Int, Int)) -> a) -> (Mat a, (Int, Int)) -> a
forall a b. (a -> b) -> a -> b
$
                                                        if Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
rowsTL Bool -> Bool -> Bool
&& Int
c Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
colsTL then (Mat a
br, (Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
rowsTL, Int
c Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
colsTL))
                                                        else if Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
rowsTL           then (Mat a
bl, (Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
rowsTL, Int
c))
                                                        else if Int
c Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
colsTL           then (Mat a
tr, (Int
r, Int
c Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
colsTL))
                                                        else                               (Mat a
tl, (Int
r, Int
c))

infixl 9 <|>
-- | Joins two matrices horizontally.

(<|>) :: Mat a -> Mat a -> Mat a
<|> :: forall a. Mat a -> Mat a -> Mat a
(<|>) a :: Mat a
a@(Mat Int
rows1 Int
cols1 Int
_ Int
_ Int
_ Int
_ Vector a
_) b :: Mat a
b@(Mat Int
rows2 Int
cols2 Int
_ Int
_ Int
_ Int
_ Vector a
_)
    | Int
rows1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
rows2 = String -> Mat a
forall a. HasCallStack => String -> a
error String
"Given matrices must have the same number of rows"
    | Bool
otherwise      = (Int, Int) -> ((Int, Int) -> a) -> Mat a
forall a. (Int, Int) -> ((Int, Int) -> a) -> Mat a
generate (Int
rows1, Int
cols1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
cols2) (((Int, Int) -> a) -> Mat a) -> ((Int, Int) -> a) -> Mat a
forall a b. (a -> b) -> a -> b
$ \(Int
r, Int
c) -> if Int
c Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
cols1
                                                                    then Mat a -> Index (Mat a) -> DType (Mat a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex Mat a
a (Int
r, Int
c)
                                                                    else Mat a -> Index (Mat a) -> DType (Mat a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex Mat a
b (Int
r, Int
c Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
cols1)

infixl 9 <->
-- | Joins two matrices vertically.

(<->) :: Mat a -> Mat a -> Mat a
<-> :: forall a. Mat a -> Mat a -> Mat a
(<->) a :: Mat a
a@(Mat Int
rows1 Int
cols1 Int
_ Int
_ Int
_ Int
_ Vector a
_) b :: Mat a
b@(Mat Int
rows2 Int
cols2 Int
_ Int
_ Int
_ Int
_ Vector a
_)
    | Int
cols1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
cols2 = String -> Mat a
forall a. HasCallStack => String -> a
error String
"Given matrices must have the same number of columns"
    | Bool
otherwise      = (Int, Int) -> ((Int, Int) -> a) -> Mat a
forall a. (Int, Int) -> ((Int, Int) -> a) -> Mat a
generate (Int
rows1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
rows2, Int
cols1) (((Int, Int) -> a) -> Mat a) -> ((Int, Int) -> a) -> Mat a
forall a b. (a -> b) -> a -> b
$ \(Int
r, Int
c) -> if Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
rows1
                                                                    then Mat a -> Index (Mat a) -> DType (Mat a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex Mat a
a (Int
r, Int
c)
                                                                    else Mat a -> Index (Mat a) -> DType (Mat a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex Mat a
b (Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
rows1, Int
c)


-- Functions that work on mathematical matrix (type constraint refers to a number)


-- | Creates 'Mat' that is filled with zeroes.

zeroes :: Num a => (Int, Int) -> Mat a
zeroes :: forall a. Num a => (Int, Int) -> Mat a
zeroes = ((Int, Int) -> a -> Mat a) -> a -> (Int, Int) -> Mat a
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Int, Int) -> a -> Mat a
forall a. (Int, Int) -> a -> Mat a
replicate a
0

-- | Creates 'Mat' that is filled with ones.

ones :: Num a => (Int, Int) -> Mat a
ones :: forall a. Num a => (Int, Int) -> Mat a
ones = ((Int, Int) -> a -> Mat a) -> a -> (Int, Int) -> Mat a
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Int, Int) -> a -> Mat a
forall a. (Int, Int) -> a -> Mat a
replicate a
1

-- | Creates identity matrix.

identity :: Num a => Int -> Mat a
identity :: forall a. Num a => Int -> Mat a
identity Int
n = (Int, Int) -> ((Int, Int) -> a) -> Mat a
forall a. (Int, Int) -> ((Int, Int) -> a) -> Mat a
generate (Int
n, Int
n) (((Int, Int) -> a) -> Mat a) -> ((Int, Int) -> a) -> Mat a
forall a b. (a -> b) -> a -> b
$ \(Int
r, Int
c) -> if Int
r Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
c then a
1 else a
0


-- | Adamar multiplication (elementwise multiplication).

adamarMul :: Num a => Mat a -> Mat a -> Mat a
adamarMul :: forall a. Num a => Mat a -> Mat a -> Mat a
adamarMul = (a -> a -> a) -> Mat a -> Mat a -> Mat a
forall a b c. (a -> b -> c) -> Mat a -> Mat b -> Mat c
elementwise a -> a -> a
forall a. Num a => a -> a -> a
(*)

instance Num a => MatOps (Mat a) where
    transpose :: Mat a -> Mat a
transpose (Mat Int
rows Int
cols Int
rk Int
ck Int
r0 Int
c0 Vector a
x) = Int -> Int -> Int -> Int -> Int -> Int -> Vector a -> Mat a
forall a.
Int -> Int -> Int -> Int -> Int -> Int -> Vector a -> Mat a
Mat Int
cols Int
rows Int
ck Int
rk Int
c0 Int
r0 Vector a
x
    addMatRow :: Num (DType (Mat a)) => Mat a -> Mat a -> Mat a
addMatRow Mat a
mat Mat a
row
        | Mat a -> Int
forall a. Mat a -> Int
nRows Mat a
row Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
1         = String -> Mat a
forall a. HasCallStack => String -> a
error String
"Given row matrix is not a row"
        | Mat a -> Int
forall a. Mat a -> Int
nCols Mat a
row Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Mat a -> Int
forall a. Mat a -> Int
nCols Mat a
mat = String -> Mat a
forall a. HasCallStack => String -> a
error String
"Number of columns does not match"
        | Bool
otherwise              = ((Int, Int) -> a -> a) -> Mat a -> Mat a
forall a b. ((Int, Int) -> a -> b) -> Mat a -> Mat b
imap (\(Int
_, Int
c) -> (a -> a -> a
forall a. Num a => a -> a -> a
+ Mat a -> Index (Mat a) -> DType (Mat a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex Mat a
row (Int
0, Int
c))) Mat a
mat
    matMul :: Num (DType (Mat a)) => Mat a -> Mat a -> Mat a
matMul a :: Mat a
a@(Mat Int
rows1 Int
cols1 Int
_ Int
_ Int
_ Int
_ Vector a
_) b :: Mat a
b@(Mat Int
rows2 Int
cols2 Int
_ Int
_ Int
_ Int
_ Vector a
_)
        | Int
cols1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
rows2 = String -> Mat a
forall a. HasCallStack => String -> a
error String
"Matrices dimensions do not match"
        | Bool
otherwise      = (Int, Int) -> ((Int, Int) -> a) -> Mat a
forall a. (Int, Int) -> ((Int, Int) -> a) -> Mat a
generate (Int
rows1, Int
cols2) (((Int, Int) -> a) -> Mat a) -> ((Int, Int) -> a) -> Mat a
forall a b. (a -> b) -> a -> b
$ \(Int
r, Int
c) -> Vec a -> DType (Vec a)
forall f. SingletonOps f => f -> DType f
unSingleton (Vec a -> DType (Vec a)) -> Vec a -> DType (Vec a)
forall a b. (a -> b) -> a -> b
$ Mat a -> Int -> Vec a
forall a. Mat a -> Int -> Vec a
indexRow Mat a
a Int
r Vec a -> Vec a -> Vec a
forall f. (VecOps f, Num (DType f)) => f -> f -> f
`SV.dot` Mat a -> Int -> Vec a
forall a. Mat a -> Int -> Vec a
indexCol Mat a
b Int
c

-- | Determinant of a square matrix. If matrix is empty, zero is returned.

det :: Num a => Mat a -> a
det :: forall a. Num a => Mat a -> a
det mat :: Mat a
mat@(Mat Int
rows Int
cols Int
_ Int
_ Int
_ Int
_ Vector a
_)
    | Int
rows Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
cols       = String -> a
forall a. HasCallStack => String -> a
error String
"Matrix is not square"
    | Mat a -> Int
forall a. Mat a -> Int
nElements Mat a
mat Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = a
0
    | Mat a -> Int
forall a. Mat a -> Int
nElements Mat a
mat Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 = Mat a -> Index (Mat a) -> DType (Mat a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex Mat a
mat (Int
0, Int
0)
    | Bool
otherwise          = let mat' :: Mat a
mat' = if Mat a -> Bool
forall a. Mat a -> Bool
isTransposed Mat a
mat then Mat a -> Mat a
forall f. MatOps f => f -> f
transpose Mat a
mat else Mat a
mat

                               determinant :: Num a => Mat a -> a
                               determinant :: forall a. Num a => Mat a -> a
determinant x :: Mat a
x@(Mat Int
2 Int
_ Int
_ Int
_ Int
_ Int
_ Vector a
_) = Mat a -> Index (Mat a) -> DType (Mat a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex Mat a
x (Int
0, Int
0) a -> a -> a
forall a. Num a => a -> a -> a
* Mat a -> Index (Mat a) -> DType (Mat a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex Mat a
x (Int
1, Int
1) a -> a -> a
forall a. Num a => a -> a -> a
-
                                                                   Mat a -> Index (Mat a) -> DType (Mat a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex Mat a
x (Int
1, Int
0) a -> a -> a
forall a. Num a => a -> a -> a
* Mat a -> Index (Mat a) -> DType (Mat a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex Mat a
x (Int
0, Int
1)
                               determinant Mat a
x                     = [a] -> a
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [((-a
1) a -> Int -> a
forall a b. (Num a, Integral b) => a -> b -> a
^ Int
i) a -> a -> a
forall a. Num a => a -> a -> a
* (Mat a -> Int -> Vec a
forall a. Mat a -> Int -> Vec a
indexRow Mat a
x Int
0 Vec a -> Index (Vec a) -> DType (Vec a)
forall f. Indexable f => f -> Index f -> DType f
! Int
Index (Vec a)
i) a -> a -> a
forall a. Num a => a -> a -> a
* Mat a -> a
forall a. Num a => Mat a -> a
determinant (Mat a -> (Int, Int) -> Mat a
forall a. Mat a -> (Int, Int) -> Mat a
minor Mat a
x (Int
0, Int
i))
                                                                      | Int
i <- [Int
0 .. Mat a -> Int
forall a. Mat a -> Int
nRows Mat a
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]]
                           in Mat a -> a
forall a. Num a => Mat a -> a
determinant Mat a
mat'

-- | Row reduced echelon form of matrix.

rref :: (Eq a, Fractional a) => Mat a -> Mat a
rref :: forall a. (Eq a, Fractional a) => Mat a -> Mat a
rref mat :: Mat a
mat@(Mat Int
rows Int
cols Int
_ Int
_ Int
_ Int
_ Vector a
_) = Mat a -> Int -> [Int] -> Mat a
go Mat a
mat Int
0 [Int
0 .. Int
rows Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
  where
    go :: Mat a -> Int -> [Int] -> Mat a
go Mat a
m Int
_ [] = Mat a
m
    go Mat a
m Int
lead (Int
r:[Int]
rs) = case ((Int, Int) -> Bool) -> [(Int, Int)] -> Maybe (Int, Int)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((a
0 a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/=) (a -> Bool) -> ((Int, Int) -> a) -> (Int, Int) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Mat a -> Index (Mat a) -> DType (Mat a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex Mat a
m) [(Int
i, Int
j) | Int
j <- [Int
lead .. Int
cols Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1], Int
i <- [Int
r .. Int
rows Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]] of
                           Maybe (Int, Int)
Nothing                -> Mat a
m
                           Just (Int
pivotRow, Int
lead') -> let newRow :: Vec a
newRow = (a -> a) -> Vec a -> Vec a
forall a b. (a -> b) -> Vec a -> Vec b
SV.map (a -> a -> a
forall a. Fractional a => a -> a -> a
/ Mat a -> Index (Mat a) -> DType (Mat a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex Mat a
m (Int
pivotRow, Int
lead')) (Mat a -> Int -> Vec a
forall a. Mat a -> Int -> Vec a
indexRow Mat a
m Int
pivotRow)
                                                         m' :: Mat a
m'   = Mat a -> Int -> Int -> Mat a
forall a. Mat a -> Int -> Int -> Mat a
swapRows Mat a
m Int
pivotRow Int
r
                                                         m'' :: Mat a
m''  = Int -> (Vec a -> Vec a) -> Mat a -> Mat a
forall a. Int -> (Vec a -> Vec a) -> Mat a -> Mat a
mapRow Int
r (Vec a -> Vec a -> Vec a
forall a b. a -> b -> a
const Vec a
newRow) Mat a
m'
                                                         m''' :: Mat a
m''' = ((Int, Int) -> a -> a) -> Mat a -> Mat a
forall a b. ((Int, Int) -> a -> b) -> Mat a -> Mat b
imap (\(Int
row, Int
c) -> if Int
row Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
r
                                                                                   then a -> a
forall a. a -> a
id
                                                                                   else a -> a -> a
forall a. Num a => a -> a -> a
subtract (Vec a
newRow Vec a -> Index (Vec a) -> DType (Vec a)
forall f. Indexable f => f -> Index f -> DType f
! Int
Index (Vec a)
c a -> a -> a
forall a. Num a => a -> a -> a
* Mat a -> Index (Mat a) -> DType (Mat a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex Mat a
m'' (Int
row, Int
lead'))
                                                                     ) Mat a
m''
                                                     in Mat a -> Int -> [Int] -> Mat a
go Mat a
m''' (Int
lead' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [Int]
rs

-- | Inverse of a square matrix. If given matrix is empty, empty matrix is returned.

inverse :: (Eq a, Fractional a) => Mat a -> Maybe (Mat a)
inverse :: forall a. (Eq a, Fractional a) => Mat a -> Maybe (Mat a)
inverse mat :: Mat a
mat@(Mat Int
rows Int
cols Int
_ Int
_ Int
_ Int
_ Vector a
_)
    | Int
rows Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
cols       = String -> Maybe (Mat a)
forall a. HasCallStack => String -> a
error String
"Matrix is not square"
    | Mat a -> Int
forall a. Mat a -> Int
nElements Mat a
mat Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Mat a -> Maybe (Mat a)
forall a. a -> Maybe a
Just Mat a
forall a. Mat a
empty
    | Bool
otherwise          = let mat' :: Mat a
mat' = Mat a
mat Mat a -> Mat a -> Mat a
forall a. Mat a -> Mat a -> Mat a
<|> Int -> Mat a
forall a. Num a => Int -> Mat a
identity Int
rows
                               reduced :: Mat a
reduced = Mat a -> Mat a
forall a. (Eq a, Fractional a) => Mat a -> Mat a
rref Mat a
mat'
                               (Mat a
left, Mat a
right, Mat a
_, Mat a
_) = Mat a -> (Int, Int) -> (Mat a, Mat a, Mat a, Mat a)
forall a. Mat a -> (Int, Int) -> (Mat a, Mat a, Mat a, Mat a)
split Mat a
reduced (Int
rows, Int
cols)
                           in case (a -> Bool) -> Vector a -> Maybe a
forall a. (a -> Bool) -> Vector a -> Maybe a
V.find (a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0) (Vec a -> Vector a
forall a. Vec a -> Vector a
SV.unVec (Vec a -> Vector a) -> Vec a -> Vector a
forall a b. (a -> b) -> a -> b
$ Mat a -> Vec a
forall a. Mat a -> Vec a
diagonal Mat a
left) of
                                  Maybe a
Nothing -> Mat a -> Maybe (Mat a)
forall a. a -> Maybe a
Just Mat a
right
                                  Just a
_  -> Maybe (Mat a)
forall a. Maybe a
Nothing

-- | Orthogonalizes matrix by rows using Gram-Schmidt algorithm.

orthogonalized :: Floating a => Mat a -> Mat a
orthogonalized :: forall a. Floating a => Mat a -> Mat a
orthogonalized Mat a
mat = (Mat a -> Int -> Mat a) -> Mat a -> [Int] -> Mat a
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\Mat a
mat' Int
row -> Int -> (Vec a -> Vec a) -> Mat a -> Mat a
forall a. Int -> (Vec a -> Vec a) -> Mat a -> Mat a
mapRow Int
row Vec a -> Vec a
forall a. Floating a => Vec a -> Vec a
SV.normalized (Mat a -> Mat a) -> Mat a -> Mat a
forall a b. (a -> b) -> a -> b
$ Mat a -> Int -> Mat a
forall {a}. Num a => Mat a -> Int -> Mat a
iterationGramSchmidt Mat a
mat' Int
row) Mat a
mat [Int
0 .. Mat a -> Int
forall a. Mat a -> Int
nRows Mat a
mat]
  where
    iterationGramSchmidt :: Mat a -> Int -> Mat a
iterationGramSchmidt Mat a
mat' Int
row = (Mat a -> Int -> Mat a) -> Mat a -> [Int] -> Mat a
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\Mat a
mat'' Int
r ->
                                            Int -> (Vec a -> Vec a) -> Mat a -> Mat a
forall a. Int -> (Vec a -> Vec a) -> Mat a -> Mat a
mapRow Int
row (Vec a -> Vec a -> Vec a
forall a. Num a => a -> a -> a
subtract (Vec a -> Vec a -> Vec a) -> Vec a -> Vec a -> Vec a
forall a b. (a -> b) -> a -> b
$ Mat a -> Int -> Vec a
forall a. Mat a -> Int -> Vec a
indexRow Mat a
mat'' Int
r Vec a -> DType (Vec a) -> Vec a
forall f.
(ElementwiseScalarOps f, Num (DType f)) =>
f -> DType f -> f
*. Vec a -> DType (Vec a)
forall f. SingletonOps f => f -> DType f
unSingleton (Mat a -> Int -> Vec a
forall a. Mat a -> Int -> Vec a
indexRow Mat a
mat'' Int
r Vec a -> Vec a -> Vec a
forall f. (VecOps f, Num (DType f)) => f -> f -> f
`SV.dot` Mat a -> Int -> Vec a
forall a. Mat a -> Int -> Vec a
indexRow Mat a
mat'' Int
row)) Mat a
mat''
                                           ) Mat a
mat' [Int
0 .. Int
row]