module Network.QUIC.Simple.Stream
( MessageQueues
, streamSerialise
, streamCodec
) where
import Codec.Serialise (Serialise, serialise, deserialiseIncremental)
import Codec.Serialise qualified as IDecode (IDecode(..))
import Control.Concurrent.Async (Async, async, race_)
import Control.Concurrent.STM
import Control.Exception (finally, throwIO)
import Control.Monad.ST (stToIO)
import Data.ByteString qualified as BS
import Data.ByteString.Lazy qualified as BSL
import Data.IORef
import Network.QUIC qualified as QUIC
type MessageQueues sendMsg recvMsg = (TBQueue sendMsg, TBQueue recvMsg)
streamSerialise
:: forall sendMsg recvMsg
. (Serialise sendMsg, Serialise recvMsg)
=> QUIC.Stream
-> IO (Async (), MessageQueues sendMsg recvMsg)
streamSerialise :: forall sendMsg recvMsg.
(Serialise sendMsg, Serialise recvMsg) =>
Stream -> IO (Async (), MessageQueues sendMsg recvMsg)
streamSerialise Stream
stream = do
IDecode RealWorld recvMsg
initial <- ST RealWorld (IDecode RealWorld recvMsg)
-> IO (IDecode RealWorld recvMsg)
forall a. ST RealWorld a -> IO a
stToIO (ST RealWorld (IDecode RealWorld recvMsg)
-> IO (IDecode RealWorld recvMsg))
-> ST RealWorld (IDecode RealWorld recvMsg)
-> IO (IDecode RealWorld recvMsg)
forall a b. (a -> b) -> a -> b
$ forall a s. Serialise a => ST s (IDecode s a)
deserialiseIncremental @recvMsg
IORef (IDecode RealWorld recvMsg)
state <- IDecode RealWorld recvMsg -> IO (IORef (IDecode RealWorld recvMsg))
forall a. a -> IO (IORef a)
newIORef IDecode RealWorld recvMsg
initial
let
decode :: Bool -> ByteString -> IO (ByteString, Maybe recvMsg)
decode Bool
starting ByteString
chunk = do
IDecode RealWorld recvMsg
decoder <- IORef (IDecode RealWorld recvMsg) -> IO (IDecode RealWorld recvMsg)
forall a. IORef a -> IO a
readIORef IORef (IDecode RealWorld recvMsg)
state
case IDecode RealWorld recvMsg
decoder of
IDecode.Fail ByteString
_leftovers ByteOffset
_offset DeserialiseFailure
err ->
DeserialiseFailure -> IO (ByteString, Maybe recvMsg)
forall e a. Exception e => e -> IO a
throwIO DeserialiseFailure
err
IDecode.Done ByteString
leftovers ByteOffset
_consumed recvMsg
msg -> do
ST RealWorld (IDecode RealWorld recvMsg)
-> IO (IDecode RealWorld recvMsg)
forall a. ST RealWorld a -> IO a
stToIO ST RealWorld (IDecode RealWorld recvMsg)
forall a s. Serialise a => ST s (IDecode s a)
deserialiseIncremental IO (IDecode RealWorld recvMsg)
-> (IDecode RealWorld recvMsg -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IORef (IDecode RealWorld recvMsg)
-> IDecode RealWorld recvMsg -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef (IDecode RealWorld recvMsg)
state
(ByteString, Maybe recvMsg) -> IO (ByteString, Maybe recvMsg)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString
leftovers, recvMsg -> Maybe recvMsg
forall a. a -> Maybe a
Just recvMsg
msg)
IDecode.Partial Maybe ByteString -> ST RealWorld (IDecode RealWorld recvMsg)
consume -> do
ST RealWorld (IDecode RealWorld recvMsg)
-> IO (IDecode RealWorld recvMsg)
forall a. ST RealWorld a -> IO a
stToIO (Maybe ByteString -> ST RealWorld (IDecode RealWorld recvMsg)
consume (Maybe ByteString -> ST RealWorld (IDecode RealWorld recvMsg))
-> Maybe ByteString -> ST RealWorld (IDecode RealWorld recvMsg)
forall a b. (a -> b) -> a -> b
$ ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
chunk) IO (IDecode RealWorld recvMsg)
-> (IDecode RealWorld recvMsg -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IORef (IDecode RealWorld recvMsg)
-> IDecode RealWorld recvMsg -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef (IDecode RealWorld recvMsg)
state
if Bool
starting then
Bool -> ByteString -> IO (ByteString, Maybe recvMsg)
decode Bool
False ByteString
""
else
(ByteString, Maybe recvMsg) -> IO (ByteString, Maybe recvMsg)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString
"", Maybe recvMsg
forall a. Maybe a
Nothing)
(sendMsg -> ByteString)
-> (ByteString -> IO (ByteString, Maybe recvMsg))
-> Stream
-> IO (Async (), MessageQueues sendMsg recvMsg)
forall sendMsg recvMsg.
(sendMsg -> ByteString)
-> (ByteString -> IO (ByteString, Maybe recvMsg))
-> Stream
-> IO (Async (), MessageQueues sendMsg recvMsg)
streamCodec sendMsg -> ByteString
forall a. Serialise a => a -> ByteString
serialise (Bool -> ByteString -> IO (ByteString, Maybe recvMsg)
decode Bool
True) Stream
stream
streamCodec
:: (sendMsg -> BSL.ByteString)
-> (BS.ByteString -> IO (BS.ByteString, Maybe recvMsg))
-> QUIC.Stream
-> IO (Async (), MessageQueues sendMsg recvMsg)
streamCodec :: forall sendMsg recvMsg.
(sendMsg -> ByteString)
-> (ByteString -> IO (ByteString, Maybe recvMsg))
-> Stream
-> IO (Async (), MessageQueues sendMsg recvMsg)
streamCodec sendMsg -> ByteString
encode ByteString -> IO (ByteString, Maybe recvMsg)
decode Stream
stream = do
TBQueue recvMsg
readQ <- Natural -> IO (TBQueue recvMsg)
forall a. Natural -> IO (TBQueue a)
newTBQueueIO Natural
1024
TBQueue sendMsg
writeQ <- Natural -> IO (TBQueue sendMsg)
forall a. Natural -> IO (TBQueue a)
newTBQueueIO Natural
1024
Async ()
worker <- IO () -> IO (Async ())
forall a. IO a -> IO (Async a)
async (IO () -> IO (Async ())) -> IO () -> IO (Async ())
forall a b. (a -> b) -> a -> b
$
IO Any -> IO Any -> IO ()
forall a b. IO a -> IO b -> IO ()
race_ (ByteString -> TBQueue recvMsg -> IO Any
forall {b}. ByteString -> TBQueue recvMsg -> IO b
reader ByteString
"" TBQueue recvMsg
readQ) (TBQueue sendMsg -> IO Any
forall {b}. TBQueue sendMsg -> IO b
writer TBQueue sendMsg
writeQ) IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO a
`finally` Stream -> IO ()
QUIC.closeStream Stream
stream
(Async (), MessageQueues sendMsg recvMsg)
-> IO (Async (), MessageQueues sendMsg recvMsg)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Async ()
worker, (TBQueue sendMsg
writeQ, TBQueue recvMsg
readQ))
where
reader :: ByteString -> TBQueue recvMsg -> IO b
reader ByteString
leftovers TBQueue recvMsg
readQ = do
ByteString
chunk <-
if ByteString -> Bool
BS.null ByteString
leftovers then
Stream -> Int -> IO ByteString
QUIC.recvStream Stream
stream Int
4096
else
ByteString -> IO ByteString
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
leftovers
(ByteString
leftovers', Maybe recvMsg
message_) <- ByteString -> IO (ByteString, Maybe recvMsg)
decode ByteString
chunk
(recvMsg -> IO ()) -> Maybe recvMsg -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> (recvMsg -> STM ()) -> recvMsg -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TBQueue recvMsg -> recvMsg -> STM ()
forall a. TBQueue a -> a -> STM ()
writeTBQueue TBQueue recvMsg
readQ) Maybe recvMsg
message_
ByteString -> TBQueue recvMsg -> IO b
reader ByteString
leftovers' TBQueue recvMsg
readQ
writer :: TBQueue sendMsg -> IO b
writer TBQueue sendMsg
writeQ = do
sendMsg
message <- STM sendMsg -> IO sendMsg
forall a. STM a -> IO a
atomically (STM sendMsg -> IO sendMsg) -> STM sendMsg -> IO sendMsg
forall a b. (a -> b) -> a -> b
$ TBQueue sendMsg -> STM sendMsg
forall a. TBQueue a -> STM a
readTBQueue TBQueue sendMsg
writeQ
let chunks :: [ByteString]
chunks = ByteString -> [ByteString]
BSL.toChunks (ByteString -> [ByteString]) -> ByteString -> [ByteString]
forall a b. (a -> b) -> a -> b
$ sendMsg -> ByteString
encode sendMsg
message
Stream -> [ByteString] -> IO ()
QUIC.sendStreamMany Stream
stream [ByteString]
chunks
TBQueue sendMsg -> IO b
writer TBQueue sendMsg
writeQ