module Numeric.CBLAS.FFI.Private ( dotReal, sum, copyMatrix, -- transferMatrix, addMatrix, ) where import Numeric.CBLAS.FFI.Common (pointerSeq) import qualified Numeric.BLAS.FFI.Real as BlasReal import qualified Numeric.BLAS.FFI.Generic as Blas import qualified Numeric.Netlib.Modifier as Modi import qualified Numeric.Netlib.Class as Class import qualified Numeric.Netlib.Utility as Call import qualified Control.Monad.Trans.Cont as MC import Control.Monad.IO.Class (liftIO) import Control.Applicative (liftA2) import Foreign.Marshal.Array (advancePtr) import Foreign.Ptr (Ptr, castPtr) import Data.Complex (Complex((:+))) import Prelude hiding (sum) newtype Sum a = Sum {runSum :: Int -> Ptr a -> Int -> IO a} sum :: (Class.Floating a) => Int -> Ptr a -> Int -> IO a sum = runSum $ Class.switchFloating (Sum sumReal) (Sum sumReal) (Sum sumComplex) (Sum sumComplex) sumReal :: Class.Real a => Int -> Ptr a -> Int -> IO a sumReal n xPtr incx = MC.evalContT $ do nPtr <- Call.cint n incxPtr <- Call.cint incx yPtr <- Call.real 1 incyPtr <- Call.cint 0 liftIO $ BlasReal.dot nPtr xPtr incxPtr yPtr incyPtr sumComplex :: Class.Real a => Int -> Ptr (Complex a) -> Int -> IO (Complex a) sumComplex n xPtr incx = MC.evalContT $ do nPtr <- Call.cint n let sxPtr = castPtr xPtr -- realPtr incxPtr <- Call.cint (2*incx) yPtr <- Call.real 1 incyPtr <- Call.cint 0 liftIO $ liftA2 (:+) (BlasReal.dot nPtr sxPtr incxPtr yPtr incyPtr) (BlasReal.dot nPtr (advancePtr sxPtr 1) incxPtr yPtr incyPtr) dotReal :: Class.Real a => Int -> Ptr a -> Int -> Ptr a -> Int -> IO a dotReal n xPtr incx yPtr incy = MC.evalContT $ do nPtr <- Call.cint n incxPtr <- Call.cint incx incyPtr <- Call.cint incy liftIO $ BlasReal.dot nPtr xPtr incxPtr yPtr incyPtr copyMatrix :: (Class.Floating a) => Modi.Transposition -> Int -> Int -> Ptr a -> Int -> Ptr a -> Int -> IO () copyMatrix transp rows cols aPtr lda bPtr ldb = case transp of Modi.NonTransposed -> MC.evalContT $ do incPtr <- Call.cint 1 if rows == lda && rows == ldb then do nPtr <- Call.cint (rows*cols) liftIO $ Blas.copy nPtr aPtr incPtr bPtr incPtr else do nPtr <- Call.cint rows liftIO $ sequence_ $ take cols $ zipWith (\akPtr bkPtr -> Blas.copy nPtr akPtr incPtr bkPtr incPtr) (pointerSeq lda aPtr) (pointerSeq ldb bPtr) Modi.Transposed -> MC.evalContT $ do nPtr <- Call.cint cols incaPtr <- Call.cint lda incbPtr <- Call.cint 1 liftIO $ sequence_ $ take rows $ zipWith (\akPtr bkPtr -> Blas.copy nPtr akPtr incaPtr bkPtr incbPtr) (pointerSeq 1 aPtr) (pointerSeq ldb bPtr) {- transferMatrix :: (Class.Floating a) => Modi.Transposition -> Modi.Conjugation -> Int -> Int -> a -> Ptr a -> Int -> Ptr a -> Int -> IO () transferMatrix transp conj rows cols alpha a lda b ldb = -} addMatrix :: (Class.Floating a) => Int -> Int -> a -> Ptr a -> Int -> a -> Ptr a -> Int -> IO () addMatrix rows cols alpha aPtr lda beta bPtr ldb = MC.evalContT $ do incPtr <- Call.cint 1 alphaPtr <- Call.number alpha betaPtr <- Call.number beta if rows == lda && rows == ldb then do nPtr <- Call.cint (rows*cols) liftIO $ Blas.scal nPtr betaPtr bPtr incPtr liftIO $ Blas.axpy nPtr alphaPtr aPtr incPtr bPtr incPtr else do nPtr <- Call.cint rows liftIO $ sequence_ $ take cols $ zipWith (\akPtr bkPtr -> do Blas.scal nPtr betaPtr bkPtr incPtr Blas.axpy nPtr alphaPtr akPtr incPtr bkPtr incPtr) (pointerSeq lda aPtr) (pointerSeq ldb bPtr)