module Hpgsql.Networking
( recvNonBlocking,
socketWaitRead,
socketWaitWrite,
sendNonBlocking,
)
where
import Control.Concurrent (threadWaitRead, threadWaitWrite)
import Control.Exception.Safe (throw)
import Data.ByteString (ByteString)
import Data.ByteString.Internal (createAndTrim)
import qualified Data.ByteString.Lazy as L
import Data.ByteString.Unsafe (unsafeUseAsCStringLen)
import Data.Int (Int64)
import Foreign (Ptr, Storable (..), Word8, allocaArray, castPtr, nullPtr, plusPtr)
import Foreign.C (CChar (..), CInt (..), CSize (..), eAGAIN, eWOULDBLOCK, getErrno)
import Network.Socket (Socket, withFdSocket)
import System.Posix.Types (CSsize (..))
socketWaitRead :: Socket -> IO ()
socketWaitRead :: Socket -> IO ()
socketWaitRead Socket
socket = Socket -> (CInt -> IO ()) -> IO ()
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
socket (Fd -> IO ()
threadWaitRead (Fd -> IO ()) -> (CInt -> Fd) -> CInt -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CInt -> Fd
forall a b. (Integral a, Num b) => a -> b
fromIntegral)
socketWaitWrite :: Socket -> IO ()
socketWaitWrite :: Socket -> IO ()
socketWaitWrite Socket
socket = Socket -> (CInt -> IO ()) -> IO ()
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
socket (Fd -> IO ()
threadWaitWrite (Fd -> IO ()) -> (CInt -> Fd) -> CInt -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CInt -> Fd
forall a b. (Integral a, Num b) => a -> b
fromIntegral)
recvNonBlocking :: Socket -> CSize -> IO ByteString
recvNonBlocking :: Socket -> CSize -> IO ByteString
recvNonBlocking Socket
s CSize
nbytes = Socket -> (CInt -> IO ByteString) -> IO ByteString
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
s ((CInt -> IO ByteString) -> IO ByteString)
-> (CInt -> IO ByteString) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \CInt
fd -> Int -> (Ptr Word8 -> IO Int) -> IO ByteString
createAndTrim (CSize -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CSize
nbytes) ((Ptr Word8 -> IO Int) -> IO ByteString)
-> (Ptr Word8 -> IO Int) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
buffer -> do
r <- CInt -> Ptr CChar -> CSize -> CInt -> IO CInt
c_recv CInt
fd (Ptr Word8 -> Ptr CChar
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
buffer) CSize
nbytes CInt
0
if r >= 0
then do
pure $ fromIntegral r
else do
err <- getErrno
if err == eAGAIN || err == eWOULDBLOCK
then do
pure 0
else
throw $ userError "Internal error in hpgsql's recvNonBlocking"
sendNonBlocking :: Socket -> L.ByteString -> IO Int64
sendNonBlocking :: Socket -> ByteString -> IO Int64
sendNonBlocking Socket
s ByteString
lbs = do
let cs :: [ByteString]
cs = Int -> [ByteString] -> [ByteString]
forall a. Int -> [a] -> [a]
take Int
maxNumChunks (ByteString -> [ByteString]
L.toChunks ByteString
lbs)
len :: Int
len = [ByteString] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ByteString]
cs
siz <- Socket -> (CInt -> IO CSsize) -> IO CSsize
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
s ((CInt -> IO CSsize) -> IO CSsize)
-> (CInt -> IO CSsize) -> IO CSsize
forall a b. (a -> b) -> a -> b
$ \CInt
fd -> Int -> (Ptr IOVec -> IO CSsize) -> IO CSsize
forall a b. Storable a => Int -> (Ptr a -> IO b) -> IO b
allocaArray Int
len ((Ptr IOVec -> IO CSsize) -> IO CSsize)
-> (Ptr IOVec -> IO CSsize) -> IO CSsize
forall a b. (a -> b) -> a -> b
$ \Ptr IOVec
ptr ->
[ByteString] -> Ptr IOVec -> (CInt -> IO CSsize) -> IO CSsize
forall {t} {a}.
Num t =>
[ByteString] -> Ptr IOVec -> (t -> IO a) -> IO a
withPokes [ByteString]
cs Ptr IOVec
ptr ((CInt -> IO CSsize) -> IO CSsize)
-> (CInt -> IO CSsize) -> IO CSsize
forall a b. (a -> b) -> a -> b
$ \CInt
niovs ->
CInt -> Ptr IOVec -> CInt -> IO CSsize
c_writev CInt
fd Ptr IOVec
ptr CInt
niovs
return $ fromIntegral siz
where
withPokes :: [ByteString] -> Ptr IOVec -> (t -> IO a) -> IO a
withPokes [ByteString]
ss Ptr IOVec
p t -> IO a
f = [ByteString] -> Ptr IOVec -> Int -> t -> IO a
loop [ByteString]
ss Ptr IOVec
p Int
0 t
0
where
loop :: [ByteString] -> Ptr IOVec -> Int -> t -> IO a
loop (ByteString
c : [ByteString]
cs) Ptr IOVec
q Int
k t
niovs
| Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
maxNumBytes = ByteString -> (CStringLen -> IO a) -> IO a
forall a. ByteString -> (CStringLen -> IO a) -> IO a
unsafeUseAsCStringLen ByteString
c ((CStringLen -> IO a) -> IO a) -> (CStringLen -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \(Ptr CChar
ptr, Int
len) -> do
Ptr IOVec -> IOVec -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr IOVec
q (IOVec -> IO ()) -> IOVec -> IO ()
forall a b. (a -> b) -> a -> b
$ Ptr Word8 -> CSize -> IOVec
IOVec (Ptr CChar -> Ptr Word8
forall a b. Ptr a -> Ptr b
castPtr Ptr CChar
ptr) (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)
[ByteString] -> Ptr IOVec -> Int -> t -> IO a
loop
[ByteString]
cs
(Ptr IOVec
q Ptr IOVec -> Int -> Ptr IOVec
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` IOVec -> Int
forall a. Storable a => a -> Int
sizeOf (Ptr Word8 -> CSize -> IOVec
IOVec Ptr Word8
forall a. Ptr a
nullPtr CSize
0))
(Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)
(t
niovs t -> t -> t
forall a. Num a => a -> a -> a
+ t
1)
| Bool
otherwise = t -> IO a
f t
niovs
loop [ByteString]
_ Ptr IOVec
_ Int
_ t
niovs = t -> IO a
f t
niovs
maxNumBytes :: Int
maxNumBytes = Int
4194304 :: Int
maxNumChunks :: Int
maxNumChunks = Int
1024 :: Int
data IOVec = IOVec
{ IOVec -> Ptr Word8
iovBase :: Ptr Word8,
IOVec -> CSize
iovLen :: CSize
}
instance Storable IOVec where
sizeOf :: IOVec -> Int
sizeOf ~IOVec
_ = (Int
16)
alignment :: IOVec -> Int
alignment ~IOVec
_ = CInt -> Int
forall a. Storable a => a -> Int
alignment (CInt
0 :: CInt)
peek :: Ptr IOVec -> IO IOVec
peek Ptr IOVec
p = do
base <- ((\Ptr IOVec
hsc_ptr -> Ptr IOVec -> Int -> IO (Ptr Word8)
forall b. Ptr b -> Int -> IO (Ptr Word8)
forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr IOVec
hsc_ptr Int
0)) Ptr IOVec
p
len <- ((\Ptr IOVec
hsc_ptr -> Ptr IOVec -> Int -> IO CSize
forall b. Ptr b -> Int -> IO CSize
forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr IOVec
hsc_ptr Int
8)) p
return $ IOVec base len
poke :: Ptr IOVec -> IOVec -> IO ()
poke Ptr IOVec
p IOVec
iov = do
((\Ptr IOVec
hsc_ptr -> Ptr IOVec -> Int -> Ptr Word8 -> IO ()
forall b. Ptr b -> Int -> Ptr Word8 -> IO ()
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr IOVec
hsc_ptr Int
0)) Ptr IOVec
p (IOVec -> Ptr Word8
iovBase IOVec
iov)
((\Ptr IOVec
hsc_ptr -> Ptr IOVec -> Int -> CSize -> IO ()
forall b. Ptr b -> Int -> CSize -> IO ()
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr IOVec
hsc_ptr Int
8)) Ptr IOVec
p (IOVec -> CSize
iovLen IOVec
iov)
foreign import ccall unsafe "recv"
c_recv :: CInt -> Ptr CChar -> CSize -> CInt -> IO CInt
foreign import ccall unsafe "writev"
c_writev :: CInt -> Ptr IOVec -> CInt -> IO CSsize