{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Numeric.BLAS.Matrix.RowMajor (
Matrix,
Square,
Vector,
height, width,
Array2.singleRow, Array2.flattenRow,
Array2.singleColumn, Array2.flattenColumn,
identity,
takeRow,
takeColumn,
fromRows,
above,
beside,
takeTop, takeBottom,
takeLeft, takeRight,
tensorProduct,
decomplex,
recomplex,
scaleRows,
scaleColumns,
multiplyVectorLeft,
multiplyVectorRight,
Transposable(..), nonTransposed, transposed,
transposeTransposable,
multiply,
multiplyTransposable,
kronecker,
kroneckerTransposable,
kroneckerLeftTransposable,
) where
import qualified Numeric.BLAS.Private as Private
import Numeric.BLAS.Matrix.Modifier (Conjugation(NonConjugated,Conjugated))
import Numeric.BLAS.Scalar (zero, one)
import Numeric.BLAS.Private (ShapeInt, shapeInt, ComplexShape, pointerSeq, fill)
import qualified Numeric.BLAS.FFI.Generic as Blas
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class
import Foreign.Marshal.Array (copyArray, advancePtr)
import Foreign.ForeignPtr (ForeignPtr, withForeignPtr, castForeignPtr)
import Foreign.Storable (Storable, poke)
import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)
import Control.Applicative (liftA2)
import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Storable.Dim2 as Array2
import qualified Data.Array.Comfort.Shape.SubSize as SubSize
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))
import Data.Array.Comfort.Shape ((::+))
import Data.Foldable (forM_)
import Data.Complex (Complex)
import Data.Tuple.HT (swap)
type Matrix height width = Array (height,width)
type Square sh = Matrix sh sh
type Vector = Array
height :: Matrix height width a -> height
height :: forall height width a. Matrix height width a -> height
height = (height, width) -> height
forall a b. (a, b) -> a
fst ((height, width) -> height)
-> (Matrix height width a -> (height, width))
-> Matrix height width a
-> height
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix height width a -> (height, width)
forall sh a. Array sh a -> sh
Array.shape
width :: Matrix height width a -> width
width :: forall height width a. Matrix height width a -> width
width = (height, width) -> width
forall a b. (a, b) -> b
snd ((height, width) -> width)
-> (Matrix height width a -> (height, width))
-> Matrix height width a
-> width
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix height width a -> (height, width)
forall sh a. Array sh a -> sh
Array.shape
identity :: (Shape.C sh, Class.Floating a) => sh -> Square sh a
identity :: forall sh a. (C sh, Floating a) => sh -> Square sh a
identity sh
sh =
(sh, sh)
-> (Sub (Atom sh, Atom sh) -> Ptr a -> IO ()) -> Array (sh, sh) a
forall sh nsize a.
(C sh, sh ~ ToShape nsize, C nsize, Storable a) =>
sh -> (nsize -> Ptr a -> IO ()) -> Array sh a
Array.unsafeCreateWithAutoSizes (sh
sh,sh
sh) ((Sub (Atom sh, Atom sh) -> Ptr a -> IO ()) -> Array (sh, sh) a)
-> (Sub (Atom sh, Atom sh) -> Ptr a -> IO ()) -> Array (sh, sh) a
forall a b. (a -> b) -> a -> b
$
\(SubSize.Sub Int
blockSize (SubSize.Atom Int
nint, SubSize.Atom Int
_)) Ptr a
yPtr ->
ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
Ptr CInt
nPtr <- FortranIO () (Ptr CInt)
forall a r. Storable a => FortranIO r (Ptr a)
Call.alloca
Ptr a
xPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
zero
Ptr CInt
incxPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
0
Ptr CInt
incyPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
IO () -> ContT () IO ()
forall a. IO a -> ContT () IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$ do
Ptr CInt -> CInt -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr CInt
nPtr (CInt -> IO ()) -> CInt -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
blockSize
Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
Blas.copy Ptr CInt
nPtr Ptr a
xPtr Ptr CInt
incxPtr Ptr a
yPtr Ptr CInt
incyPtr
let n :: CInt
n = Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nint
Ptr CInt -> CInt -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr CInt
nPtr CInt
n
Ptr a -> a -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr a
xPtr a
forall a. Floating a => a
one
Ptr CInt -> CInt -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr CInt
incyPtr (CInt
nCInt -> CInt -> CInt
forall a. Num a => a -> a -> a
+CInt
1)
Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
Blas.copy Ptr CInt
nPtr Ptr a
xPtr Ptr CInt
incxPtr Ptr a
yPtr Ptr CInt
incyPtr
takeRow ::
(Shape.Indexed height, Shape.C width, Shape.Index height ~ ix,
Storable a) =>
ix -> Matrix height width a -> Vector width a
takeRow :: forall height width ix a.
(Indexed height, C width, Index height ~ ix, Storable a) =>
ix -> Matrix height width a -> Vector width a
takeRow ix
ix (Array (height
height_,width
width_) ForeignPtr a
x) =
width -> (Int -> Ptr a -> IO ()) -> Array width a
forall sh a.
(C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO ()) -> Array sh a
Array.unsafeCreateWithSize width
width_ ((Int -> Ptr a -> IO ()) -> Array width a)
-> (Int -> Ptr a -> IO ()) -> Array width a
forall a b. (a -> b) -> a -> b
$ \Int
n Ptr a
yPtr ->
ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
x ((Ptr a -> IO ()) -> IO ()) -> (Ptr a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
xPtr ->
Ptr a -> Ptr a -> Int -> IO ()
forall a. Storable a => Ptr a -> Ptr a -> Int -> IO ()
copyArray Ptr a
yPtr (Ptr a -> Int -> Ptr a
forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Ptr a
xPtr (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* height -> Index height -> Int
forall sh. Indexed sh => sh -> Index sh -> Int
Shape.offset height
height_ ix
Index height
ix)) Int
n
takeColumn ::
(Shape.C height, Shape.Indexed width, Shape.Index width ~ ix,
Class.Floating a) =>
ix -> Matrix height width a -> Vector height a
takeColumn :: forall height width ix a.
(C height, Indexed width, Index width ~ ix, Floating a) =>
ix -> Matrix height width a -> Vector height a
takeColumn ix
ix (Array (height
height_,width
width_) ForeignPtr a
x) =
height -> (Int -> Ptr a -> IO ()) -> Array height a
forall sh a.
(C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO ()) -> Array sh a
Array.unsafeCreateWithSize height
height_ ((Int -> Ptr a -> IO ()) -> Array height a)
-> (Int -> Ptr a -> IO ()) -> Array height a
forall a b. (a -> b) -> a -> b
$ \Int
n Ptr a
yPtr -> ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
let offset :: Int
offset = width -> Index width -> Int
forall sh. Indexed sh => sh -> Index sh -> Int
Shape.offset width
width_ ix
Index width
ix
Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
Ptr a
xPtr <- ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
x
Ptr CInt
incxPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint (Int -> FortranIO () (Ptr CInt)) -> Int -> FortranIO () (Ptr CInt)
forall a b. (a -> b) -> a -> b
$ width -> Int
forall sh. C sh => sh -> Int
Shape.size width
width_
Ptr CInt
incyPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
IO () -> ContT () IO ()
forall a. IO a -> ContT () IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$ Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
Blas.copy Ptr CInt
nPtr (Ptr a -> Int -> Ptr a
forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Ptr a
xPtr Int
offset) Ptr CInt
incxPtr Ptr a
yPtr Ptr CInt
incyPtr
fromRows ::
(Shape.C width, Eq width, Storable a) =>
width -> [Vector width a] -> Matrix ShapeInt width a
fromRows :: forall width a.
(C width, Eq width, Storable a) =>
width -> [Vector width a] -> Matrix ShapeInt width a
fromRows width
width_ [Vector width a]
rows =
(ShapeInt, width)
-> ((Atom ShapeInt, Atom width) -> Ptr a -> IO ())
-> Array (ShapeInt, width) a
forall sh nsize a.
(C sh, sh ~ ToShape nsize, C nsize, Storable a) =>
sh -> (nsize -> Ptr a -> IO ()) -> Array sh a
Array.unsafeCreateWithAutoSizes (Int -> ShapeInt
shapeInt (Int -> ShapeInt) -> Int -> ShapeInt
forall a b. (a -> b) -> a -> b
$ [Vector width a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Vector width a]
rows, width
width_) (((Atom ShapeInt, Atom width) -> Ptr a -> IO ())
-> Array (ShapeInt, width) a)
-> ((Atom ShapeInt, Atom width) -> Ptr a -> IO ())
-> Array (ShapeInt, width) a
forall a b. (a -> b) -> a -> b
$
\(SubSize.Atom Int
_, SubSize.Atom Int
widthSize) Ptr a
dstPtr ->
[(Ptr a, Vector width a)]
-> ((Ptr a, Vector width a) -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Ptr a] -> [Vector width a] -> [(Ptr a, Vector width a)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Int -> Ptr a -> [Ptr a]
forall a. Storable a => Int -> Ptr a -> [Ptr a]
pointerSeq Int
widthSize Ptr a
dstPtr) [Vector width a]
rows) (((Ptr a, Vector width a) -> IO ()) -> IO ())
-> ((Ptr a, Vector width a) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$
\(Ptr a
dstRowPtr, Array.Array width
rowWidth ForeignPtr a
srcFPtr) ->
ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
srcFPtr ((Ptr a -> IO ()) -> IO ()) -> (Ptr a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
srcPtr -> do
String -> Bool -> IO ()
Call.assert
String
"Matrix.fromRows: non-matching vector size"
(width
width_ width -> width -> Bool
forall a. Eq a => a -> a -> Bool
== width
rowWidth)
Ptr a -> Ptr a -> Int -> IO ()
forall a. Storable a => Ptr a -> Ptr a -> Int -> IO ()
copyArray Ptr a
dstRowPtr Ptr a
srcPtr Int
widthSize
infixr 2 `above`
infixr 3 `beside`
above ::
(Shape.C heightA, Shape.C heightB) =>
(Shape.C width, Eq width) =>
(Storable a) =>
Matrix heightA width a ->
Matrix heightB width a ->
Matrix (heightA::+heightB) width a
above :: forall heightA heightB width a.
(C heightA, C heightB, C width, Eq width, Storable a) =>
Matrix heightA width a
-> Matrix heightB width a -> Matrix (heightA ::+ heightB) width a
above = Array2 heightA width a
-> Array2 heightB width a -> Array2 (heightA ::+ heightB) width a
forall heightA heightB width a.
(C heightA, C heightB, C width, Eq width, Storable a) =>
Matrix heightA width a
-> Matrix heightB width a -> Matrix (heightA ::+ heightB) width a
Array2.above
beside ::
(Shape.C widthA, Shape.C widthB) =>
(Shape.C height, Eq height) =>
(Storable a) =>
Matrix height widthA a ->
Matrix height widthB a ->
Matrix height (widthA::+widthB) a
beside :: forall widthA widthB height a.
(C widthA, C widthB, C height, Eq height, Storable a) =>
Matrix height widthA a
-> Matrix height widthB a -> Matrix height (widthA ::+ widthB) a
beside = Array2 height widthA a
-> Array2 height widthB a -> Array2 height (widthA ::+ widthB) a
forall height widthA widthB a.
(C height, Eq height, C widthA, C widthB, Storable a) =>
Array2 height widthA a
-> Array2 height widthB a -> Array2 height (widthA ::+ widthB) a
Array2.beside
takeTop ::
(Shape.C heightA, Shape.C heightB, Shape.C width, Storable a) =>
Matrix (heightA::+heightB) width a ->
Matrix heightA width a
takeTop :: forall heightA heightB width a.
(C heightA, C heightB, C width, Storable a) =>
Matrix (heightA ::+ heightB) width a -> Matrix heightA width a
takeTop = Array2 (heightA ::+ heightB) width a -> Array2 heightA width a
forall heightA heightB width a.
(C heightA, C heightB, C width, Storable a) =>
Matrix (heightA ::+ heightB) width a -> Matrix heightA width a
Array2.takeTop
takeBottom ::
(Shape.C heightA, Shape.C heightB, Shape.C width, Storable a) =>
Matrix (heightA::+heightB) width a ->
Matrix heightB width a
takeBottom :: forall heightA heightB width a.
(C heightA, C heightB, C width, Storable a) =>
Matrix (heightA ::+ heightB) width a -> Matrix heightB width a
takeBottom = Array2 (heightA ::+ heightB) width a -> Array2 heightB width a
forall heightA heightB width a.
(C heightA, C heightB, C width, Storable a) =>
Matrix (heightA ::+ heightB) width a -> Matrix heightB width a
Array2.takeBottom
takeLeft ::
(Shape.C height, Shape.C widthA, Shape.C widthB, Storable a) =>
Matrix height (widthA::+widthB) a ->
Matrix height widthA a
takeLeft :: forall height widthA widthB a.
(C height, C widthA, C widthB, Storable a) =>
Matrix height (widthA ::+ widthB) a -> Matrix height widthA a
takeLeft = Array2 height (widthA ::+ widthB) a -> Array2 height widthA a
forall height widthA widthB a.
(C height, C widthA, C widthB, Storable a) =>
Matrix height (widthA ::+ widthB) a -> Matrix height widthA a
Array2.takeLeft
takeRight ::
(Shape.C height, Shape.C widthA, Shape.C widthB, Storable a) =>
Matrix height (widthA::+widthB) a ->
Matrix height widthB a
takeRight :: forall height widthA widthB a.
(C height, C widthA, C widthB, Storable a) =>
Matrix height (widthA ::+ widthB) a -> Matrix height widthB a
takeRight = Array2 height (widthA ::+ widthB) a -> Array2 height widthB a
forall height widthA widthB a.
(C height, C widthA, C widthB, Storable a) =>
Matrix height (widthA ::+ widthB) a -> Matrix height widthB a
Array2.takeRight
{-# WARNING tensorProduct "Don't use conjugation. Left and Right are swapped." #-}
tensorProduct ::
(Shape.C height, Shape.C width, Class.Floating a) =>
Either Conjugation Conjugation ->
Vector height a -> Vector width a -> Matrix height width a
tensorProduct :: forall height width a.
(C height, C width, Floating a) =>
Either Conjugation Conjugation
-> Vector height a -> Vector width a -> Matrix height width a
tensorProduct Either Conjugation Conjugation
side (Array height
height_ ForeignPtr a
x) (Array width
width_ ForeignPtr a
y) =
(height, width)
-> ((Atom height, Atom width) -> Ptr a -> IO ())
-> Array (height, width) a
forall sh nsize a.
(C sh, sh ~ ToShape nsize, C nsize, Storable a) =>
sh -> (nsize -> Ptr a -> IO ()) -> Array sh a
Array.unsafeCreateWithAutoSizes (height
height_,width
width_) (((Atom height, Atom width) -> Ptr a -> IO ())
-> Array (height, width) a)
-> ((Atom height, Atom width) -> Ptr a -> IO ())
-> Array (height, width) a
forall a b. (a -> b) -> a -> b
$
\(SubSize.Atom Int
n, SubSize.Atom Int
m) Ptr a
cPtr -> do
let trans :: Conjugation -> Char
trans Conjugation
conjugated =
case Conjugation
conjugated of Conjugation
NonConjugated -> Char
'T'; Conjugation
Conjugated -> Char
'C'
let ((Char
transa,Char
transb),(Int
lda,Int
ldb)) =
case Either Conjugation Conjugation
side of
Left Conjugation
c -> ((Conjugation -> Char
trans Conjugation
c, Char
'N'),(Int
1,Int
1))
Right Conjugation
c -> ((Char
'N', Conjugation -> Char
trans Conjugation
c),(Int
m,Int
n))
ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
Ptr CChar
transaPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char Char
transa
Ptr CChar
transbPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char Char
transb
Ptr CInt
mPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
m
Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
Ptr CInt
kPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
Ptr a
alphaPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
one
Ptr a
aPtr <- ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
y
Ptr CInt
ldaPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
lda
Ptr a
bPtr <- ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
x
Ptr CInt
ldbPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
ldb
Ptr a
betaPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
zero
Ptr CInt
ldcPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
m
IO () -> ContT () IO ()
forall a. IO a -> ContT () IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$
Ptr CChar
-> Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
Blas.gemm
Ptr CChar
transaPtr Ptr CChar
transbPtr Ptr CInt
mPtr Ptr CInt
nPtr Ptr CInt
kPtr Ptr a
alphaPtr
Ptr a
aPtr Ptr CInt
ldaPtr Ptr a
bPtr Ptr CInt
ldbPtr Ptr a
betaPtr Ptr a
cPtr Ptr CInt
ldcPtr
decomplex ::
(Class.Real a) =>
Matrix height width (Complex a) ->
Matrix height (width, ComplexShape) a
decomplex :: forall a height width.
Real a =>
Matrix height width (Complex a)
-> Matrix height (width, ComplexShape) a
decomplex (Array (height
height_,width
width_) ForeignPtr (Complex a)
a) =
(height, (width, ComplexShape))
-> ForeignPtr a -> Array (height, (width, ComplexShape)) a
forall sh a. sh -> ForeignPtr a -> Array sh a
Array (height
height_, (width
width_, ComplexShape
forall sh. Static sh => sh
Shape.static)) (ForeignPtr (Complex a) -> ForeignPtr a
forall a b. ForeignPtr a -> ForeignPtr b
castForeignPtr ForeignPtr (Complex a)
a)
recomplex ::
(Class.Real a) =>
Matrix height (width, ComplexShape) a ->
Matrix height width (Complex a)
recomplex :: forall a height width.
Real a =>
Matrix height (width, ComplexShape) a
-> Matrix height width (Complex a)
recomplex (Array (height
height_, (width
width_, Shape.NestedTuple Complex Element
_)) ForeignPtr a
a) =
(height, width)
-> ForeignPtr (Complex a) -> Array (height, width) (Complex a)
forall sh a. sh -> ForeignPtr a -> Array sh a
Array (height
height_,width
width_) (ForeignPtr a -> ForeignPtr (Complex a)
forall a b. ForeignPtr a -> ForeignPtr b
castForeignPtr ForeignPtr a
a)
scaleRows ::
(Shape.C height, Eq height, Shape.C width, Class.Floating a) =>
Vector height a -> Matrix height width a -> Matrix height width a
scaleRows :: forall height width a.
(C height, Eq height, C width, Floating a) =>
Vector height a -> Matrix height width a -> Matrix height width a
scaleRows (Array height
heightX ForeignPtr a
x) (Array (height, width)
shape ForeignPtr a
a) =
(height, width)
-> ((Atom height, Atom width) -> Ptr a -> IO ())
-> Array (height, width) a
forall sh nsize a.
(C sh, sh ~ ToShape nsize, C nsize, Storable a) =>
sh -> (nsize -> Ptr a -> IO ()) -> Array sh a
Array.unsafeCreateWithAutoSizes (height, width)
shape (((Atom height, Atom width) -> Ptr a -> IO ())
-> Array (height, width) a)
-> ((Atom height, Atom width) -> Ptr a -> IO ())
-> Array (height, width) a
forall a b. (a -> b) -> a -> b
$
\(SubSize.Atom Int
m, SubSize.Atom Int
n) Ptr a
bPtr -> do
String -> Bool -> IO ()
Call.assert String
"scaleRows: sizes mismatch" (height
heightX height -> height -> Bool
forall a. Eq a => a -> a -> Bool
== (height, width) -> height
forall a b. (a, b) -> a
fst (height, width)
shape)
ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
Ptr a
xPtr <- ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
x
Ptr a
aPtr <- ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
a
Ptr CInt
incaPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
Ptr CInt
incbPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
IO () -> ContT () IO ()
forall a. IO a -> ContT () IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$ [IO ()] -> IO ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ ([IO ()] -> IO ()) -> [IO ()] -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> [IO ()] -> [IO ()]
forall a. Int -> [a] -> [a]
take Int
m ([IO ()] -> [IO ()]) -> [IO ()] -> [IO ()]
forall a b. (a -> b) -> a -> b
$
(Ptr a -> Ptr a -> Ptr a -> IO ())
-> [Ptr a] -> [Ptr a] -> [Ptr a] -> [IO ()]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3
(\Ptr a
xkPtr Ptr a
akPtr Ptr a
bkPtr -> do
Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
Blas.copy Ptr CInt
nPtr Ptr a
akPtr Ptr CInt
incaPtr Ptr a
bkPtr Ptr CInt
incbPtr
Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> IO ()
Blas.scal Ptr CInt
nPtr Ptr a
xkPtr Ptr a
bkPtr Ptr CInt
incbPtr)
(Int -> Ptr a -> [Ptr a]
forall a. Storable a => Int -> Ptr a -> [Ptr a]
pointerSeq Int
1 Ptr a
xPtr)
(Int -> Ptr a -> [Ptr a]
forall a. Storable a => Int -> Ptr a -> [Ptr a]
pointerSeq Int
n Ptr a
aPtr)
(Int -> Ptr a -> [Ptr a]
forall a. Storable a => Int -> Ptr a -> [Ptr a]
pointerSeq Int
n Ptr a
bPtr)
scaleColumns ::
(Shape.C height, Shape.C width, Eq width, Class.Floating a) =>
Vector width a -> Matrix height width a -> Matrix height width a
scaleColumns :: forall height width a.
(C height, C width, Eq width, Floating a) =>
Vector width a -> Matrix height width a -> Matrix height width a
scaleColumns (Array width
widthX ForeignPtr a
x) (Array (height, width)
shape ForeignPtr a
a) =
(height, width)
-> ((Atom height, Atom width) -> Ptr a -> IO ())
-> Array (height, width) a
forall sh nsize a.
(C sh, sh ~ ToShape nsize, C nsize, Storable a) =>
sh -> (nsize -> Ptr a -> IO ()) -> Array sh a
Array.unsafeCreateWithAutoSizes (height, width)
shape (((Atom height, Atom width) -> Ptr a -> IO ())
-> Array (height, width) a)
-> ((Atom height, Atom width) -> Ptr a -> IO ())
-> Array (height, width) a
forall a b. (a -> b) -> a -> b
$
\(SubSize.Atom Int
m, SubSize.Atom Int
n) Ptr a
bPtr -> do
String -> Bool -> IO ()
Call.assert String
"scaleColumns: sizes mismatch" (width
widthX width -> width -> Bool
forall a. Eq a => a -> a -> Bool
== (height, width) -> width
forall a b. (a, b) -> b
snd (height, width)
shape)
ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
Ptr CChar
transPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char Char
'N'
Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
Ptr CInt
klPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
0
Ptr CInt
kuPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
0
Ptr a
alphaPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
one
Ptr a
xPtr <- ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
x
Ptr CInt
ldxPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
1
Ptr a
aPtr <- ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
a
Ptr CInt
incaPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
Ptr a
betaPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
zero
Ptr CInt
incbPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
IO () -> ContT () IO ()
forall a. IO a -> ContT () IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$ [IO ()] -> IO ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ ([IO ()] -> IO ()) -> [IO ()] -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> [IO ()] -> [IO ()]
forall a. Int -> [a] -> [a]
take Int
m ([IO ()] -> [IO ()]) -> [IO ()] -> [IO ()]
forall a b. (a -> b) -> a -> b
$
(Ptr a -> Ptr a -> IO ()) -> [Ptr a] -> [Ptr a] -> [IO ()]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
(\Ptr a
akPtr Ptr a
bkPtr ->
Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
Private.gbmv Ptr CChar
transPtr
Ptr CInt
nPtr Ptr CInt
nPtr Ptr CInt
klPtr Ptr CInt
kuPtr Ptr a
alphaPtr Ptr a
xPtr Ptr CInt
ldxPtr
Ptr a
akPtr Ptr CInt
incaPtr Ptr a
betaPtr Ptr a
bkPtr Ptr CInt
incbPtr)
(Int -> Ptr a -> [Ptr a]
forall a. Storable a => Int -> Ptr a -> [Ptr a]
pointerSeq Int
n Ptr a
aPtr)
(Int -> Ptr a -> [Ptr a]
forall a. Storable a => Int -> Ptr a -> [Ptr a]
pointerSeq Int
n Ptr a
bPtr)
multiplyVectorLeft ::
(Eq height, Shape.C height, Shape.C width, Class.Floating a) =>
Vector height a -> Matrix height width a -> Vector width a
multiplyVectorLeft :: forall height width a.
(Eq height, C height, C width, Floating a) =>
Vector height a -> Matrix height width a -> Vector width a
multiplyVectorLeft Vector height a
x Matrix height width a
a = Vector height a -> Transposable height width a -> Vector width a
forall height width a.
(C height, C width, Eq height, Floating a) =>
Vector height a -> Transposable height width a -> Vector width a
multiplyVector Vector height a
x (Matrix height width a -> Transposable height width a
forall height width a.
Matrix height width a -> Transposable height width a
NonTransposed Matrix height width a
a)
multiplyVectorRight ::
(Shape.C height, Shape.C width, Eq width, Class.Floating a) =>
Matrix height width a -> Vector width a -> Vector height a
multiplyVectorRight :: forall height width a.
(C height, C width, Eq width, Floating a) =>
Matrix height width a -> Vector width a -> Vector height a
multiplyVectorRight Matrix height width a
a Vector width a
x = Vector width a -> Transposable width height a -> Vector height a
forall height width a.
(C height, C width, Eq height, Floating a) =>
Vector height a -> Transposable height width a -> Vector width a
multiplyVector Vector width a
x (Matrix height width a -> Transposable width height a
forall height width a.
Matrix width height a -> Transposable height width a
Transposed Matrix height width a
a)
data Transposable height width a =
NonTransposed (Matrix height width a)
| Transposed (Matrix width height a)
deriving (Int -> Transposable height width a -> ShowS
[Transposable height width a] -> ShowS
Transposable height width a -> String
(Int -> Transposable height width a -> ShowS)
-> (Transposable height width a -> String)
-> ([Transposable height width a] -> ShowS)
-> Show (Transposable height width a)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall height width a.
(C height, C width, Storable a, Show height, Show width, Show a) =>
Int -> Transposable height width a -> ShowS
forall height width a.
(C height, C width, Storable a, Show height, Show width, Show a) =>
[Transposable height width a] -> ShowS
forall height width a.
(C height, C width, Storable a, Show height, Show width, Show a) =>
Transposable height width a -> String
$cshowsPrec :: forall height width a.
(C height, C width, Storable a, Show height, Show width, Show a) =>
Int -> Transposable height width a -> ShowS
showsPrec :: Int -> Transposable height width a -> ShowS
$cshow :: forall height width a.
(C height, C width, Storable a, Show height, Show width, Show a) =>
Transposable height width a -> String
show :: Transposable height width a -> String
$cshowList :: forall height width a.
(C height, C width, Storable a, Show height, Show width, Show a) =>
[Transposable height width a] -> ShowS
showList :: [Transposable height width a] -> ShowS
Show)
nonTransposed :: Matrix height width a -> Transposable height width a
nonTransposed :: forall height width a.
Matrix height width a -> Transposable height width a
nonTransposed = Matrix height width a -> Transposable height width a
forall height width a.
Matrix height width a -> Transposable height width a
NonTransposed
transposed :: Matrix height width a -> Transposable width height a
transposed :: forall height width a.
Matrix height width a -> Transposable width height a
transposed = Matrix height width a -> Transposable width height a
forall height width a.
Matrix width height a -> Transposable height width a
Transposed
transposeTransposable ::
Transposable height width a -> Transposable width height a
transposeTransposable :: forall height width a.
Transposable height width a -> Transposable width height a
transposeTransposable Transposable height width a
at =
case Transposable height width a
at of
NonTransposed Matrix height width a
a -> Matrix height width a -> Transposable width height a
forall height width a.
Matrix width height a -> Transposable height width a
Transposed Matrix height width a
a
Transposed Matrix width height a
a -> Matrix width height a -> Transposable width height a
forall height width a.
Matrix height width a -> Transposable height width a
NonTransposed Matrix width height a
a
inspectTransposable ::
Transposable height width a -> (Char, (height, width), ForeignPtr a)
inspectTransposable :: forall height width a.
Transposable height width a
-> (Char, (height, width), ForeignPtr a)
inspectTransposable Transposable height width a
at =
case Transposable height width a
at of
NonTransposed (Array (height, width)
shA ForeignPtr a
fptr) -> (Char
'N', (height, width)
shA, ForeignPtr a
fptr)
Transposed (Array (width, height)
shA ForeignPtr a
fptr) -> (Char
'T', (width, height) -> (height, width)
forall a b. (a, b) -> (b, a)
swap (width, height)
shA, ForeignPtr a
fptr)
multiplyVector ::
(Shape.C height, Shape.C width, Eq height, Class.Floating a) =>
Vector height a -> Transposable height width a -> Vector width a
multiplyVector :: forall height width a.
(C height, C width, Eq height, Floating a) =>
Vector height a -> Transposable height width a -> Vector width a
multiplyVector (Array height
sh ForeignPtr a
x) Transposable height width a
at =
let (Char
transChar, (height
height_,width
width_), ForeignPtr a
a) = Transposable height width a
-> (Char, (height, width), ForeignPtr a)
forall height width a.
Transposable height width a
-> (Char, (height, width), ForeignPtr a)
inspectTransposable Transposable height width a
at in
width -> (Int -> Ptr a -> IO ()) -> Array width a
forall sh a.
(C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO ()) -> Array sh a
Array.unsafeCreateWithSize width
width_ ((Int -> Ptr a -> IO ()) -> Array width a)
-> (Int -> Ptr a -> IO ()) -> Array width a
forall a b. (a -> b) -> a -> b
$ \Int
m0 Ptr a
yPtr -> do
String -> Bool -> IO ()
Call.assert
String
"Matrix.RowMajor.multiplyVector: shapes mismatch"
(height
height_ height -> height -> Bool
forall a. Eq a => a -> a -> Bool
== height
sh)
let n0 :: Int
n0 = height -> Int
forall sh. C sh => sh -> Int
Shape.size height
height_
let (Int
m,Int
n) =
case Transposable height width a
at of
NonTransposed Matrix height width a
_ -> (Int
m0,Int
n0)
Transposed Matrix width height a
_ -> (Int
n0,Int
m0)
if Int
n0Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
==Int
0
then a -> Int -> Ptr a -> IO ()
forall a. Floating a => a -> Int -> Ptr a -> IO ()
fill a
forall a. Floating a => a
zero Int
m0 Ptr a
yPtr
else ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
let lda :: Int
lda = Int
m
Ptr CChar
transPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char Char
transChar
Ptr CInt
mPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
m
Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
Ptr a
alphaPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
one
Ptr a
aPtr <- ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
a
Ptr CInt
ldaPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
lda
Ptr a
xPtr <- ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
x
Ptr CInt
incxPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
Ptr a
betaPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
zero
Ptr CInt
incyPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
IO () -> ContT () IO ()
forall a. IO a -> ContT () IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$
Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
Blas.gemv
Ptr CChar
transPtr Ptr CInt
mPtr Ptr CInt
nPtr Ptr a
alphaPtr Ptr a
aPtr Ptr CInt
ldaPtr
Ptr a
xPtr Ptr CInt
incxPtr Ptr a
betaPtr Ptr a
yPtr Ptr CInt
incyPtr
multiply ::
(Shape.C height, Shape.C width, Shape.C fuse, Eq fuse, Class.Floating a) =>
Matrix height fuse a -> Matrix fuse width a -> Matrix height width a
multiply :: forall height width fuse a.
(C height, C width, C fuse, Eq fuse, Floating a) =>
Matrix height fuse a
-> Matrix fuse width a -> Matrix height width a
multiply Matrix height fuse a
a Matrix fuse width a
b = Transposable height fuse a
-> Transposable fuse width a -> Matrix height width a
forall height width fuse a.
(C height, C width, C fuse, Eq fuse, Floating a) =>
Transposable height fuse a
-> Transposable fuse width a -> Matrix height width a
multiplyTransposable (Matrix height fuse a -> Transposable height fuse a
forall height width a.
Matrix height width a -> Transposable height width a
NonTransposed Matrix height fuse a
a) (Matrix fuse width a -> Transposable fuse width a
forall height width a.
Matrix height width a -> Transposable height width a
NonTransposed Matrix fuse width a
b)
multiplyTransposable ::
(Shape.C height, Shape.C width, Shape.C fuse, Eq fuse, Class.Floating a) =>
Transposable height fuse a ->
Transposable fuse width a ->
Matrix height width a
multiplyTransposable :: forall height width fuse a.
(C height, C width, C fuse, Eq fuse, Floating a) =>
Transposable height fuse a
-> Transposable fuse width a -> Matrix height width a
multiplyTransposable Transposable height fuse a
a Transposable fuse width a
b = Transposable fuse width a
-> Transposable height fuse a -> Matrix height width a
forall height width fuse a.
(C height, C width, C fuse, Eq fuse, Floating a) =>
Transposable fuse height a
-> Transposable width fuse a -> Matrix width height a
multiplyColumnMajor Transposable fuse width a
b Transposable height fuse a
a
multiplyColumnMajor ::
(Shape.C height, Shape.C width, Shape.C fuse, Eq fuse, Class.Floating a) =>
Transposable fuse height a ->
Transposable width fuse a ->
Matrix width height a
multiplyColumnMajor :: forall height width fuse a.
(C height, C width, C fuse, Eq fuse, Floating a) =>
Transposable fuse height a
-> Transposable width fuse a -> Matrix width height a
multiplyColumnMajor Transposable fuse height a
at Transposable width fuse a
bt =
let (Char
transa, (fuse
widthA,height
heightA), ForeignPtr a
a) = Transposable fuse height a -> (Char, (fuse, height), ForeignPtr a)
forall height width a.
Transposable height width a
-> (Char, (height, width), ForeignPtr a)
inspectTransposable Transposable fuse height a
at in
let (Char
transb, (width
widthB,fuse
heightB), ForeignPtr a
b) = Transposable width fuse a -> (Char, (width, fuse), ForeignPtr a)
forall height width a.
Transposable height width a
-> (Char, (height, width), ForeignPtr a)
inspectTransposable Transposable width fuse a
bt in
(width, height) -> (Ptr a -> IO ()) -> Array (width, height) a
forall sh a.
(C sh, Storable a) =>
sh -> (Ptr a -> IO ()) -> Array sh a
Array.unsafeCreate (width
widthB,height
heightA) ((Ptr a -> IO ()) -> Array (width, height) a)
-> (Ptr a -> IO ()) -> Array (width, height) a
forall a b. (a -> b) -> a -> b
$ \Ptr a
cPtr -> do
String -> Bool -> IO ()
Call.assert
String
"Matrix.RowMajor.multiply: shapes mismatch"
(fuse
widthA fuse -> fuse -> Bool
forall a. Eq a => a -> a -> Bool
== fuse
heightB)
ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
let m :: Int
m = height -> Int
forall sh. C sh => sh -> Int
Shape.size height
heightA
let k :: Int
k = fuse -> Int
forall sh. C sh => sh -> Int
Shape.size fuse
widthA
let n :: Int
n = width -> Int
forall sh. C sh => sh -> Int
Shape.size width
widthB
let lda :: Int
lda = case Transposable fuse height a
at of NonTransposed Matrix fuse height a
_ -> Int
m; Transposed Matrix height fuse a
_ -> Int
k
let ldb :: Int
ldb = case Transposable width fuse a
bt of NonTransposed Matrix width fuse a
_ -> Int
k; Transposed Matrix fuse width a
_ -> Int
n
let ldc :: Int
ldc = Int
m
if Int
kInt -> Int -> Bool
forall a. Eq a => a -> a -> Bool
==Int
0
then IO () -> ContT () IO ()
forall a. IO a -> ContT () IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$ a -> Int -> Ptr a -> IO ()
forall a. Floating a => a -> Int -> Ptr a -> IO ()
fill a
forall a. Floating a => a
zero (Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
n) Ptr a
cPtr
else do
Ptr CChar
transaPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char Char
transa
Ptr CChar
transbPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char Char
transb
Ptr CInt
mPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
m
Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
Ptr CInt
kPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
k
Ptr a
alphaPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
one
Ptr a
aPtr <- ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
a
Ptr CInt
ldaPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
lda
Ptr a
bPtr <- ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
b
Ptr CInt
ldbPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
ldb
Ptr a
betaPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
zero
Ptr CInt
ldcPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
ldc
IO () -> ContT () IO ()
forall a. IO a -> ContT () IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$
Ptr CChar
-> Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
Blas.gemm
Ptr CChar
transaPtr Ptr CChar
transbPtr Ptr CInt
mPtr Ptr CInt
nPtr Ptr CInt
kPtr Ptr a
alphaPtr Ptr a
aPtr Ptr CInt
ldaPtr
Ptr a
bPtr Ptr CInt
ldbPtr Ptr a
betaPtr Ptr a
cPtr Ptr CInt
ldcPtr
kronecker ::
(Shape.C heightA, Shape.C widthA, Shape.C heightB, Shape.C widthB,
Class.Floating a) =>
Matrix heightA widthA a ->
Matrix heightB widthB a ->
Matrix (heightA,heightB) (widthA,widthB) a
kronecker :: forall heightA widthA heightB widthB a.
(C heightA, C widthA, C heightB, C widthB, Floating a) =>
Matrix heightA widthA a
-> Matrix heightB widthB a
-> Matrix (heightA, heightB) (widthA, widthB) a
kronecker Matrix heightA widthA a
a Matrix heightB widthB a
b = Transposable heightA widthA a
-> Matrix heightB widthB a
-> Matrix (heightA, heightB) (widthA, widthB) a
forall heightA widthA heightB widthB a.
(C heightA, C widthA, C heightB, C widthB, Floating a) =>
Transposable heightA widthA a
-> Matrix heightB widthB a
-> Matrix (heightA, heightB) (widthA, widthB) a
kroneckerLeftTransposable (Matrix heightA widthA a -> Transposable heightA widthA a
forall height width a.
Matrix height width a -> Transposable height width a
NonTransposed Matrix heightA widthA a
a) Matrix heightB widthB a
b
kroneckerTransposable ::
(Shape.C heightA, Shape.C widthA, Shape.C heightB, Shape.C widthB,
Class.Floating a) =>
Transposable heightA widthA a ->
Transposable heightB widthB a ->
Transposable (heightA,heightB) (widthA,widthB) a
kroneckerTransposable :: forall heightA widthA heightB widthB a.
(C heightA, C widthA, C heightB, C widthB, Floating a) =>
Transposable heightA widthA a
-> Transposable heightB widthB a
-> Transposable (heightA, heightB) (widthA, widthB) a
kroneckerTransposable Transposable heightA widthA a
at Transposable heightB widthB a
bt =
case Transposable heightB widthB a
bt of
NonTransposed Matrix heightB widthB a
b -> Matrix (heightA, heightB) (widthA, widthB) a
-> Transposable (heightA, heightB) (widthA, widthB) a
forall height width a.
Matrix height width a -> Transposable height width a
NonTransposed (Matrix (heightA, heightB) (widthA, widthB) a
-> Transposable (heightA, heightB) (widthA, widthB) a)
-> Matrix (heightA, heightB) (widthA, widthB) a
-> Transposable (heightA, heightB) (widthA, widthB) a
forall a b. (a -> b) -> a -> b
$ Transposable heightA widthA a
-> Matrix heightB widthB a
-> Matrix (heightA, heightB) (widthA, widthB) a
forall heightA widthA heightB widthB a.
(C heightA, C widthA, C heightB, C widthB, Floating a) =>
Transposable heightA widthA a
-> Matrix heightB widthB a
-> Matrix (heightA, heightB) (widthA, widthB) a
kroneckerLeftTransposable Transposable heightA widthA a
at Matrix heightB widthB a
b
Transposed Matrix widthB heightB a
b ->
Matrix (widthA, widthB) (heightA, heightB) a
-> Transposable (heightA, heightB) (widthA, widthB) a
forall height width a.
Matrix width height a -> Transposable height width a
Transposed (Matrix (widthA, widthB) (heightA, heightB) a
-> Transposable (heightA, heightB) (widthA, widthB) a)
-> Matrix (widthA, widthB) (heightA, heightB) a
-> Transposable (heightA, heightB) (widthA, widthB) a
forall a b. (a -> b) -> a -> b
$ Transposable widthA heightA a
-> Matrix widthB heightB a
-> Matrix (widthA, widthB) (heightA, heightB) a
forall heightA widthA heightB widthB a.
(C heightA, C widthA, C heightB, C widthB, Floating a) =>
Transposable heightA widthA a
-> Matrix heightB widthB a
-> Matrix (heightA, heightB) (widthA, widthB) a
kroneckerLeftTransposable (Transposable heightA widthA a -> Transposable widthA heightA a
forall height width a.
Transposable height width a -> Transposable width height a
transposeTransposable Transposable heightA widthA a
at) Matrix widthB heightB a
b
kroneckerLeftTransposable ::
(Shape.C heightA, Shape.C widthA, Shape.C heightB, Shape.C widthB,
Class.Floating a) =>
Transposable heightA widthA a ->
Matrix heightB widthB a ->
Matrix (heightA,heightB) (widthA,widthB) a
kroneckerLeftTransposable :: forall heightA widthA heightB widthB a.
(C heightA, C widthA, C heightB, C widthB, Floating a) =>
Transposable heightA widthA a
-> Matrix heightB widthB a
-> Matrix (heightA, heightB) (widthA, widthB) a
kroneckerLeftTransposable Transposable heightA widthA a
at (Array (heightB
heightB,widthB
widthB) ForeignPtr a
b) =
let (Char
_trans, (heightA
heightA,widthA
widthA), ForeignPtr a
a) = Transposable heightA widthA a
-> (Char, (heightA, widthA), ForeignPtr a)
forall height width a.
Transposable height width a
-> (Char, (height, width), ForeignPtr a)
inspectTransposable Transposable heightA widthA a
at
in ((heightA, heightB), (widthA, widthB))
-> (Ptr a -> IO ())
-> Array ((heightA, heightB), (widthA, widthB)) a
forall sh a.
(C sh, Storable a) =>
sh -> (Ptr a -> IO ()) -> Array sh a
Array.unsafeCreate ((heightA
heightA,heightB
heightB), (widthA
widthA,widthB
widthB)) ((Ptr a -> IO ())
-> Array ((heightA, heightB), (widthA, widthB)) a)
-> (Ptr a -> IO ())
-> Array ((heightA, heightB), (widthA, widthB)) a
forall a b. (a -> b) -> a -> b
$ \Ptr a
cPtr ->
ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
let (Int
ma,Int
na) = (heightA -> Int
forall sh. C sh => sh -> Int
Shape.size heightA
heightA, widthA -> Int
forall sh. C sh => sh -> Int
Shape.size widthA
widthA)
let (Int
mb,Int
nb) = (heightB -> Int
forall sh. C sh => sh -> Int
Shape.size heightB
heightB, widthB -> Int
forall sh. C sh => sh -> Int
Shape.size widthB
widthB)
let (Int
lda,Int
istep) =
case Transposable heightA widthA a
at of
NonTransposed Matrix heightA widthA a
_ -> (Int
1,Int
na)
Transposed Matrix widthA heightA a
_ -> (Int
ma,Int
1)
Ptr CChar
transaPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char Char
'N'
Ptr CChar
transbPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char Char
'T'
Ptr CInt
mPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
na
Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
nb
Ptr CInt
kPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
Ptr a
alphaPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
one
Ptr a
aPtr <- ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
a
Ptr CInt
ldaPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
lda
Ptr a
bPtr <- ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
b
Ptr CInt
ldbPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
1
Ptr a
betaPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
zero
Ptr CInt
ldcPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
nb
IO () -> ContT () IO ()
forall a. IO a -> ContT () IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$
[(Int, Int)] -> ((Int, Int) -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ((Int -> Int -> (Int, Int)) -> [Int] -> [Int] -> [(Int, Int)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 (,) (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take Int
ma [Int
0..]) (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take Int
mb [Int
0..])) (((Int, Int) -> IO ()) -> IO ()) -> ((Int, Int) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \(Int
i,Int
j) -> do
let aiPtr :: Ptr a
aiPtr = Ptr a -> Int -> Ptr a
forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Ptr a
aPtr (Int
istepInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
i)
let bjPtr :: Ptr a
bjPtr = Ptr a -> Int -> Ptr a
forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Ptr a
bPtr (Int
nbInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
j)
let cijPtr :: Ptr a
cijPtr = Ptr a -> Int -> Ptr a
forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Ptr a
cPtr (Int
naInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
nbInt -> Int -> Int
forall a. Num a => a -> a -> a
*(Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
mbInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
i))
Ptr CChar
-> Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
Blas.gemm
Ptr CChar
transbPtr Ptr CChar
transaPtr Ptr CInt
nPtr Ptr CInt
mPtr Ptr CInt
kPtr Ptr a
alphaPtr
Ptr a
bjPtr Ptr CInt
ldbPtr Ptr a
aiPtr Ptr CInt
ldaPtr Ptr a
betaPtr Ptr a
cijPtr Ptr CInt
ldcPtr