-- |
-- This module contains code largely copied from the @network@ library
-- (BSD-3-Clause), with modifications to remove blocking calls
-- (threadWaitWrite, threadWaitRead).
--
-- Original copyright:
--   Copyright (c) 2002-2010, The University Court of the University of Glasgow
--   Copyright (c) 2007-2010, Johan Tibell
--
-- See: https://github.com/haskell/network/blob/master/LICENSE
module Hpgsql.Networking
  ( recvNonBlocking,
    socketWaitRead,
    socketWaitWrite,
    sendNonBlocking,
  )
where

import Control.Concurrent (threadWaitRead, threadWaitWrite)
import Control.Exception.Safe (throw)
import Data.ByteString (ByteString)
import Data.ByteString.Internal (createAndTrim)
import qualified Data.ByteString.Lazy as L
import Data.ByteString.Unsafe (unsafeUseAsCStringLen)
import Data.Int (Int64)
import Foreign (Ptr, Storable (..), Word8, allocaArray, castPtr, nullPtr, plusPtr)
import Foreign.C (CChar (..), CInt (..), CSize (..), eAGAIN, eWOULDBLOCK, getErrno)
import Network.Socket (Socket, withFdSocket)
import System.Posix.Types (CSsize (..))

socketWaitRead :: Socket -> IO ()
socketWaitRead :: Socket -> IO ()
socketWaitRead Socket
socket = Socket -> (CInt -> IO ()) -> IO ()
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
socket (Fd -> IO ()
threadWaitRead (Fd -> IO ()) -> (CInt -> Fd) -> CInt -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CInt -> Fd
forall a b. (Integral a, Num b) => a -> b
fromIntegral)

socketWaitWrite :: Socket -> IO ()
socketWaitWrite :: Socket -> IO ()
socketWaitWrite Socket
socket = Socket -> (CInt -> IO ()) -> IO ()
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
socket (Fd -> IO ()
threadWaitWrite (Fd -> IO ()) -> (CInt -> Fd) -> CInt -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CInt -> Fd
forall a b. (Integral a, Num b) => a -> b
fromIntegral)

recvNonBlocking :: Socket -> CSize -> IO ByteString
recvNonBlocking :: Socket -> CSize -> IO ByteString
recvNonBlocking Socket
s CSize
nbytes = Socket -> (CInt -> IO ByteString) -> IO ByteString
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
s ((CInt -> IO ByteString) -> IO ByteString)
-> (CInt -> IO ByteString) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \CInt
fd -> Int -> (Ptr Word8 -> IO Int) -> IO ByteString
createAndTrim (CSize -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CSize
nbytes) ((Ptr Word8 -> IO Int) -> IO ByteString)
-> (Ptr Word8 -> IO Int) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
buffer -> do
  -- Largely copied from https://hackage-content.haskell.org/package/network-3.2.8.0/docs/src/Network.Socket.Buffer.html#recvBufNoWait and other functions from the network library,
  -- but then modified to our needs.
  r <- CInt -> Ptr CChar -> CSize -> CInt -> IO CInt
c_recv CInt
fd (Ptr Word8 -> Ptr CChar
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
buffer) CSize
nbytes CInt
0 {-flags-}
  if r >= 0
    then do
      -- putStrLn $ "Asked for " ++ show nbytes ++ ", got " ++ show r
      pure $ fromIntegral r
    else do
      err <- getErrno
      if err == eAGAIN || err == eWOULDBLOCK
        then do
          -- putStrLn $ "Asked for " ++ show nbytes ++ ", but got eAGAIN"
          pure 0
        else
          throw $ userError "Internal error in hpgsql's recvNonBlocking"

sendNonBlocking :: Socket -> L.ByteString -> IO Int64
sendNonBlocking :: Socket -> ByteString -> IO Int64
sendNonBlocking Socket
s ByteString
lbs = do
  -- Largely copied from https://hackage-content.haskell.org/package/network-3.2.8.0/docs/src/Network.Socket.ByteString.Lazy.Posix.html#send,
  -- but then modified to our needs.
  let cs :: [ByteString]
cs = Int -> [ByteString] -> [ByteString]
forall a. Int -> [a] -> [a]
take Int
maxNumChunks (ByteString -> [ByteString]
L.toChunks ByteString
lbs)
      len :: Int
len = [ByteString] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ByteString]
cs
  siz <- Socket -> (CInt -> IO CSsize) -> IO CSsize
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
s ((CInt -> IO CSsize) -> IO CSsize)
-> (CInt -> IO CSsize) -> IO CSsize
forall a b. (a -> b) -> a -> b
$ \CInt
fd -> Int -> (Ptr IOVec -> IO CSsize) -> IO CSsize
forall a b. Storable a => Int -> (Ptr a -> IO b) -> IO b
allocaArray Int
len ((Ptr IOVec -> IO CSsize) -> IO CSsize)
-> (Ptr IOVec -> IO CSsize) -> IO CSsize
forall a b. (a -> b) -> a -> b
$ \Ptr IOVec
ptr ->
    [ByteString] -> Ptr IOVec -> (CInt -> IO CSsize) -> IO CSsize
forall {t} {a}.
Num t =>
[ByteString] -> Ptr IOVec -> (t -> IO a) -> IO a
withPokes [ByteString]
cs Ptr IOVec
ptr ((CInt -> IO CSsize) -> IO CSsize)
-> (CInt -> IO CSsize) -> IO CSsize
forall a b. (a -> b) -> a -> b
$ \CInt
niovs ->
      -- This part has `throwSocketErrorWaitWrite s "writev"` in the
      -- original codebase, but we don't have that here because that
      -- calls threadWaitWrite, which blocks.
      CInt -> Ptr IOVec -> CInt -> IO CSsize
c_writev CInt
fd Ptr IOVec
ptr CInt
niovs
  return $ fromIntegral siz
  where
    withPokes :: [ByteString] -> Ptr IOVec -> (t -> IO a) -> IO a
withPokes [ByteString]
ss Ptr IOVec
p t -> IO a
f = [ByteString] -> Ptr IOVec -> Int -> t -> IO a
loop [ByteString]
ss Ptr IOVec
p Int
0 t
0
      where
        loop :: [ByteString] -> Ptr IOVec -> Int -> t -> IO a
loop (ByteString
c : [ByteString]
cs) Ptr IOVec
q Int
k t
niovs
          | Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
maxNumBytes = ByteString -> (CStringLen -> IO a) -> IO a
forall a. ByteString -> (CStringLen -> IO a) -> IO a
unsafeUseAsCStringLen ByteString
c ((CStringLen -> IO a) -> IO a) -> (CStringLen -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \(Ptr CChar
ptr, Int
len) -> do
              Ptr IOVec -> IOVec -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr IOVec
q (IOVec -> IO ()) -> IOVec -> IO ()
forall a b. (a -> b) -> a -> b
$ Ptr Word8 -> CSize -> IOVec
IOVec (Ptr CChar -> Ptr Word8
forall a b. Ptr a -> Ptr b
castPtr Ptr CChar
ptr) (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)
              [ByteString] -> Ptr IOVec -> Int -> t -> IO a
loop
                [ByteString]
cs
                (Ptr IOVec
q Ptr IOVec -> Int -> Ptr IOVec
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` IOVec -> Int
forall a. Storable a => a -> Int
sizeOf (Ptr Word8 -> CSize -> IOVec
IOVec Ptr Word8
forall a. Ptr a
nullPtr CSize
0))
                (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)
                (t
niovs t -> t -> t
forall a. Num a => a -> a -> a
+ t
1)
          | Bool
otherwise = t -> IO a
f t
niovs
        loop [ByteString]
_ Ptr IOVec
_ Int
_ t
niovs = t -> IO a
f t
niovs
    maxNumBytes :: Int
maxNumBytes = Int
4194304 :: Int -- maximum number of bytes to transmit in one system call
    maxNumChunks :: Int
maxNumChunks = Int
1024 :: Int -- maximum number of chunks to transmit in one system call

data IOVec = IOVec
  { IOVec -> Ptr Word8
iovBase :: Ptr Word8,
    IOVec -> CSize
iovLen :: CSize
  }

instance Storable IOVec where
  sizeOf :: IOVec -> Int
sizeOf ~IOVec
_ = (Int
16)
  alignment :: IOVec -> Int
alignment ~IOVec
_ = CInt -> Int
forall a. Storable a => a -> Int
alignment (CInt
0 :: CInt)

  peek :: Ptr IOVec -> IO IOVec
peek Ptr IOVec
p = do
    base <- ((\Ptr IOVec
hsc_ptr -> Ptr IOVec -> Int -> IO (Ptr Word8)
forall b. Ptr b -> Int -> IO (Ptr Word8)
forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr IOVec
hsc_ptr Int
0)) Ptr IOVec
p
    len <- ((\Ptr IOVec
hsc_ptr -> Ptr IOVec -> Int -> IO CSize
forall b. Ptr b -> Int -> IO CSize
forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr IOVec
hsc_ptr Int
8)) p
    return $ IOVec base len

  poke :: Ptr IOVec -> IOVec -> IO ()
poke Ptr IOVec
p IOVec
iov = do
    ((\Ptr IOVec
hsc_ptr -> Ptr IOVec -> Int -> Ptr Word8 -> IO ()
forall b. Ptr b -> Int -> Ptr Word8 -> IO ()
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr IOVec
hsc_ptr Int
0)) Ptr IOVec
p (IOVec -> Ptr Word8
iovBase IOVec
iov)
    ((\Ptr IOVec
hsc_ptr -> Ptr IOVec -> Int -> CSize -> IO ()
forall b. Ptr b -> Int -> CSize -> IO ()
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr IOVec
hsc_ptr Int
8)) Ptr IOVec
p (IOVec -> CSize
iovLen IOVec
iov)

-- -- | @withIOVec cs f@ executes the computation @f@, passing as argument a pair
-- -- consisting of a pointer to a temporarily allocated array of pointers to
-- -- IOVec made from @cs@ and the number of pointers (@length cs@).
-- -- /Unix only/.
-- withIOVec :: [(Ptr Word8, Int)] -> ((Ptr IOVec, Int) -> IO a) -> IO a
-- withIOVec [] f = f (nullPtr, 0)
-- withIOVec cs f =
--   allocaArray csLen $ \aPtr -> do
--     zipWithM_ pokeIov (ptrs aPtr) cs
--     f (aPtr, csLen)
--   where
--     csLen = length cs
--     ptrs = iterate (`plusPtr` sizeOf (IOVec nullPtr 0))
--     pokeIov ptr (sPtr, sLen) = poke ptr $ IOVec sPtr (fromIntegral sLen)

foreign import ccall unsafe "recv"
  c_recv :: CInt -> Ptr CChar -> CSize -> CInt -> IO CInt

foreign import ccall unsafe "writev"
  c_writev :: CInt -> Ptr IOVec -> CInt -> IO CSsize