-- | Node with client role (i.e., its peer is a server)
module Network.GRPC.Util.Session.Client (
    ConnectionToServer(..)
  , NoTrailers(..)
  , CancelRequest
  , setupRequestChannel
  ) where

import Control.Concurrent
import Control.Concurrent.STM
import Control.Monad
import Control.Monad.Catch
import Data.ByteString qualified as BS.Strict
import Data.ByteString qualified as Strict (ByteString)
import Data.ByteString.Lazy qualified as BS.Lazy
import Data.ByteString.Lazy qualified as Lazy (ByteString)
import Data.Proxy
import Network.HTTP.Types qualified as HTTP
import Network.HTTP2.Client qualified as Client

import Network.GRPC.Util.HTTP2 (fromHeaderTable)
import Network.GRPC.Util.HTTP2.Stream
import Network.GRPC.Util.RedundantConstraint (addConstraint)
import Network.GRPC.Util.Session.API
import Network.GRPC.Util.Session.Channel
import Network.GRPC.Util.Thread

{-------------------------------------------------------------------------------
  Connection
-------------------------------------------------------------------------------}

-- | Connection to the server, as provided by @http2@
data ConnectionToServer = ConnectionToServer {
      ConnectionToServer
-> forall a. Request -> (Response -> IO a) -> IO a
sendRequest :: forall a.
           Client.Request
        -> (Client.Response -> IO a)
        -> IO a
    }

{-------------------------------------------------------------------------------
  Initiate request

  Control flow and exception handling here is a little tricky.

  * 'sendRequest' (coming from @http2@) will itself spawn a separate thread to
    deal with sending inputs to the server (here, that is 'sendMessageLoop').

  * As the last thing it does, 'sendRequest' will then call its continuation
    argument in whatever thread it was called. Here, that continuation argument
    is 'recvMessageLoop' (dealing with outputs sent by the server), which we
    want to run in a separate thread also. We must therefore call 'sendRequest'
    in a newly forked thread.

    Note that 'sendRequest' will /only/ call its continuation once it receives
    a response from the server. Some servers will not send the (start of the)
    response until it has received (part of) the request, so it is important
    that we do not wait on 'sendRequest' before we return control to the caller.

  * We need to decide what to do with any exceptions that might arise in these
    threads. One option might be to rethrow those exceptions to the parent
    thread, but this presupposes a certain level of hygiene in client code: if
    the client code spawns threads of their own, and shares the open RPC call
    between them, then we rely on the client code to propagate the exception
    further.

    We therefore choose a different approach: we do not try to propagate the
    exception at all, but instead ensure that any (current or future) call to
    'read' or 'write' (or 'push') on the open 'Peer' will result in an exception
    when the respective threads have died.

  * This leaves one final problem to deal with: if setting up the request
    /itself/ throws an exception, one or both threads might never get started.
    To avoid client code blocking indefinitely we will therefore catch any
    exceptions that arise during the call setup, and use them to 'cancel' both
    threads (which will kill them or ensure they never get started).
-------------------------------------------------------------------------------}

-- | No trailers
--
-- We do not support outbound trailers in the client. This simplifies control
-- flow: when the inbound stream closes (because the server terminates the RPC),
-- we kill the outbound thread; by not supporting outbound trailers, we avoid
-- that exception propagating to the trailers maker, which causes http2 to
-- panic and shut down the entire connection.
class NoTrailers sess where
  -- | There is no interesting information in the trailers
  noTrailers :: Proxy sess -> Trailers (Outbound sess)

type CancelRequest = Maybe SomeException -> IO ()

-- | Setup request channel
--
-- This initiates a new request.
--
setupRequestChannel :: forall sess.
     (InitiateSession sess, NoTrailers sess)
  => sess
  -> ConnectionToServer
  -> (InboundResult sess -> SomeException)
  -- ^ We assume that when the server closes their outbound connection to us,
  -- the entire conversation is over (i.e., the server cannot "half-close").
  -> FlowStart (Outbound sess)
  -> IO (Channel sess, CancelRequest)
setupRequestChannel :: forall sess.
(InitiateSession sess, NoTrailers sess) =>
sess
-> ConnectionToServer
-> (InboundResult sess -> SomeException)
-> FlowStart (Outbound sess)
-> IO (Channel sess, CancelRequest)
setupRequestChannel sess
sess
                    ConnectionToServer{forall a. Request -> (Response -> IO a) -> IO a
sendRequest :: ConnectionToServer
-> forall a. Request -> (Response -> IO a) -> IO a
sendRequest :: forall a. Request -> (Response -> IO a) -> IO a
sendRequest}
                    InboundResult sess -> SomeException
terminateCall
                    FlowStart (Outbound sess)
outboundStart
                  = do
    channel <- IO (Channel sess)
forall sess. HasCallStack => IO (Channel sess)
initChannel
    let requestInfo = sess -> FlowStart (Outbound sess) -> RequestInfo
forall sess.
InitiateSession sess =>
sess -> FlowStart (Outbound sess) -> RequestInfo
buildRequestInfo sess
sess FlowStart (Outbound sess)
outboundStart

    cancelRequestVar <- newEmptyMVar
    let cancelRequest :: CancelRequest
        cancelRequest Maybe SomeException
e = IO (IO ()) -> IO ()
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (IO (IO ()) -> IO ())
-> (IO CancelRequest -> IO (IO ())) -> IO CancelRequest -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((CancelRequest -> IO ()) -> IO CancelRequest -> IO (IO ())
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (CancelRequest -> CancelRequest
forall a b. (a -> b) -> a -> b
$ Maybe SomeException
e)) (IO CancelRequest -> IO ()) -> IO CancelRequest -> IO ()
forall a b. (a -> b) -> a -> b
$ MVar CancelRequest -> IO CancelRequest
forall a. MVar a -> IO a
readMVar MVar CancelRequest
cancelRequestVar

    case outboundStart of
      FlowStartRegular Headers (Outbound sess)
headers -> do
        regular <- Headers (Outbound sess) -> IO (RegularFlowState (Outbound sess))
forall {k} (flow :: k). Headers flow -> IO (RegularFlowState flow)
initFlowStateRegular Headers (Outbound sess)
headers
        let req :: Client.Request
            req = Method
-> Method -> RequestHeaders -> (OutBodyIface -> IO ()) -> Request
Client.requestStreamingIface
                    (RequestInfo -> Method
requestMethod  RequestInfo
requestInfo)
                    (RequestInfo -> Method
requestPath    RequestInfo
requestInfo)
                    (RequestInfo -> RequestHeaders
requestHeaders RequestInfo
requestInfo)
                ((OutBodyIface -> IO ()) -> Request)
-> (OutBodyIface -> IO ()) -> Request
forall a b. (a -> b) -> a -> b
$ Channel sess
-> MVar CancelRequest
-> RegularFlowState (Outbound sess)
-> OutBodyIface
-> IO ()
outboundThread Channel sess
channel MVar CancelRequest
cancelRequestVar RegularFlowState (Outbound sess)
regular
        forkRequest channel req
      FlowStartNoMessages NoMessages (Outbound sess)
trailers -> do
        let state :: FlowState (Outbound sess)
            state :: FlowState (Outbound sess)
state = NoMessages (Outbound sess) -> FlowState (Outbound sess)
forall {k} (flow :: k). NoMessages flow -> FlowState flow
FlowStateNoMessages NoMessages (Outbound sess)
trailers

            req :: Client.Request
            req :: Request
req = Method -> Method -> RequestHeaders -> Request
Client.requestNoBody
                    (RequestInfo -> Method
requestMethod  RequestInfo
requestInfo)
                    (RequestInfo -> Method
requestPath    RequestInfo
requestInfo)
                    (RequestInfo -> RequestHeaders
requestHeaders RequestInfo
requestInfo)
        -- Can't cancel non-streaming request
        MVar CancelRequest -> CancelRequest -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar CancelRequest
cancelRequestVar (CancelRequest -> IO ()) -> CancelRequest -> IO ()
forall a b. (a -> b) -> a -> b
$ \Maybe SomeException
_ -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$
          TVar (ThreadState (FlowState (Outbound sess)))
-> (ThreadState (FlowState (Outbound sess))
    -> ThreadState (FlowState (Outbound sess)))
-> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar (Channel sess -> TVar (ThreadState (FlowState (Outbound sess)))
forall sess.
Channel sess -> TVar (ThreadState (FlowState (Outbound sess)))
channelOutbound Channel sess
channel) ((ThreadState (FlowState (Outbound sess))
  -> ThreadState (FlowState (Outbound sess)))
 -> STM ())
-> (ThreadState (FlowState (Outbound sess))
    -> ThreadState (FlowState (Outbound sess)))
-> STM ()
forall a b. (a -> b) -> a -> b
$ \ThreadState (FlowState (Outbound sess))
oldState ->
            case ThreadState (FlowState (Outbound sess))
oldState of
              ThreadNotStarted DebugThreadId
debugId ->
                DebugThreadId
-> FlowState (Outbound sess)
-> ThreadState (FlowState (Outbound sess))
forall a. DebugThreadId -> a -> ThreadState a
ThreadDone DebugThreadId
debugId FlowState (Outbound sess)
state
              ThreadState (FlowState (Outbound sess))
_otherwise ->
                [Char] -> ThreadState (FlowState (Outbound sess))
forall a. HasCallStack => [Char] -> a
error [Char]
"setupRequestChannel: expected thread state"
        Channel sess -> Request -> IO ()
forkRequest Channel sess
channel Request
req

    return (channel, cancelRequest)
  where
    ()
_ = forall (c :: Constraint). c => ()
addConstraint @(NoTrailers sess)

    forkRequest :: Channel sess -> Client.Request -> IO ()
    forkRequest :: Channel sess -> Request -> IO ()
forkRequest Channel sess
channel Request
req =
        [Char]
-> TVar (ThreadState (FlowState (Inbound sess)))
-> ThreadBody (FlowState (Inbound sess))
-> IO ()
forall a.
HasCallStack =>
[Char] -> TVar (ThreadState a) -> ThreadBody a -> IO ()
forkThread [Char]
"grapesy:clientInbound" (Channel sess -> TVar (ThreadState (FlowState (Inbound sess)))
forall sess.
Channel sess -> TVar (ThreadState (FlowState (Inbound sess)))
channelInbound Channel sess
channel) (ThreadBody (FlowState (Inbound sess)) -> IO ())
-> ThreadBody (FlowState (Inbound sess)) -> IO ()
forall a b. (a -> b) -> a -> b
$ \forall x. IO x -> IO x
unmask FlowState (Inbound sess) -> IO ()
markReady DebugThreadId
_debugId -> IO () -> IO ()
forall x. IO x -> IO x
unmask (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
          AllowHalfClosed sess
-> Channel sess -> IO (InboundResult sess) -> IO ()
forall sess.
IsSession sess =>
AllowHalfClosed sess
-> Channel sess -> IO (InboundResult sess) -> IO ()
linkOutboundToInbound ((InboundResult sess -> SomeException) -> AllowHalfClosed sess
forall sess.
(InboundResult sess -> SomeException) -> AllowHalfClosed sess
TerminateWhenInboundClosed InboundResult sess -> SomeException
terminateCall) Channel sess
channel (IO (InboundResult sess) -> IO ())
-> IO (InboundResult sess) -> IO ()
forall a b. (a -> b) -> a -> b
$
            Request
-> (Response -> IO (InboundResult sess)) -> IO (InboundResult sess)
forall a. Request -> (Response -> IO a) -> IO a
sendRequest Request
req ((Response -> IO (InboundResult sess)) -> IO (InboundResult sess))
-> (Response -> IO (InboundResult sess)) -> IO (InboundResult sess)
forall a b. (a -> b) -> a -> b
$ \Response
resp -> do
              responseStatus <-
                case Response -> Maybe Status
Client.responseStatus Response
resp of
                  Just Status
x  -> Status -> IO Status
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Status
x
                  Maybe Status
Nothing -> PeerException -> IO Status
forall e a. (HasCallStack, Exception e) => e -> IO a
forall (m :: * -> *) e a.
(MonadThrow m, HasCallStack, Exception e) =>
e -> m a
throwM PeerException
PeerMissingPseudoHeaderStatus

              -- Read the entire response body in case of a non-OK response
              responseBody :: Maybe Lazy.ByteString <-
                if HTTP.statusIsSuccessful responseStatus then
                  return Nothing
                else
                  Just <$> readResponseBody resp

              let responseHeaders =
                    TokenHeaderTable -> RequestHeaders
fromHeaderTable (TokenHeaderTable -> RequestHeaders)
-> TokenHeaderTable -> RequestHeaders
forall a b. (a -> b) -> a -> b
$ Response -> TokenHeaderTable
Client.responseHeaders Response
resp
                  responseInfo = ResponseInfo {
                    RequestHeaders
responseHeaders :: RequestHeaders
responseHeaders :: RequestHeaders
responseHeaders
                  , Status
responseStatus :: Status
responseStatus :: Status
responseStatus
                  , Maybe ByteString
responseBody :: Maybe ByteString
responseBody :: Maybe ByteString
responseBody
                  }

              flowStart <- parseResponse sess responseInfo
              case flowStart of
                FlowStartRegular Headers (Inbound sess)
headers -> do
                  state <- Headers (Inbound sess) -> IO (RegularFlowState (Inbound sess))
forall {k} (flow :: k). Headers flow -> IO (RegularFlowState flow)
initFlowStateRegular Headers (Inbound sess)
headers
                  stream  <- clientInputStream resp
                  markReady $ FlowStateRegular state
                  Right <$> recvMessageLoop sess state stream
                FlowStartNoMessages NoMessages (Inbound sess)
trailers -> do
                  FlowState (Inbound sess) -> IO ()
markReady (FlowState (Inbound sess) -> IO ())
-> FlowState (Inbound sess) -> IO ()
forall a b. (a -> b) -> a -> b
$ NoMessages (Inbound sess) -> FlowState (Inbound sess)
forall {k} (flow :: k). NoMessages flow -> FlowState flow
FlowStateNoMessages NoMessages (Inbound sess)
trailers
                  InboundResult sess -> IO (InboundResult sess)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (InboundResult sess -> IO (InboundResult sess))
-> InboundResult sess -> IO (InboundResult sess)
forall a b. (a -> b) -> a -> b
$ NoMessages (Inbound sess) -> InboundResult sess
forall a b. a -> Either a b
Left NoMessages (Inbound sess)
trailers

    outboundThread ::
         Channel sess
      -> MVar CancelRequest
      -> RegularFlowState (Outbound sess)
      -> Client.OutBodyIface
      -> IO ()
    outboundThread :: Channel sess
-> MVar CancelRequest
-> RegularFlowState (Outbound sess)
-> OutBodyIface
-> IO ()
outboundThread Channel sess
channel MVar CancelRequest
cancelRequestVar RegularFlowState (Outbound sess)
regular OutBodyIface
iface =
        [Char]
-> TVar (ThreadState (FlowState (Outbound sess)))
-> ((FlowState (Outbound sess) -> IO ()) -> DebugThreadId -> IO ())
-> IO ()
forall a.
HasCallStack =>
[Char]
-> TVar (ThreadState a)
-> ((a -> IO ()) -> DebugThreadId -> IO ())
-> IO ()
threadBody [Char]
"grapesy:clientOutbound" (Channel sess -> TVar (ThreadState (FlowState (Outbound sess)))
forall sess.
Channel sess -> TVar (ThreadState (FlowState (Outbound sess)))
channelOutbound Channel sess
channel) (((FlowState (Outbound sess) -> IO ()) -> DebugThreadId -> IO ())
 -> IO ())
-> ((FlowState (Outbound sess) -> IO ()) -> DebugThreadId -> IO ())
-> IO ()
forall a b. (a -> b) -> a -> b
$ \FlowState (Outbound sess) -> IO ()
markReady DebugThreadId
_debugId -> do
          FlowState (Outbound sess) -> IO ()
markReady (FlowState (Outbound sess) -> IO ())
-> FlowState (Outbound sess) -> IO ()
forall a b. (a -> b) -> a -> b
$ RegularFlowState (Outbound sess) -> FlowState (Outbound sess)
forall {k} (flow :: k). RegularFlowState flow -> FlowState flow
FlowStateRegular RegularFlowState (Outbound sess)
regular
          MVar CancelRequest -> CancelRequest -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar CancelRequest
cancelRequestVar (OutBodyIface -> CancelRequest
Client.outBodyCancel OutBodyIface
iface)
          stream <- OutBodyIface -> IO OutputStream
clientOutputStream OutBodyIface
iface
          -- Unlike the client inbound thread, or the inbound/outbound threads
          -- of the server, http2 knows about this particular thread and may
          -- raise an exception on it when the server dies. This results in a
          -- race condition between that exception and the exception we get from
          -- attempting to read the next message. No matter who wins that race,
          -- we need to mark that as 'ServerDisconnected'.
          --
          -- We don't have this top-level exception handler in other places
          -- because we don't want to mark /our own/ exceptions as
          -- 'ServerDisconnected' or 'ClientDisconnected'.
          wrapStreamExceptionsWith ServerDisconnected $
            Client.outBodyUnmask iface $ sendMessageLoop sess regular stream

{-------------------------------------------------------------------------------
   Auxiliary http2
-------------------------------------------------------------------------------}

readResponseBody :: Client.Response -> IO Lazy.ByteString
readResponseBody :: Response -> IO ByteString
readResponseBody Response
resp = [Method] -> IO ByteString
go []
  where
    go :: [Strict.ByteString] -> IO Lazy.ByteString
    go :: [Method] -> IO ByteString
go [Method]
acc = do
        chunk <- Response -> IO Method
Client.getResponseBodyChunk Response
resp
        if BS.Strict.null chunk then
          return $ BS.Lazy.fromChunks (reverse acc)
        else
          go (chunk:acc)