{-# LANGUAGE ForeignFunctionInterface #-} module Numeric.CBLAS.FFI.Private ( Routine.dotu, Routine.dotc, Routine.sum, Order, rowMajor, columnMajor, Transpose, noTrans, trans, conjTrans, conjNoTrans, omatcopy, copyMatrix, transferMatrix, addMatrix, ) where import qualified Numeric.CBLAS.FFI.Routine as Routine import Numeric.CBLAS.FFI.Type import qualified Numeric.Netlib.Modifier as Modi import qualified Numeric.Netlib.Class as Class import Foreign.Marshal (with) -- import Foreign.Storable (peek) import Foreign.Ptr (Ptr) -- import Foreign.C.Types import Data.Complex (Complex) type OMatCopy c a = Order -> Transpose -> CBlasInt -> CBlasInt -> c -> Ptr a -> CBlasInt -> Ptr a -> CBlasInt -> IO () type COMatCopy a = OMatCopy (Ptr a) a foreign import ccall "cblas_somatcopy" somatcopy :: OMatCopy Float Float foreign import ccall "cblas_domatcopy" domatcopy :: OMatCopy Double Double foreign import ccall "cblas_comatcopy" comatcopy :: COMatCopy (Complex Float) foreign import ccall "cblas_zomatcopy" zomatcopy :: COMatCopy (Complex Double) newtype OMATCOPY a = OMATCOPY {getOMATCOPY :: OMatCopy a a} omatcopy :: (Class.Floating a) => OMatCopy a a omatcopy = getOMATCOPY $ Class.switchFloating (OMATCOPY somatcopy) (OMATCOPY domatcopy) (OMATCOPY $ \order transp rows cols alpha a lda b ldb -> with alpha $ \alphaPtr -> comatcopy order transp rows cols alphaPtr a lda b ldb) (OMATCOPY $ \order transp rows cols alpha a lda b ldb -> with alpha $ \alphaPtr -> zomatcopy order transp rows cols alphaPtr a lda b ldb) copyMatrix :: (Class.Floating a) => Modi.Transposition -> Int -> Int -> Ptr a -> Int -> Ptr a -> Int -> IO () copyMatrix transp rows cols a lda b ldb = omatcopy columnMajor (case transp of Modi.Transposed -> trans Modi.NonTransposed -> noTrans) (fromIntegral rows) (fromIntegral cols) 1 a (fromIntegral lda) b (fromIntegral ldb) transferMatrix :: (Class.Floating a) => Modi.Transposition -> Modi.Conjugation -> Int -> Int -> a -> Ptr a -> Int -> Ptr a -> Int -> IO () transferMatrix transp conj rows cols alpha a lda b ldb = omatcopy columnMajor (case (transp,conj) of (Modi.NonTransposed, Modi.NonConjugated) -> noTrans (Modi.Transposed, Modi.NonConjugated) -> trans (Modi.NonTransposed, Modi.Conjugated) -> conjNoTrans (Modi.Transposed, Modi.Conjugated) -> conjTrans) (fromIntegral rows) (fromIntegral cols) alpha a (fromIntegral lda) b (fromIntegral ldb) type GeAdd c a = Order -> CBlasInt -> CBlasInt -> c -> Ptr a -> CBlasInt -> c -> Ptr a -> CBlasInt -> IO () foreign import ccall "cblas_sgeadd" sgeadd :: GeAdd Float Float foreign import ccall "cblas_dgeadd" dgeadd :: GeAdd Double Double foreign import ccall "cblas_cgeadd" cgeadd :: GeAdd (Ptr (Complex Float)) (Complex Float) foreign import ccall "cblas_zgeadd" zgeadd :: GeAdd (Ptr (Complex Double)) (Complex Double) newtype GEADD a = GEADD {getGEADD :: GeAdd a a} geadd :: (Class.Floating a) => GeAdd a a geadd = getGEADD $ Class.switchFloating (GEADD sgeadd) (GEADD dgeadd) (GEADD $ \order rows cols alpha a lda beta b ldb -> with alpha $ \alphaPtr -> with beta $ \betaPtr -> cgeadd order rows cols alphaPtr a lda betaPtr b ldb) (GEADD $ \order rows cols alpha a lda beta b ldb -> with alpha $ \alphaPtr -> with beta $ \betaPtr -> zgeadd order rows cols alphaPtr a lda betaPtr b ldb) addMatrix :: (Class.Floating a) => Int -> Int -> a -> Ptr a -> Int -> a -> Ptr a -> Int -> IO () addMatrix rows cols alpha a lda beta b ldb = geadd columnMajor (fromIntegral rows) (fromIntegral cols) alpha a (fromIntegral lda) beta b (fromIntegral ldb)