{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Numeric.BLAS.Matrix.RowMajor (
   Matrix,
   Square,
   Vector,
   height, width,
   Array2.singleRow, Array2.flattenRow,
   Array2.singleColumn, Array2.flattenColumn,
   identity,
   takeRow,
   takeColumn,
   fromRows,
   above,
   beside,
   takeTop, takeBottom,
   takeLeft, takeRight,
   tensorProduct,
   decomplex,
   recomplex,
   scaleRows,
   scaleColumns,
   multiplyVectorLeft,
   multiplyVectorRight,
   Transposable(..), nonTransposed, transposed,
   transposeTransposable,
   multiply,
   multiplyTransposable,
   kronecker,
   kroneckerTransposable,
   kroneckerLeftTransposable,
   ) where

import qualified Numeric.BLAS.Private as Private
import Numeric.BLAS.Matrix.Modifier (Conjugation(NonConjugated,Conjugated))
import Numeric.BLAS.Scalar (zero, one)
import Numeric.BLAS.Private (ShapeInt, shapeInt, ComplexShape, pointerSeq, fill)

import qualified Numeric.BLAS.FFI.Generic as Blas
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class

import Foreign.Marshal.Array (copyArray, advancePtr)
import Foreign.ForeignPtr (ForeignPtr, withForeignPtr, castForeignPtr)
import Foreign.Storable (Storable, poke)

import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)
import Control.Applicative (liftA2)

import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Storable.Dim2 as Array2
import qualified Data.Array.Comfort.Shape.SubSize as SubSize
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))
import Data.Array.Comfort.Shape ((::+))

import Data.Foldable (forM_)
import Data.Complex (Complex)
import Data.Tuple.HT (swap)


{- $setup
>>> import Test.NumberModule.Type (Number_)
>>> import Test.NumberModule.Numeric.BLAS.Vector (forVector, genVector, number_)
>>> import Test.Slice (ShapeInt, shapeInt)
>>> import qualified Numeric.BLAS.Matrix.RowMajor as Matrix
>>> import qualified Numeric.BLAS.Vector as Vector
>>> import qualified Numeric.Netlib.Class as Class
>>> import Numeric.BLAS.Scalar (RealOf)
>>> import qualified Data.Array.Comfort.Storable as Array
>>> import qualified Data.Array.Comfort.Shape as Shape
>>> import qualified Test.QuickCheck as QC
>>>
>>> type Matrix = Matrix.Matrix (Shape.ZeroBased Int) (Shape.ZeroBased Int)
>>> type Real_ = RealOf Number_
>>>
>>> maxDim :: Int
>>> maxDim = 10
>>>
>>> forMatrix ::
>>>    (QC.Testable prop, QC.Arbitrary a, Class.Floating a, Show a) =>
>>>    QC.Gen a -> (Matrix a -> prop) -> QC.Property
>>> forMatrix genElem =
>>>    QC.forAll
>>>       (do height <- fmap shapeInt $ QC.choose (0,maxDim)
>>>           width <- fmap shapeInt $ QC.choose (0,maxDim)
>>>           genVector (height, width) genElem)
>>>
>>> genIdentityTrans ::
>>>    (Shape.C sh, Class.Floating a) =>
>>>    sh -> QC.Gen (Matrix.Transposable sh sh a)
>>> genIdentityTrans sh = do
>>>    trans <- QC.arbitrary
>>>    return $
>>>       if trans
>>>          then Matrix.transposed (Matrix.identity sh)
>>>          else Matrix.nonTransposed (Matrix.identity sh)
>>>
>>> transpose ::
>>>    (Shape.C height, Eq height, Shape.C width, Class.Floating a) =>
>>>    Matrix.Matrix height width a -> Matrix.Matrix width height a
>>> transpose a =
>>>    Matrix.multiplyTransposable
>>>       (Matrix.transposed a)
>>>       (Matrix.nonTransposed (Matrix.identity (Matrix.height a)))
-}


type Matrix height width = Array (height,width)
{- |
There is also 'Shape.Square'
but this would be incompatible with other matrix operations.
This might be addressed in a new Matrix.Square module.
But for advanced type hacks you can already use the @lapack@ package.
-}
type Square sh = Matrix sh sh
type Vector = Array


height :: Matrix height width a -> height
height :: forall height width a. Matrix height width a -> height
height = (height, width) -> height
forall a b. (a, b) -> a
fst ((height, width) -> height)
-> (Matrix height width a -> (height, width))
-> Matrix height width a
-> height
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix height width a -> (height, width)
forall sh a. Array sh a -> sh
Array.shape

width :: Matrix height width a -> width
width :: forall height width a. Matrix height width a -> width
width = (height, width) -> width
forall a b. (a, b) -> b
snd ((height, width) -> width)
-> (Matrix height width a -> (height, width))
-> Matrix height width a
-> width
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix height width a -> (height, width)
forall sh a. Array sh a -> sh
Array.shape


{- |
>>> Matrix.identity (Shape.ZeroBased 0) :: Matrix.Square (Shape.ZeroBased Int) Real_
StorableArray.fromList (ZeroBased {... 0},ZeroBased {... 0}) []
>>> Matrix.identity (Shape.ZeroBased 3) :: Matrix.Square (Shape.ZeroBased Int) Real_
StorableArray.fromList (ZeroBased {... 3},ZeroBased {... 3}) [1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0]
-}
identity :: (Shape.C sh, Class.Floating a) => sh -> Square sh a
identity :: forall sh a. (C sh, Floating a) => sh -> Square sh a
identity sh
sh =
   (sh, sh)
-> (Sub (Atom sh, Atom sh) -> Ptr a -> IO ()) -> Array (sh, sh) a
forall sh nsize a.
(C sh, sh ~ ToShape nsize, C nsize, Storable a) =>
sh -> (nsize -> Ptr a -> IO ()) -> Array sh a
Array.unsafeCreateWithAutoSizes (sh
sh,sh
sh) ((Sub (Atom sh, Atom sh) -> Ptr a -> IO ()) -> Array (sh, sh) a)
-> (Sub (Atom sh, Atom sh) -> Ptr a -> IO ()) -> Array (sh, sh) a
forall a b. (a -> b) -> a -> b
$
      \(SubSize.Sub Int
blockSize (SubSize.Atom Int
nint, SubSize.Atom Int
_)) Ptr a
yPtr ->
   ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
      Ptr CInt
nPtr <- FortranIO () (Ptr CInt)
forall a r. Storable a => FortranIO r (Ptr a)
Call.alloca
      Ptr a
xPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
zero
      Ptr CInt
incxPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
0
      Ptr CInt
incyPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
      IO () -> ContT () IO ()
forall a. IO a -> ContT () IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$ do
         Ptr CInt -> CInt -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr CInt
nPtr (CInt -> IO ()) -> CInt -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
blockSize
         Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
Blas.copy Ptr CInt
nPtr Ptr a
xPtr Ptr CInt
incxPtr Ptr a
yPtr Ptr CInt
incyPtr
         let n :: CInt
n = Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nint
         Ptr CInt -> CInt -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr CInt
nPtr CInt
n
         Ptr a -> a -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr a
xPtr a
forall a. Floating a => a
one
         Ptr CInt -> CInt -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr CInt
incyPtr (CInt
nCInt -> CInt -> CInt
forall a. Num a => a -> a -> a
+CInt
1)
         Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
Blas.copy Ptr CInt
nPtr Ptr a
xPtr Ptr CInt
incxPtr Ptr a
yPtr Ptr CInt
incyPtr


takeRow ::
   (Shape.Indexed height, Shape.C width, Shape.Index height ~ ix,
    Storable a) =>
   ix -> Matrix height width a -> Vector width a
takeRow :: forall height width ix a.
(Indexed height, C width, Index height ~ ix, Storable a) =>
ix -> Matrix height width a -> Vector width a
takeRow ix
ix (Array (height
height_,width
width_) ForeignPtr a
x) =
   width -> (Int -> Ptr a -> IO ()) -> Array width a
forall sh a.
(C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO ()) -> Array sh a
Array.unsafeCreateWithSize width
width_ ((Int -> Ptr a -> IO ()) -> Array width a)
-> (Int -> Ptr a -> IO ()) -> Array width a
forall a b. (a -> b) -> a -> b
$ \Int
n Ptr a
yPtr ->
   ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
x ((Ptr a -> IO ()) -> IO ()) -> (Ptr a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
xPtr ->
      Ptr a -> Ptr a -> Int -> IO ()
forall a. Storable a => Ptr a -> Ptr a -> Int -> IO ()
copyArray Ptr a
yPtr (Ptr a -> Int -> Ptr a
forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Ptr a
xPtr (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* height -> Index height -> Int
forall sh. Indexed sh => sh -> Index sh -> Int
Shape.offset height
height_ ix
Index height
ix)) Int
n

takeColumn ::
   (Shape.C height, Shape.Indexed width, Shape.Index width ~ ix,
    Class.Floating a) =>
   ix -> Matrix height width a -> Vector height a
takeColumn :: forall height width ix a.
(C height, Indexed width, Index width ~ ix, Floating a) =>
ix -> Matrix height width a -> Vector height a
takeColumn ix
ix (Array (height
height_,width
width_) ForeignPtr a
x) =
   height -> (Int -> Ptr a -> IO ()) -> Array height a
forall sh a.
(C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO ()) -> Array sh a
Array.unsafeCreateWithSize height
height_ ((Int -> Ptr a -> IO ()) -> Array height a)
-> (Int -> Ptr a -> IO ()) -> Array height a
forall a b. (a -> b) -> a -> b
$ \Int
n Ptr a
yPtr -> ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
      let offset :: Int
offset = width -> Index width -> Int
forall sh. Indexed sh => sh -> Index sh -> Int
Shape.offset width
width_ ix
Index width
ix
      Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
      Ptr a
xPtr <- ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
x
      Ptr CInt
incxPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint (Int -> FortranIO () (Ptr CInt)) -> Int -> FortranIO () (Ptr CInt)
forall a b. (a -> b) -> a -> b
$ width -> Int
forall sh. C sh => sh -> Int
Shape.size width
width_
      Ptr CInt
incyPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
      IO () -> ContT () IO ()
forall a. IO a -> ContT () IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$ Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
Blas.copy Ptr CInt
nPtr (Ptr a -> Int -> Ptr a
forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Ptr a
xPtr Int
offset) Ptr CInt
incxPtr Ptr a
yPtr Ptr CInt
incyPtr


fromRows ::
   (Shape.C width, Eq width, Storable a) =>
   width -> [Vector width a] -> Matrix ShapeInt width a
fromRows :: forall width a.
(C width, Eq width, Storable a) =>
width -> [Vector width a] -> Matrix ShapeInt width a
fromRows width
width_ [Vector width a]
rows =
   (ShapeInt, width)
-> ((Atom ShapeInt, Atom width) -> Ptr a -> IO ())
-> Array (ShapeInt, width) a
forall sh nsize a.
(C sh, sh ~ ToShape nsize, C nsize, Storable a) =>
sh -> (nsize -> Ptr a -> IO ()) -> Array sh a
Array.unsafeCreateWithAutoSizes (Int -> ShapeInt
shapeInt (Int -> ShapeInt) -> Int -> ShapeInt
forall a b. (a -> b) -> a -> b
$ [Vector width a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Vector width a]
rows, width
width_) (((Atom ShapeInt, Atom width) -> Ptr a -> IO ())
 -> Array (ShapeInt, width) a)
-> ((Atom ShapeInt, Atom width) -> Ptr a -> IO ())
-> Array (ShapeInt, width) a
forall a b. (a -> b) -> a -> b
$
         \(SubSize.Atom Int
_, SubSize.Atom Int
widthSize) Ptr a
dstPtr ->

      [(Ptr a, Vector width a)]
-> ((Ptr a, Vector width a) -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Ptr a] -> [Vector width a] -> [(Ptr a, Vector width a)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Int -> Ptr a -> [Ptr a]
forall a. Storable a => Int -> Ptr a -> [Ptr a]
pointerSeq Int
widthSize Ptr a
dstPtr) [Vector width a]
rows) (((Ptr a, Vector width a) -> IO ()) -> IO ())
-> ((Ptr a, Vector width a) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$
         \(Ptr a
dstRowPtr, Array.Array width
rowWidth ForeignPtr a
srcFPtr) ->
         ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
srcFPtr ((Ptr a -> IO ()) -> IO ()) -> (Ptr a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
srcPtr -> do
            String -> Bool -> IO ()
Call.assert
               String
"Matrix.fromRows: non-matching vector size"
               (width
width_ width -> width -> Bool
forall a. Eq a => a -> a -> Bool
== width
rowWidth)
            Ptr a -> Ptr a -> Int -> IO ()
forall a. Storable a => Ptr a -> Ptr a -> Int -> IO ()
copyArray Ptr a
dstRowPtr Ptr a
srcPtr Int
widthSize

infixr 2 `above`
infixr 3 `beside`

above ::
   (Shape.C heightA, Shape.C heightB) =>
   (Shape.C width, Eq width) =>
   (Storable a) =>
   Matrix heightA width a ->
   Matrix heightB width a ->
   Matrix (heightA::+heightB) width a
above :: forall heightA heightB width a.
(C heightA, C heightB, C width, Eq width, Storable a) =>
Matrix heightA width a
-> Matrix heightB width a -> Matrix (heightA ::+ heightB) width a
above = Array2 heightA width a
-> Array2 heightB width a -> Array2 (heightA ::+ heightB) width a
forall heightA heightB width a.
(C heightA, C heightB, C width, Eq width, Storable a) =>
Matrix heightA width a
-> Matrix heightB width a -> Matrix (heightA ::+ heightB) width a
Array2.above

beside ::
   (Shape.C widthA, Shape.C widthB) =>
   (Shape.C height, Eq height) =>
   (Storable a) =>
   Matrix height widthA a ->
   Matrix height widthB a ->
   Matrix height (widthA::+widthB) a
beside :: forall widthA widthB height a.
(C widthA, C widthB, C height, Eq height, Storable a) =>
Matrix height widthA a
-> Matrix height widthB a -> Matrix height (widthA ::+ widthB) a
beside = Array2 height widthA a
-> Array2 height widthB a -> Array2 height (widthA ::+ widthB) a
forall height widthA widthB a.
(C height, Eq height, C widthA, C widthB, Storable a) =>
Array2 height widthA a
-> Array2 height widthB a -> Array2 height (widthA ::+ widthB) a
Array2.beside

takeTop ::
   (Shape.C heightA, Shape.C heightB, Shape.C width, Storable a) =>
   Matrix (heightA::+heightB) width a ->
   Matrix heightA width a
takeTop :: forall heightA heightB width a.
(C heightA, C heightB, C width, Storable a) =>
Matrix (heightA ::+ heightB) width a -> Matrix heightA width a
takeTop = Array2 (heightA ::+ heightB) width a -> Array2 heightA width a
forall heightA heightB width a.
(C heightA, C heightB, C width, Storable a) =>
Matrix (heightA ::+ heightB) width a -> Matrix heightA width a
Array2.takeTop

takeBottom ::
   (Shape.C heightA, Shape.C heightB, Shape.C width, Storable a) =>
   Matrix (heightA::+heightB) width a ->
   Matrix heightB width a
takeBottom :: forall heightA heightB width a.
(C heightA, C heightB, C width, Storable a) =>
Matrix (heightA ::+ heightB) width a -> Matrix heightB width a
takeBottom = Array2 (heightA ::+ heightB) width a -> Array2 heightB width a
forall heightA heightB width a.
(C heightA, C heightB, C width, Storable a) =>
Matrix (heightA ::+ heightB) width a -> Matrix heightB width a
Array2.takeBottom

takeLeft ::
   (Shape.C height, Shape.C widthA, Shape.C widthB, Storable a) =>
   Matrix height (widthA::+widthB) a ->
   Matrix height widthA a
takeLeft :: forall height widthA widthB a.
(C height, C widthA, C widthB, Storable a) =>
Matrix height (widthA ::+ widthB) a -> Matrix height widthA a
takeLeft = Array2 height (widthA ::+ widthB) a -> Array2 height widthA a
forall height widthA widthB a.
(C height, C widthA, C widthB, Storable a) =>
Matrix height (widthA ::+ widthB) a -> Matrix height widthA a
Array2.takeLeft

takeRight ::
   (Shape.C height, Shape.C widthA, Shape.C widthB, Storable a) =>
   Matrix height (widthA::+widthB) a ->
   Matrix height widthB a
takeRight :: forall height widthA widthB a.
(C height, C widthA, C widthB, Storable a) =>
Matrix height (widthA ::+ widthB) a -> Matrix height widthB a
takeRight = Array2 height (widthA ::+ widthB) a -> Array2 height widthB a
forall height widthA widthB a.
(C height, C widthA, C widthB, Storable a) =>
Matrix height (widthA ::+ widthB) a -> Matrix height widthB a
Array2.takeRight


{-# WARNING tensorProduct "Don't use conjugation. Left and Right are swapped." #-}
tensorProduct ::
   (Shape.C height, Shape.C width, Class.Floating a) =>
   Either Conjugation Conjugation ->
   Vector height a -> Vector width a -> Matrix height width a
tensorProduct :: forall height width a.
(C height, C width, Floating a) =>
Either Conjugation Conjugation
-> Vector height a -> Vector width a -> Matrix height width a
tensorProduct Either Conjugation Conjugation
side (Array height
height_ ForeignPtr a
x) (Array width
width_ ForeignPtr a
y) =
   (height, width)
-> ((Atom height, Atom width) -> Ptr a -> IO ())
-> Array (height, width) a
forall sh nsize a.
(C sh, sh ~ ToShape nsize, C nsize, Storable a) =>
sh -> (nsize -> Ptr a -> IO ()) -> Array sh a
Array.unsafeCreateWithAutoSizes (height
height_,width
width_) (((Atom height, Atom width) -> Ptr a -> IO ())
 -> Array (height, width) a)
-> ((Atom height, Atom width) -> Ptr a -> IO ())
-> Array (height, width) a
forall a b. (a -> b) -> a -> b
$
      \(SubSize.Atom Int
n, SubSize.Atom Int
m) Ptr a
cPtr -> do
   let trans :: Conjugation -> Char
trans Conjugation
conjugated =
         case Conjugation
conjugated of Conjugation
NonConjugated -> Char
'T'; Conjugation
Conjugated -> Char
'C'
   let ((Char
transa,Char
transb),(Int
lda,Int
ldb)) =
         case Either Conjugation Conjugation
side of
            Left Conjugation
c -> ((Conjugation -> Char
trans Conjugation
c, Char
'N'),(Int
1,Int
1))
            Right Conjugation
c -> ((Char
'N', Conjugation -> Char
trans Conjugation
c),(Int
m,Int
n))
   ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
      Ptr CChar
transaPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char Char
transa
      Ptr CChar
transbPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char Char
transb
      Ptr CInt
mPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
m
      Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
      Ptr CInt
kPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
      Ptr a
alphaPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
one
      Ptr a
aPtr <- ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
y
      Ptr CInt
ldaPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
lda
      Ptr a
bPtr <- ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
x
      Ptr CInt
ldbPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
ldb
      Ptr a
betaPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
zero
      Ptr CInt
ldcPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
m
      IO () -> ContT () IO ()
forall a. IO a -> ContT () IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$
         Ptr CChar
-> Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
Blas.gemm
            Ptr CChar
transaPtr Ptr CChar
transbPtr Ptr CInt
mPtr Ptr CInt
nPtr Ptr CInt
kPtr Ptr a
alphaPtr
            Ptr a
aPtr Ptr CInt
ldaPtr Ptr a
bPtr Ptr CInt
ldbPtr Ptr a
betaPtr Ptr a
cPtr Ptr CInt
ldcPtr


decomplex ::
   (Class.Real a) =>
   Matrix height width (Complex a) ->
   Matrix height (width, ComplexShape) a
decomplex :: forall a height width.
Real a =>
Matrix height width (Complex a)
-> Matrix height (width, ComplexShape) a
decomplex (Array (height
height_,width
width_) ForeignPtr (Complex a)
a) =
   (height, (width, ComplexShape))
-> ForeignPtr a -> Array (height, (width, ComplexShape)) a
forall sh a. sh -> ForeignPtr a -> Array sh a
Array (height
height_, (width
width_, ComplexShape
forall sh. Static sh => sh
Shape.static)) (ForeignPtr (Complex a) -> ForeignPtr a
forall a b. ForeignPtr a -> ForeignPtr b
castForeignPtr ForeignPtr (Complex a)
a)

recomplex ::
   (Class.Real a) =>
   Matrix height (width, ComplexShape) a ->
   Matrix height width (Complex a)
recomplex :: forall a height width.
Real a =>
Matrix height (width, ComplexShape) a
-> Matrix height width (Complex a)
recomplex (Array (height
height_, (width
width_, Shape.NestedTuple Complex Element
_)) ForeignPtr a
a) =
   (height, width)
-> ForeignPtr (Complex a) -> Array (height, width) (Complex a)
forall sh a. sh -> ForeignPtr a -> Array sh a
Array (height
height_,width
width_) (ForeignPtr a -> ForeignPtr (Complex a)
forall a b. ForeignPtr a -> ForeignPtr b
castForeignPtr ForeignPtr a
a)


scaleRows ::
   (Shape.C height, Eq height, Shape.C width, Class.Floating a) =>
   Vector height a -> Matrix height width a -> Matrix height width a
scaleRows :: forall height width a.
(C height, Eq height, C width, Floating a) =>
Vector height a -> Matrix height width a -> Matrix height width a
scaleRows (Array height
heightX ForeignPtr a
x) (Array (height, width)
shape ForeignPtr a
a) =
   (height, width)
-> ((Atom height, Atom width) -> Ptr a -> IO ())
-> Array (height, width) a
forall sh nsize a.
(C sh, sh ~ ToShape nsize, C nsize, Storable a) =>
sh -> (nsize -> Ptr a -> IO ()) -> Array sh a
Array.unsafeCreateWithAutoSizes (height, width)
shape (((Atom height, Atom width) -> Ptr a -> IO ())
 -> Array (height, width) a)
-> ((Atom height, Atom width) -> Ptr a -> IO ())
-> Array (height, width) a
forall a b. (a -> b) -> a -> b
$
      \(SubSize.Atom Int
m, SubSize.Atom Int
n) Ptr a
bPtr -> do

   String -> Bool -> IO ()
Call.assert String
"scaleRows: sizes mismatch" (height
heightX height -> height -> Bool
forall a. Eq a => a -> a -> Bool
== (height, width) -> height
forall a b. (a, b) -> a
fst (height, width)
shape)
   ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
      Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
      Ptr a
xPtr <- ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
x
      Ptr a
aPtr <- ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
a
      Ptr CInt
incaPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
      Ptr CInt
incbPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
      IO () -> ContT () IO ()
forall a. IO a -> ContT () IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$ [IO ()] -> IO ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ ([IO ()] -> IO ()) -> [IO ()] -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> [IO ()] -> [IO ()]
forall a. Int -> [a] -> [a]
take Int
m ([IO ()] -> [IO ()]) -> [IO ()] -> [IO ()]
forall a b. (a -> b) -> a -> b
$
         (Ptr a -> Ptr a -> Ptr a -> IO ())
-> [Ptr a] -> [Ptr a] -> [Ptr a] -> [IO ()]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3
            (\Ptr a
xkPtr Ptr a
akPtr Ptr a
bkPtr -> do
               Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
Blas.copy Ptr CInt
nPtr Ptr a
akPtr Ptr CInt
incaPtr Ptr a
bkPtr Ptr CInt
incbPtr
               Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> IO ()
Blas.scal Ptr CInt
nPtr Ptr a
xkPtr Ptr a
bkPtr Ptr CInt
incbPtr)
            (Int -> Ptr a -> [Ptr a]
forall a. Storable a => Int -> Ptr a -> [Ptr a]
pointerSeq Int
1 Ptr a
xPtr)
            (Int -> Ptr a -> [Ptr a]
forall a. Storable a => Int -> Ptr a -> [Ptr a]
pointerSeq Int
n Ptr a
aPtr)
            (Int -> Ptr a -> [Ptr a]
forall a. Storable a => Int -> Ptr a -> [Ptr a]
pointerSeq Int
n Ptr a
bPtr)

scaleColumns ::
   (Shape.C height, Shape.C width, Eq width, Class.Floating a) =>
   Vector width a -> Matrix height width a -> Matrix height width a
scaleColumns :: forall height width a.
(C height, C width, Eq width, Floating a) =>
Vector width a -> Matrix height width a -> Matrix height width a
scaleColumns (Array width
widthX ForeignPtr a
x) (Array (height, width)
shape ForeignPtr a
a) =
   (height, width)
-> ((Atom height, Atom width) -> Ptr a -> IO ())
-> Array (height, width) a
forall sh nsize a.
(C sh, sh ~ ToShape nsize, C nsize, Storable a) =>
sh -> (nsize -> Ptr a -> IO ()) -> Array sh a
Array.unsafeCreateWithAutoSizes (height, width)
shape (((Atom height, Atom width) -> Ptr a -> IO ())
 -> Array (height, width) a)
-> ((Atom height, Atom width) -> Ptr a -> IO ())
-> Array (height, width) a
forall a b. (a -> b) -> a -> b
$
      \(SubSize.Atom Int
m, SubSize.Atom Int
n) Ptr a
bPtr -> do

   String -> Bool -> IO ()
Call.assert String
"scaleColumns: sizes mismatch" (width
widthX width -> width -> Bool
forall a. Eq a => a -> a -> Bool
== (height, width) -> width
forall a b. (a, b) -> b
snd (height, width)
shape)
   ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
      Ptr CChar
transPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char Char
'N'
      Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
      Ptr CInt
klPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
0
      Ptr CInt
kuPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
0
      Ptr a
alphaPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
one
      Ptr a
xPtr <- ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
x
      Ptr CInt
ldxPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
1
      Ptr a
aPtr <- ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
a
      Ptr CInt
incaPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
      Ptr a
betaPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
zero
      Ptr CInt
incbPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
      IO () -> ContT () IO ()
forall a. IO a -> ContT () IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$ [IO ()] -> IO ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ ([IO ()] -> IO ()) -> [IO ()] -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> [IO ()] -> [IO ()]
forall a. Int -> [a] -> [a]
take Int
m ([IO ()] -> [IO ()]) -> [IO ()] -> [IO ()]
forall a b. (a -> b) -> a -> b
$
         (Ptr a -> Ptr a -> IO ()) -> [Ptr a] -> [Ptr a] -> [IO ()]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
            (\Ptr a
akPtr Ptr a
bkPtr ->
               Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
Private.gbmv Ptr CChar
transPtr
                  Ptr CInt
nPtr Ptr CInt
nPtr Ptr CInt
klPtr Ptr CInt
kuPtr Ptr a
alphaPtr Ptr a
xPtr Ptr CInt
ldxPtr
                  Ptr a
akPtr Ptr CInt
incaPtr Ptr a
betaPtr Ptr a
bkPtr Ptr CInt
incbPtr)
            (Int -> Ptr a -> [Ptr a]
forall a. Storable a => Int -> Ptr a -> [Ptr a]
pointerSeq Int
n Ptr a
aPtr)
            (Int -> Ptr a -> [Ptr a]
forall a. Storable a => Int -> Ptr a -> [Ptr a]
pointerSeq Int
n Ptr a
bPtr)


{- |
>>> Matrix.multiplyVectorLeft (Array.vectorFromList [3,1,4]) (Array.fromList (Shape.ZeroBased (3::Int), Shape.Range 'a' 'b') [0,1,0,0,1,0::Real_])
StorableArray.fromList (Range {rangeFrom = 'a', rangeTo = 'b'}) [4.0,3.0]

prop> :{
   forVector number_ $ \xs ->
   Matrix.multiplyVectorLeft xs (Matrix.identity (Array.shape xs)) == xs
:}
-}
multiplyVectorLeft ::
   (Eq height, Shape.C height, Shape.C width, Class.Floating a) =>
   Vector height a -> Matrix height width a -> Vector width a
multiplyVectorLeft :: forall height width a.
(Eq height, C height, C width, Floating a) =>
Vector height a -> Matrix height width a -> Vector width a
multiplyVectorLeft Vector height a
x Matrix height width a
a = Vector height a -> Transposable height width a -> Vector width a
forall height width a.
(C height, C width, Eq height, Floating a) =>
Vector height a -> Transposable height width a -> Vector width a
multiplyVector Vector height a
x (Matrix height width a -> Transposable height width a
forall height width a.
Matrix height width a -> Transposable height width a
NonTransposed Matrix height width a
a)

{- |
>>> Matrix.multiplyVectorRight (Array.fromList (Shape.Range 'a' 'b', Shape.ZeroBased (3::Int)) [0,0,1,1,0,0]) (Array.vectorFromList [3,1,4::Real_])
StorableArray.fromList (Range {rangeFrom = 'a', rangeTo = 'b'}) [4.0,3.0]
>>> Matrix.multiplyVectorRight (Array.fromList (Shape.Range 'a' 'b', Shape.ZeroBased (3::Int)) [2,7,1,8,2,8]) (Array.vectorFromList [3,1,4::Real_])
StorableArray.fromList (Range {rangeFrom = 'a', rangeTo = 'b'}) [17.0,58.0]

prop> :{
   forVector number_ $ \xs ->
   Matrix.multiplyVectorRight (Matrix.identity (Array.shape xs)) xs == xs
:}

prop> :{
   forMatrix number_ $ \a ->
   QC.forAll (genVector (snd $ Array.shape a) number_) $ \x ->
   Matrix.singleColumn (Matrix.multiplyVectorRight a x)
   ==
   Matrix.multiply a (Matrix.singleColumn x)
:}

prop> :{
   forMatrix number_ $ \a ->
   QC.forAll (genVector (fst $ Array.shape a) number_) $ \x ->
   QC.forAll (genVector (snd $ Array.shape a) number_) $ \y ->
   Vector.dot x (Matrix.multiplyVectorRight a y)
   ==
   Vector.dot (Matrix.multiplyVectorLeft x a) y
:}

prop> :{
   forMatrix number_ $ \a ->
   QC.forAll (genVector (snd $ Array.shape a) number_) $ \x ->
   Matrix.multiplyVectorRight a x
   ==
   Matrix.multiplyVectorLeft x (transpose a)
:}
-}
multiplyVectorRight ::
   (Shape.C height, Shape.C width, Eq width, Class.Floating a) =>
   Matrix height width a -> Vector width a -> Vector height a
multiplyVectorRight :: forall height width a.
(C height, C width, Eq width, Floating a) =>
Matrix height width a -> Vector width a -> Vector height a
multiplyVectorRight Matrix height width a
a Vector width a
x = Vector width a -> Transposable width height a -> Vector height a
forall height width a.
(C height, C width, Eq height, Floating a) =>
Vector height a -> Transposable height width a -> Vector width a
multiplyVector Vector width a
x (Matrix height width a -> Transposable width height a
forall height width a.
Matrix width height a -> Transposable height width a
Transposed Matrix height width a
a)


data Transposable height width a =
     NonTransposed (Matrix height width a)
   | Transposed (Matrix width height a)
   deriving (Int -> Transposable height width a -> ShowS
[Transposable height width a] -> ShowS
Transposable height width a -> String
(Int -> Transposable height width a -> ShowS)
-> (Transposable height width a -> String)
-> ([Transposable height width a] -> ShowS)
-> Show (Transposable height width a)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall height width a.
(C height, C width, Storable a, Show height, Show width, Show a) =>
Int -> Transposable height width a -> ShowS
forall height width a.
(C height, C width, Storable a, Show height, Show width, Show a) =>
[Transposable height width a] -> ShowS
forall height width a.
(C height, C width, Storable a, Show height, Show width, Show a) =>
Transposable height width a -> String
$cshowsPrec :: forall height width a.
(C height, C width, Storable a, Show height, Show width, Show a) =>
Int -> Transposable height width a -> ShowS
showsPrec :: Int -> Transposable height width a -> ShowS
$cshow :: forall height width a.
(C height, C width, Storable a, Show height, Show width, Show a) =>
Transposable height width a -> String
show :: Transposable height width a -> String
$cshowList :: forall height width a.
(C height, C width, Storable a, Show height, Show width, Show a) =>
[Transposable height width a] -> ShowS
showList :: [Transposable height width a] -> ShowS
Show)


nonTransposed :: Matrix height width a -> Transposable height width a
nonTransposed :: forall height width a.
Matrix height width a -> Transposable height width a
nonTransposed = Matrix height width a -> Transposable height width a
forall height width a.
Matrix height width a -> Transposable height width a
NonTransposed

transposed :: Matrix height width a -> Transposable width height a
transposed :: forall height width a.
Matrix height width a -> Transposable width height a
transposed = Matrix height width a -> Transposable width height a
forall height width a.
Matrix width height a -> Transposable height width a
Transposed


transposeTransposable ::
   Transposable height width a -> Transposable width height a
transposeTransposable :: forall height width a.
Transposable height width a -> Transposable width height a
transposeTransposable Transposable height width a
at =
   case Transposable height width a
at of
      NonTransposed Matrix height width a
a -> Matrix height width a -> Transposable width height a
forall height width a.
Matrix width height a -> Transposable height width a
Transposed Matrix height width a
a
      Transposed Matrix width height a
a -> Matrix width height a -> Transposable width height a
forall height width a.
Matrix height width a -> Transposable height width a
NonTransposed Matrix width height a
a


inspectTransposable ::
   Transposable height width a -> (Char, (height, width), ForeignPtr a)
inspectTransposable :: forall height width a.
Transposable height width a
-> (Char, (height, width), ForeignPtr a)
inspectTransposable Transposable height width a
at =
   case Transposable height width a
at of
      NonTransposed (Array (height, width)
shA ForeignPtr a
fptr) -> (Char
'N', (height, width)
shA, ForeignPtr a
fptr)
      Transposed (Array (width, height)
shA ForeignPtr a
fptr) -> (Char
'T', (width, height) -> (height, width)
forall a b. (a, b) -> (b, a)
swap (width, height)
shA, ForeignPtr a
fptr)

multiplyVector ::
   (Shape.C height, Shape.C width, Eq height, Class.Floating a) =>
   Vector height a -> Transposable height width a -> Vector width a
multiplyVector :: forall height width a.
(C height, C width, Eq height, Floating a) =>
Vector height a -> Transposable height width a -> Vector width a
multiplyVector (Array height
sh ForeignPtr a
x) Transposable height width a
at =
   let (Char
transChar, (height
height_,width
width_), ForeignPtr a
a) = Transposable height width a
-> (Char, (height, width), ForeignPtr a)
forall height width a.
Transposable height width a
-> (Char, (height, width), ForeignPtr a)
inspectTransposable Transposable height width a
at in
   width -> (Int -> Ptr a -> IO ()) -> Array width a
forall sh a.
(C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO ()) -> Array sh a
Array.unsafeCreateWithSize width
width_ ((Int -> Ptr a -> IO ()) -> Array width a)
-> (Int -> Ptr a -> IO ()) -> Array width a
forall a b. (a -> b) -> a -> b
$ \Int
m0 Ptr a
yPtr -> do
   String -> Bool -> IO ()
Call.assert
      String
"Matrix.RowMajor.multiplyVector: shapes mismatch"
      (height
height_ height -> height -> Bool
forall a. Eq a => a -> a -> Bool
== height
sh)
   let n0 :: Int
n0 = height -> Int
forall sh. C sh => sh -> Int
Shape.size height
height_
   let (Int
m,Int
n) =
         case Transposable height width a
at of
            NonTransposed Matrix height width a
_ -> (Int
m0,Int
n0)
            Transposed Matrix width height a
_ -> (Int
n0,Int
m0)
   if Int
n0Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
==Int
0
      then a -> Int -> Ptr a -> IO ()
forall a. Floating a => a -> Int -> Ptr a -> IO ()
fill a
forall a. Floating a => a
zero Int
m0 Ptr a
yPtr
      else ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do

      let lda :: Int
lda = Int
m
      Ptr CChar
transPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char Char
transChar
      Ptr CInt
mPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
m
      Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
      Ptr a
alphaPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
one
      Ptr a
aPtr <- ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
a
      Ptr CInt
ldaPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
lda
      Ptr a
xPtr <- ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
x
      Ptr CInt
incxPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
      Ptr a
betaPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
zero
      Ptr CInt
incyPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
      IO () -> ContT () IO ()
forall a. IO a -> ContT () IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$
         Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
Blas.gemv
            Ptr CChar
transPtr Ptr CInt
mPtr Ptr CInt
nPtr Ptr a
alphaPtr Ptr a
aPtr Ptr CInt
ldaPtr
            Ptr a
xPtr Ptr CInt
incxPtr Ptr a
betaPtr Ptr a
yPtr Ptr CInt
incyPtr


{- |
>>> :{
   Matrix.multiply
      (Array.fromList (shapeInt 2, shapeInt 2) [1000,100,10,1])
      (Array.fromList (shapeInt 2, shapeInt 3) [0..5::Real_])
:}
... [300.0,1400.0,2500.0,3.0,14.0,25.0]

prop> :{
   forMatrix number_ $ \a ->
      Matrix.multiply (Matrix.identity (Matrix.height a)) a == a
:}
prop> :{
   forMatrix number_ $ \a ->
      Matrix.multiply a (Matrix.identity (Matrix.width a)) == a
:}
prop> :{
   forMatrix number_ $ \a ->
   forMatrix number_ $ \c ->
   QC.forAll (genVector (Matrix.width a, Matrix.height c) number_) $ \b ->
      Matrix.multiply a (Matrix.multiply b c)
      ==
      Matrix.multiply (Matrix.multiply a b) c
:}
-}
multiply ::
   (Shape.C height, Shape.C width, Shape.C fuse, Eq fuse, Class.Floating a) =>
   Matrix height fuse a -> Matrix fuse width a -> Matrix height width a
multiply :: forall height width fuse a.
(C height, C width, C fuse, Eq fuse, Floating a) =>
Matrix height fuse a
-> Matrix fuse width a -> Matrix height width a
multiply Matrix height fuse a
a Matrix fuse width a
b = Transposable height fuse a
-> Transposable fuse width a -> Matrix height width a
forall height width fuse a.
(C height, C width, C fuse, Eq fuse, Floating a) =>
Transposable height fuse a
-> Transposable fuse width a -> Matrix height width a
multiplyTransposable (Matrix height fuse a -> Transposable height fuse a
forall height width a.
Matrix height width a -> Transposable height width a
NonTransposed Matrix height fuse a
a) (Matrix fuse width a -> Transposable fuse width a
forall height width a.
Matrix height width a -> Transposable height width a
NonTransposed Matrix fuse width a
b)

{- |
prop> :{
   forMatrix number_ $ \a ->
   QC.forAll (genIdentityTrans (Matrix.height a)) $ \eye ->
      a == Matrix.multiplyTransposable eye (Matrix.nonTransposed a)
:}
prop> :{
   forMatrix number_ $ \a ->
   QC.forAll (genIdentityTrans (Matrix.width a)) $ \eye ->
      a == Matrix.multiplyTransposable (Matrix.nonTransposed a) eye
:}
prop> :{
   forMatrix number_ $ \a ->
   QC.forAll (genIdentityTrans (Matrix.width a)) $ \leftEye ->
   QC.forAll (genIdentityTrans (Matrix.height a)) $ \rightEye ->
      Matrix.multiplyTransposable leftEye (Matrix.transposed a)
      ==
      Matrix.multiplyTransposable (Matrix.transposed a) rightEye
:}
prop> :{
   forMatrix number_ $ \a ->
   QC.forAll (QC.choose (0,maxDim)) $ \n ->
   QC.forAll (genVector (Matrix.width a, shapeInt n) number_) $ \b ->
      transpose (Matrix.multiply a b)
      ==
      Matrix.multiplyTransposable (Matrix.transposed b) (Matrix.transposed a)
:}
-}
multiplyTransposable ::
   (Shape.C height, Shape.C width, Shape.C fuse, Eq fuse, Class.Floating a) =>
   Transposable height fuse a ->
   Transposable fuse width a ->
   Matrix height width a
multiplyTransposable :: forall height width fuse a.
(C height, C width, C fuse, Eq fuse, Floating a) =>
Transposable height fuse a
-> Transposable fuse width a -> Matrix height width a
multiplyTransposable Transposable height fuse a
a Transposable fuse width a
b = Transposable fuse width a
-> Transposable height fuse a -> Matrix height width a
forall height width fuse a.
(C height, C width, C fuse, Eq fuse, Floating a) =>
Transposable fuse height a
-> Transposable width fuse a -> Matrix width height a
multiplyColumnMajor Transposable fuse width a
b Transposable height fuse a
a

multiplyColumnMajor ::
   (Shape.C height, Shape.C width, Shape.C fuse, Eq fuse, Class.Floating a) =>
   Transposable fuse height a ->
   Transposable width fuse a ->
   Matrix width height a
multiplyColumnMajor :: forall height width fuse a.
(C height, C width, C fuse, Eq fuse, Floating a) =>
Transposable fuse height a
-> Transposable width fuse a -> Matrix width height a
multiplyColumnMajor Transposable fuse height a
at Transposable width fuse a
bt =
   let (Char
transa, (fuse
widthA,height
heightA), ForeignPtr a
a) = Transposable fuse height a -> (Char, (fuse, height), ForeignPtr a)
forall height width a.
Transposable height width a
-> (Char, (height, width), ForeignPtr a)
inspectTransposable Transposable fuse height a
at in
   let (Char
transb, (width
widthB,fuse
heightB), ForeignPtr a
b) = Transposable width fuse a -> (Char, (width, fuse), ForeignPtr a)
forall height width a.
Transposable height width a
-> (Char, (height, width), ForeignPtr a)
inspectTransposable Transposable width fuse a
bt in
   (width, height) -> (Ptr a -> IO ()) -> Array (width, height) a
forall sh a.
(C sh, Storable a) =>
sh -> (Ptr a -> IO ()) -> Array sh a
Array.unsafeCreate (width
widthB,height
heightA) ((Ptr a -> IO ()) -> Array (width, height) a)
-> (Ptr a -> IO ()) -> Array (width, height) a
forall a b. (a -> b) -> a -> b
$ \Ptr a
cPtr -> do
   String -> Bool -> IO ()
Call.assert
      String
"Matrix.RowMajor.multiply: shapes mismatch"
      (fuse
widthA fuse -> fuse -> Bool
forall a. Eq a => a -> a -> Bool
== fuse
heightB)

   ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
   let m :: Int
m = height -> Int
forall sh. C sh => sh -> Int
Shape.size height
heightA
   let k :: Int
k = fuse -> Int
forall sh. C sh => sh -> Int
Shape.size fuse
widthA
   let n :: Int
n = width -> Int
forall sh. C sh => sh -> Int
Shape.size width
widthB
   let lda :: Int
lda = case Transposable fuse height a
at of NonTransposed Matrix fuse height a
_ -> Int
m; Transposed Matrix height fuse a
_ -> Int
k
   let ldb :: Int
ldb = case Transposable width fuse a
bt of NonTransposed Matrix width fuse a
_ -> Int
k; Transposed Matrix fuse width a
_ -> Int
n
   let ldc :: Int
ldc = Int
m
   if Int
kInt -> Int -> Bool
forall a. Eq a => a -> a -> Bool
==Int
0
      then IO () -> ContT () IO ()
forall a. IO a -> ContT () IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$ a -> Int -> Ptr a -> IO ()
forall a. Floating a => a -> Int -> Ptr a -> IO ()
fill a
forall a. Floating a => a
zero (Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
n) Ptr a
cPtr
      else do
      Ptr CChar
transaPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char Char
transa
      Ptr CChar
transbPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char Char
transb
      Ptr CInt
mPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
m
      Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
      Ptr CInt
kPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
k
      Ptr a
alphaPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
one
      Ptr a
aPtr <- ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
a
      Ptr CInt
ldaPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
lda
      Ptr a
bPtr <- ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
b
      Ptr CInt
ldbPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
ldb
      Ptr a
betaPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
zero
      Ptr CInt
ldcPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
ldc
      IO () -> ContT () IO ()
forall a. IO a -> ContT () IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$
         Ptr CChar
-> Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
Blas.gemm
            Ptr CChar
transaPtr Ptr CChar
transbPtr Ptr CInt
mPtr Ptr CInt
nPtr Ptr CInt
kPtr Ptr a
alphaPtr Ptr a
aPtr Ptr CInt
ldaPtr
            Ptr a
bPtr Ptr CInt
ldbPtr Ptr a
betaPtr Ptr a
cPtr Ptr CInt
ldcPtr


{- |
>>> :{
   Matrix.kronecker
      (Array.fromList (shapeInt 2, shapeInt 2) [0,1,-1,0::Real_])
      (Array.fromList (shapeInt 2, shapeInt 3) [1..6])
:}
... [0.0,0.0,0.0,1.0,2.0,3.0,0.0,0.0,0.0,4.0,5.0,6.0,-1.0,-2.0,-3.0,0.0,0.0,0.0,-4.0,-5.0,-6.0,0.0,0.0,0.0]

>>> :{
   Matrix.kronecker
      (Array.fromList (shapeInt 2, shapeInt 2) [1,2,3,4::Real_])
      (Array.fromList (shapeInt 2, shapeInt 3) [1,2,4,8,16,32])
:}
... [1.0,2.0,4.0,2.0,4.0,8.0,8.0,16.0,32.0,16.0,32.0,64.0,3.0,6.0,12.0,4.0,8.0,16.0,24.0,48.0,96.0,32.0,64.0,128.0]

prop> :{
   QC.forAll (QC.choose (0,5)) $ \m ->
   QC.forAll (QC.choose (0,5)) $ \n ->
      Matrix.kronecker
         (Matrix.identity (shapeInt m))
         (Matrix.identity (shapeInt n))
      ==
      (Matrix.identity (shapeInt m, shapeInt n)
         :: Matrix.Square (ShapeInt, ShapeInt) Number_)
:}
-}
kronecker ::
   (Shape.C heightA, Shape.C widthA, Shape.C heightB, Shape.C widthB,
    Class.Floating a) =>
   Matrix heightA widthA a ->
   Matrix heightB widthB a ->
   Matrix (heightA,heightB) (widthA,widthB) a
kronecker :: forall heightA widthA heightB widthB a.
(C heightA, C widthA, C heightB, C widthB, Floating a) =>
Matrix heightA widthA a
-> Matrix heightB widthB a
-> Matrix (heightA, heightB) (widthA, widthB) a
kronecker Matrix heightA widthA a
a Matrix heightB widthB a
b = Transposable heightA widthA a
-> Matrix heightB widthB a
-> Matrix (heightA, heightB) (widthA, widthB) a
forall heightA widthA heightB widthB a.
(C heightA, C widthA, C heightB, C widthB, Floating a) =>
Transposable heightA widthA a
-> Matrix heightB widthB a
-> Matrix (heightA, heightB) (widthA, widthB) a
kroneckerLeftTransposable (Matrix heightA widthA a -> Transposable heightA widthA a
forall height width a.
Matrix height width a -> Transposable height width a
NonTransposed Matrix heightA widthA a
a) Matrix heightB widthB a
b

kroneckerTransposable ::
   (Shape.C heightA, Shape.C widthA, Shape.C heightB, Shape.C widthB,
    Class.Floating a) =>
   Transposable heightA widthA a ->
   Transposable heightB widthB a ->
   Transposable (heightA,heightB) (widthA,widthB) a
kroneckerTransposable :: forall heightA widthA heightB widthB a.
(C heightA, C widthA, C heightB, C widthB, Floating a) =>
Transposable heightA widthA a
-> Transposable heightB widthB a
-> Transposable (heightA, heightB) (widthA, widthB) a
kroneckerTransposable Transposable heightA widthA a
at Transposable heightB widthB a
bt =
   case Transposable heightB widthB a
bt of
      NonTransposed Matrix heightB widthB a
b -> Matrix (heightA, heightB) (widthA, widthB) a
-> Transposable (heightA, heightB) (widthA, widthB) a
forall height width a.
Matrix height width a -> Transposable height width a
NonTransposed (Matrix (heightA, heightB) (widthA, widthB) a
 -> Transposable (heightA, heightB) (widthA, widthB) a)
-> Matrix (heightA, heightB) (widthA, widthB) a
-> Transposable (heightA, heightB) (widthA, widthB) a
forall a b. (a -> b) -> a -> b
$ Transposable heightA widthA a
-> Matrix heightB widthB a
-> Matrix (heightA, heightB) (widthA, widthB) a
forall heightA widthA heightB widthB a.
(C heightA, C widthA, C heightB, C widthB, Floating a) =>
Transposable heightA widthA a
-> Matrix heightB widthB a
-> Matrix (heightA, heightB) (widthA, widthB) a
kroneckerLeftTransposable Transposable heightA widthA a
at Matrix heightB widthB a
b
      Transposed Matrix widthB heightB a
b ->
         Matrix (widthA, widthB) (heightA, heightB) a
-> Transposable (heightA, heightB) (widthA, widthB) a
forall height width a.
Matrix width height a -> Transposable height width a
Transposed (Matrix (widthA, widthB) (heightA, heightB) a
 -> Transposable (heightA, heightB) (widthA, widthB) a)
-> Matrix (widthA, widthB) (heightA, heightB) a
-> Transposable (heightA, heightB) (widthA, widthB) a
forall a b. (a -> b) -> a -> b
$ Transposable widthA heightA a
-> Matrix widthB heightB a
-> Matrix (widthA, widthB) (heightA, heightB) a
forall heightA widthA heightB widthB a.
(C heightA, C widthA, C heightB, C widthB, Floating a) =>
Transposable heightA widthA a
-> Matrix heightB widthB a
-> Matrix (heightA, heightB) (widthA, widthB) a
kroneckerLeftTransposable (Transposable heightA widthA a -> Transposable widthA heightA a
forall height width a.
Transposable height width a -> Transposable width height a
transposeTransposable Transposable heightA widthA a
at) Matrix widthB heightB a
b

kroneckerLeftTransposable ::
   (Shape.C heightA, Shape.C widthA, Shape.C heightB, Shape.C widthB,
    Class.Floating a) =>
   Transposable heightA widthA a ->
   Matrix heightB widthB a ->
   Matrix (heightA,heightB) (widthA,widthB) a
kroneckerLeftTransposable :: forall heightA widthA heightB widthB a.
(C heightA, C widthA, C heightB, C widthB, Floating a) =>
Transposable heightA widthA a
-> Matrix heightB widthB a
-> Matrix (heightA, heightB) (widthA, widthB) a
kroneckerLeftTransposable Transposable heightA widthA a
at (Array (heightB
heightB,widthB
widthB) ForeignPtr a
b) =
   let (Char
_trans, (heightA
heightA,widthA
widthA), ForeignPtr a
a) = Transposable heightA widthA a
-> (Char, (heightA, widthA), ForeignPtr a)
forall height width a.
Transposable height width a
-> (Char, (height, width), ForeignPtr a)
inspectTransposable Transposable heightA widthA a
at
   in ((heightA, heightB), (widthA, widthB))
-> (Ptr a -> IO ())
-> Array ((heightA, heightB), (widthA, widthB)) a
forall sh a.
(C sh, Storable a) =>
sh -> (Ptr a -> IO ()) -> Array sh a
Array.unsafeCreate ((heightA
heightA,heightB
heightB), (widthA
widthA,widthB
widthB)) ((Ptr a -> IO ())
 -> Array ((heightA, heightB), (widthA, widthB)) a)
-> (Ptr a -> IO ())
-> Array ((heightA, heightB), (widthA, widthB)) a
forall a b. (a -> b) -> a -> b
$ \Ptr a
cPtr ->
      ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
   let (Int
ma,Int
na) = (heightA -> Int
forall sh. C sh => sh -> Int
Shape.size heightA
heightA, widthA -> Int
forall sh. C sh => sh -> Int
Shape.size widthA
widthA)
   let (Int
mb,Int
nb) = (heightB -> Int
forall sh. C sh => sh -> Int
Shape.size heightB
heightB, widthB -> Int
forall sh. C sh => sh -> Int
Shape.size widthB
widthB)
   let (Int
lda,Int
istep) =
         case Transposable heightA widthA a
at of
            NonTransposed Matrix heightA widthA a
_ -> (Int
1,Int
na)
            Transposed Matrix widthA heightA a
_ -> (Int
ma,Int
1)
   Ptr CChar
transaPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char Char
'N'
   Ptr CChar
transbPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char Char
'T'
   Ptr CInt
mPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
na
   Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
nb
   Ptr CInt
kPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
   Ptr a
alphaPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
one
   Ptr a
aPtr <- ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
a
   Ptr CInt
ldaPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
lda
   Ptr a
bPtr <- ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
b
   Ptr CInt
ldbPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
1
   Ptr a
betaPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
zero
   Ptr CInt
ldcPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
nb
   IO () -> ContT () IO ()
forall a. IO a -> ContT () IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$
      [(Int, Int)] -> ((Int, Int) -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ((Int -> Int -> (Int, Int)) -> [Int] -> [Int] -> [(Int, Int)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 (,) (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take Int
ma [Int
0..]) (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take Int
mb [Int
0..])) (((Int, Int) -> IO ()) -> IO ()) -> ((Int, Int) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \(Int
i,Int
j) -> do
         let aiPtr :: Ptr a
aiPtr = Ptr a -> Int -> Ptr a
forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Ptr a
aPtr (Int
istepInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
i)
         let bjPtr :: Ptr a
bjPtr = Ptr a -> Int -> Ptr a
forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Ptr a
bPtr (Int
nbInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
j)
         let cijPtr :: Ptr a
cijPtr = Ptr a -> Int -> Ptr a
forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Ptr a
cPtr (Int
naInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
nbInt -> Int -> Int
forall a. Num a => a -> a -> a
*(Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
mbInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
i))
         Ptr CChar
-> Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
Blas.gemm
            Ptr CChar
transbPtr Ptr CChar
transaPtr Ptr CInt
nPtr Ptr CInt
mPtr Ptr CInt
kPtr Ptr a
alphaPtr
            Ptr a
bjPtr Ptr CInt
ldbPtr Ptr a
aiPtr Ptr CInt
ldaPtr Ptr a
betaPtr Ptr a
cijPtr Ptr CInt
ldcPtr