module Network.QUIC.Simple.Stream
  ( MessageQueues
  , streamSerialise
  , streamCodec
  ) where

import Codec.Serialise (Serialise, serialise, deserialiseIncremental)
import Codec.Serialise qualified as IDecode (IDecode(..))
import Control.Concurrent.Async (Async, async, race_)
import Control.Concurrent.STM
import Control.Exception (finally, throwIO)
import Control.Monad.ST (stToIO)
import Data.ByteString qualified as BS
import Data.ByteString.Lazy qualified as BSL
import Data.IORef
import Network.QUIC qualified as QUIC

{- | A pair of bounded queues wrapping a stream.
-}
type MessageQueues sendMsg recvMsg = (TBQueue sendMsg, TBQueue recvMsg)

{- | Wrap the stream with the CBOR codec for both incoming and outgoing messages.

The decoder will perform incremental parsing and emit complete messages.

No extra framing is required since CBOR is self-delimiting.
-}
streamSerialise
  :: forall sendMsg recvMsg
  . (Serialise sendMsg, Serialise recvMsg)
  => QUIC.Stream
  -> IO (Async (), MessageQueues sendMsg recvMsg)
streamSerialise :: forall sendMsg recvMsg.
(Serialise sendMsg, Serialise recvMsg) =>
Stream -> IO (Async (), MessageQueues sendMsg recvMsg)
streamSerialise Stream
stream = do
  IDecode RealWorld recvMsg
initial <- ST RealWorld (IDecode RealWorld recvMsg)
-> IO (IDecode RealWorld recvMsg)
forall a. ST RealWorld a -> IO a
stToIO (ST RealWorld (IDecode RealWorld recvMsg)
 -> IO (IDecode RealWorld recvMsg))
-> ST RealWorld (IDecode RealWorld recvMsg)
-> IO (IDecode RealWorld recvMsg)
forall a b. (a -> b) -> a -> b
$ forall a s. Serialise a => ST s (IDecode s a)
deserialiseIncremental @recvMsg
  IORef (IDecode RealWorld recvMsg)
state <- IDecode RealWorld recvMsg -> IO (IORef (IDecode RealWorld recvMsg))
forall a. a -> IO (IORef a)
newIORef IDecode RealWorld recvMsg
initial
  let
    decode :: Bool -> ByteString -> IO (ByteString, Maybe recvMsg)
decode Bool
starting ByteString
chunk = do
      IDecode RealWorld recvMsg
decoder <- IORef (IDecode RealWorld recvMsg) -> IO (IDecode RealWorld recvMsg)
forall a. IORef a -> IO a
readIORef IORef (IDecode RealWorld recvMsg)
state
      case IDecode RealWorld recvMsg
decoder of
        IDecode.Fail ByteString
_leftovers ByteOffset
_offset DeserialiseFailure
err ->
          DeserialiseFailure -> IO (ByteString, Maybe recvMsg)
forall e a. Exception e => e -> IO a
throwIO DeserialiseFailure
err -- crash writer (thus the stream, and the reader/writer etc)
        IDecode.Done ByteString
leftovers ByteOffset
_consumed recvMsg
msg -> do
          ST RealWorld (IDecode RealWorld recvMsg)
-> IO (IDecode RealWorld recvMsg)
forall a. ST RealWorld a -> IO a
stToIO ST RealWorld (IDecode RealWorld recvMsg)
forall a s. Serialise a => ST s (IDecode s a)
deserialiseIncremental IO (IDecode RealWorld recvMsg)
-> (IDecode RealWorld recvMsg -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IORef (IDecode RealWorld recvMsg)
-> IDecode RealWorld recvMsg -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef (IDecode RealWorld recvMsg)
state -- restart decoder
          (ByteString, Maybe recvMsg) -> IO (ByteString, Maybe recvMsg)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString
leftovers, recvMsg -> Maybe recvMsg
forall a. a -> Maybe a
Just recvMsg
msg)
        IDecode.Partial Maybe ByteString -> ST RealWorld (IDecode RealWorld recvMsg)
consume -> do
          -- want more data (initial state?)
          ST RealWorld (IDecode RealWorld recvMsg)
-> IO (IDecode RealWorld recvMsg)
forall a. ST RealWorld a -> IO a
stToIO (Maybe ByteString -> ST RealWorld (IDecode RealWorld recvMsg)
consume (Maybe ByteString -> ST RealWorld (IDecode RealWorld recvMsg))
-> Maybe ByteString -> ST RealWorld (IDecode RealWorld recvMsg)
forall a b. (a -> b) -> a -> b
$ ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
chunk) IO (IDecode RealWorld recvMsg)
-> (IDecode RealWorld recvMsg -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IORef (IDecode RealWorld recvMsg)
-> IDecode RealWorld recvMsg -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef (IDecode RealWorld recvMsg)
state -- step decoder
          if Bool
starting then
            -- re-check if done
            Bool -> ByteString -> IO (ByteString, Maybe recvMsg)
decode Bool
False ByteString
""
          else
            -- suspend and wait for next chunk
            (ByteString, Maybe recvMsg) -> IO (ByteString, Maybe recvMsg)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString
"", Maybe recvMsg
forall a. Maybe a
Nothing)
  (sendMsg -> ByteString)
-> (ByteString -> IO (ByteString, Maybe recvMsg))
-> Stream
-> IO (Async (), MessageQueues sendMsg recvMsg)
forall sendMsg recvMsg.
(sendMsg -> ByteString)
-> (ByteString -> IO (ByteString, Maybe recvMsg))
-> Stream
-> IO (Async (), MessageQueues sendMsg recvMsg)
streamCodec sendMsg -> ByteString
forall a. Serialise a => a -> ByteString
serialise (Bool -> ByteString -> IO (ByteString, Maybe recvMsg)
decode Bool
True) Stream
stream

{- | Wrap the stream with a codec to provide a TBQueue interface to it.

The decoder loop is stateless.
But it runs in IO so you can use external state and terminate the stream by erroring out.
-}
streamCodec
  :: (sendMsg -> BSL.ByteString) -- ^ Encoder for outgoing messages
  -> (BS.ByteString -> IO (BS.ByteString, Maybe recvMsg)) -- ^ Decoder for incomming chunks
  -> QUIC.Stream
  -> IO (Async (), MessageQueues sendMsg recvMsg)
streamCodec :: forall sendMsg recvMsg.
(sendMsg -> ByteString)
-> (ByteString -> IO (ByteString, Maybe recvMsg))
-> Stream
-> IO (Async (), MessageQueues sendMsg recvMsg)
streamCodec sendMsg -> ByteString
encode ByteString -> IO (ByteString, Maybe recvMsg)
decode Stream
stream = do
  TBQueue recvMsg
readQ <- Natural -> IO (TBQueue recvMsg)
forall a. Natural -> IO (TBQueue a)
newTBQueueIO Natural
1024
  TBQueue sendMsg
writeQ <- Natural -> IO (TBQueue sendMsg)
forall a. Natural -> IO (TBQueue a)
newTBQueueIO Natural
1024
  Async ()
worker <- IO () -> IO (Async ())
forall a. IO a -> IO (Async a)
async (IO () -> IO (Async ())) -> IO () -> IO (Async ())
forall a b. (a -> b) -> a -> b
$
    IO Any -> IO Any -> IO ()
forall a b. IO a -> IO b -> IO ()
race_ (ByteString -> TBQueue recvMsg -> IO Any
forall {b}. ByteString -> TBQueue recvMsg -> IO b
reader ByteString
"" TBQueue recvMsg
readQ) (TBQueue sendMsg -> IO Any
forall {b}. TBQueue sendMsg -> IO b
writer TBQueue sendMsg
writeQ) IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO a
`finally` Stream -> IO ()
QUIC.closeStream Stream
stream
  (Async (), MessageQueues sendMsg recvMsg)
-> IO (Async (), MessageQueues sendMsg recvMsg)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Async ()
worker, (TBQueue sendMsg
writeQ, TBQueue recvMsg
readQ))
  where
    reader :: ByteString -> TBQueue recvMsg -> IO b
reader ByteString
leftovers TBQueue recvMsg
readQ = do
      ByteString
chunk <-
        if ByteString -> Bool
BS.null ByteString
leftovers then
          Stream -> Int -> IO ByteString
QUIC.recvStream Stream
stream Int
4096
        else
          ByteString -> IO ByteString
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
leftovers
      (ByteString
leftovers', Maybe recvMsg
message_) <- ByteString -> IO (ByteString, Maybe recvMsg)
decode ByteString
chunk
      (recvMsg -> IO ()) -> Maybe recvMsg -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> (recvMsg -> STM ()) -> recvMsg -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TBQueue recvMsg -> recvMsg -> STM ()
forall a. TBQueue a -> a -> STM ()
writeTBQueue TBQueue recvMsg
readQ) Maybe recvMsg
message_
      ByteString -> TBQueue recvMsg -> IO b
reader ByteString
leftovers' TBQueue recvMsg
readQ

    writer :: TBQueue sendMsg -> IO b
writer TBQueue sendMsg
writeQ = do
      sendMsg
message <- STM sendMsg -> IO sendMsg
forall a. STM a -> IO a
atomically (STM sendMsg -> IO sendMsg) -> STM sendMsg -> IO sendMsg
forall a b. (a -> b) -> a -> b
$ TBQueue sendMsg -> STM sendMsg
forall a. TBQueue a -> STM a
readTBQueue TBQueue sendMsg
writeQ
      let chunks :: [ByteString]
chunks = ByteString -> [ByteString]
BSL.toChunks (ByteString -> [ByteString]) -> ByteString -> [ByteString]
forall a b. (a -> b) -> a -> b
$ sendMsg -> ByteString
encode sendMsg
message
      Stream -> [ByteString] -> IO ()
QUIC.sendStreamMany Stream
stream [ByteString]
chunks
      TBQueue sendMsg -> IO b
writer TBQueue sendMsg
writeQ