{-# LANGUAGE OverloadedStrings #-}

module Network.GRPC.Client.Session (
    ClientSession(..)
  , ClientInbound
  , ClientOutbound
  , Headers(..)
    -- * Exceptions
  , CallSetupFailure(..)
  , InvalidTrailers(..)
  ) where

import Control.Exception
import Data.Proxy
import Data.Void
import Network.HTTP.Types qualified as HTTP

import Network.GRPC.Client.Connection (Connection, ConnParams(..))
import Network.GRPC.Client.Connection qualified as Connection
import Network.GRPC.Common
import Network.GRPC.Common.Compression qualified as Compr
import Network.GRPC.Common.Headers
import Network.GRPC.Spec
import Network.GRPC.Spec.Serialization
import Network.GRPC.Util.Session

{-------------------------------------------------------------------------------
  Definition
-------------------------------------------------------------------------------}

data ClientSession rpc = ClientSession {
      forall {k} (rpc :: k). ClientSession rpc -> Connection
clientConnection :: Connection
    }

{-------------------------------------------------------------------------------
  Instances
-------------------------------------------------------------------------------}

data ClientInbound rpc
data ClientOutbound rpc

instance IsRPC rpc => DataFlow (ClientInbound rpc) where
  data Headers (ClientInbound rpc) = InboundHeaders {
        forall k (rpc :: k).
Headers (ClientInbound rpc) -> ResponseHeaders' HandledSynthesized
inbHeaders     :: ResponseHeaders' HandledSynthesized
      , forall k (rpc :: k). Headers (ClientInbound rpc) -> Compression
inbCompression :: Compression
      }
    deriving (Int -> Headers (ClientInbound rpc) -> ShowS
[Headers (ClientInbound rpc)] -> ShowS
Headers (ClientInbound rpc) -> String
(Int -> Headers (ClientInbound rpc) -> ShowS)
-> (Headers (ClientInbound rpc) -> String)
-> ([Headers (ClientInbound rpc)] -> ShowS)
-> Show (Headers (ClientInbound rpc))
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall k (rpc :: k). Int -> Headers (ClientInbound rpc) -> ShowS
forall k (rpc :: k). [Headers (ClientInbound rpc)] -> ShowS
forall k (rpc :: k). Headers (ClientInbound rpc) -> String
$cshowsPrec :: forall k (rpc :: k). Int -> Headers (ClientInbound rpc) -> ShowS
showsPrec :: Int -> Headers (ClientInbound rpc) -> ShowS
$cshow :: forall k (rpc :: k). Headers (ClientInbound rpc) -> String
show :: Headers (ClientInbound rpc) -> String
$cshowList :: forall k (rpc :: k). [Headers (ClientInbound rpc)] -> ShowS
showList :: [Headers (ClientInbound rpc)] -> ShowS
Show)

  type Message    (ClientInbound rpc) = (InboundMeta, Output rpc)
  type Trailers   (ClientInbound rpc) = ProperTrailers'
  type NoMessages (ClientInbound rpc) = TrailersOnly' HandledSynthesized

instance IsRPC rpc => DataFlow (ClientOutbound rpc) where
  data Headers (ClientOutbound rpc) = OutboundHeaders {
        forall k (rpc :: k). Headers (ClientOutbound rpc) -> RequestHeaders
outHeaders     :: RequestHeaders
      , forall k (rpc :: k). Headers (ClientOutbound rpc) -> Compression
outCompression :: Compression
      }
    deriving (Int -> Headers (ClientOutbound rpc) -> ShowS
[Headers (ClientOutbound rpc)] -> ShowS
Headers (ClientOutbound rpc) -> String
(Int -> Headers (ClientOutbound rpc) -> ShowS)
-> (Headers (ClientOutbound rpc) -> String)
-> ([Headers (ClientOutbound rpc)] -> ShowS)
-> Show (Headers (ClientOutbound rpc))
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall k (rpc :: k). Int -> Headers (ClientOutbound rpc) -> ShowS
forall k (rpc :: k). [Headers (ClientOutbound rpc)] -> ShowS
forall k (rpc :: k). Headers (ClientOutbound rpc) -> String
$cshowsPrec :: forall k (rpc :: k). Int -> Headers (ClientOutbound rpc) -> ShowS
showsPrec :: Int -> Headers (ClientOutbound rpc) -> ShowS
$cshow :: forall k (rpc :: k). Headers (ClientOutbound rpc) -> String
show :: Headers (ClientOutbound rpc) -> String
$cshowList :: forall k (rpc :: k). [Headers (ClientOutbound rpc)] -> ShowS
showList :: [Headers (ClientOutbound rpc)] -> ShowS
Show)

  type Message  (ClientOutbound rpc) = (OutboundMeta, Input rpc)
  type Trailers (ClientOutbound rpc) = NoMetadata

  -- gRPC does not support a Trailers-Only case for requests
  -- (indeed, does not support request trailers at all).
  type NoMessages (ClientOutbound rpc) = Void

instance SupportsClientRpc rpc => IsSession (ClientSession rpc) where
  type Inbound  (ClientSession rpc) = ClientInbound  rpc
  type Outbound (ClientSession rpc) = ClientOutbound rpc

  buildOutboundTrailers :: ClientSession rpc
-> Trailers (Outbound (ClientSession rpc)) -> [Header]
buildOutboundTrailers ClientSession rpc
_ = \NoMetadata
Trailers (Outbound (ClientSession rpc))
NoMetadata -> []
  parseInboundTrailers :: ClientSession rpc
-> [Header] -> IO (Trailers (Inbound (ClientSession rpc)))
parseInboundTrailers  ClientSession rpc
_ = \[Header]
trailers ->
      if [Header] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Header]
trailers then
        -- Although we parse the trailers in a lenient fashion (like all
        -- headers), only throwing errors for headers that we really need, if we
        -- get no trailers at /all/, then most likely something has gone wrong;
        -- for example, perhaps an intermediate cache has dropped the gRPC
        -- trailers entirely. We therefore check for this case separately and
        -- throw a different error.
        --
        -- We /must/ throw a GrpcException (rather than some kind of custom one)
        -- because the spec mandates that we synthesize a status and status
        -- message when the peer omits them.
        GrpcException -> IO (Trailers (Inbound (ClientSession rpc)))
forall e a. (HasCallStack, Exception e) => e -> IO a
throwIO (GrpcException -> IO (Trailers (Inbound (ClientSession rpc))))
-> GrpcException -> IO (Trailers (Inbound (ClientSession rpc)))
forall a b. (a -> b) -> a -> b
$ GrpcException {
            grpcError :: GrpcError
grpcError         = GrpcError
GrpcUnknown
          , grpcErrorMessage :: Maybe Text
grpcErrorMessage  = Text -> Maybe Text
forall a. a -> Maybe a
Just Text
"Call closed without trailers"
          , grpcErrorDetails :: Maybe ByteString
grpcErrorDetails  = Maybe ByteString
forall a. Maybe a
Nothing
          , grpcErrorMetadata :: [CustomMetadata]
grpcErrorMetadata = []
          }
      else
        ProperTrailers' -> IO ProperTrailers'
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ProperTrailers' -> IO ProperTrailers')
-> ProperTrailers' -> IO ProperTrailers'
forall a b. (a -> b) -> a -> b
$ Proxy rpc -> [Header] -> ProperTrailers'
forall {k} (rpc :: k).
IsRPC rpc =>
Proxy rpc -> [Header] -> ProperTrailers'
parseProperTrailers' (forall (t :: k). Proxy t
forall {k} (t :: k). Proxy t
Proxy @rpc) [Header]
trailers

  parseMsg :: ClientSession rpc
-> Headers (Inbound (ClientSession rpc))
-> Parser String (Message (Inbound (ClientSession rpc)))
parseMsg ClientSession rpc
_ = Proxy rpc -> Compression -> Parser String (InboundMeta, Output rpc)
forall {k} (rpc :: k).
SupportsClientRpc rpc =>
Proxy rpc -> Compression -> Parser String (InboundMeta, Output rpc)
parseOutput (forall (t :: k). Proxy t
forall {k} (t :: k). Proxy t
Proxy @rpc) (Compression -> Parser String (InboundMeta, Output rpc))
-> (Headers (ClientInbound rpc) -> Compression)
-> Headers (ClientInbound rpc)
-> Parser String (InboundMeta, Output rpc)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Headers (ClientInbound rpc) -> Compression
forall k (rpc :: k). Headers (ClientInbound rpc) -> Compression
inbCompression
  buildMsg :: ClientSession rpc
-> Headers (Outbound (ClientSession rpc))
-> Message (Outbound (ClientSession rpc))
-> Builder
buildMsg ClientSession rpc
_ = Proxy rpc -> Compression -> (OutboundMeta, Input rpc) -> Builder
forall {k} (rpc :: k).
SupportsClientRpc rpc =>
Proxy rpc -> Compression -> (OutboundMeta, Input rpc) -> Builder
buildInput  (forall (t :: k). Proxy t
forall {k} (t :: k). Proxy t
Proxy @rpc) (Compression -> (OutboundMeta, Input rpc) -> Builder)
-> (Headers (ClientOutbound rpc) -> Compression)
-> Headers (ClientOutbound rpc)
-> (OutboundMeta, Input rpc)
-> Builder
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Headers (ClientOutbound rpc) -> Compression
forall k (rpc :: k). Headers (ClientOutbound rpc) -> Compression
outCompression

instance SupportsClientRpc rpc => InitiateSession (ClientSession rpc) where
  parseResponse :: ClientSession rpc
-> ResponseInfo -> IO (FlowStart (Inbound (ClientSession rpc)))
parseResponse (ClientSession Connection
conn) (ResponseInfo Status
status [Header]
headers Maybe ByteString
body) =
      case Proxy rpc
-> Status
-> [Header]
-> Maybe ByteString
-> Either
     (TrailersOnly' GrpcException) (ResponseHeaders' GrpcException)
forall {k} (rpc :: k).
IsRPC rpc =>
Proxy rpc
-> Status
-> [Header]
-> Maybe ByteString
-> Either
     (TrailersOnly' GrpcException) (ResponseHeaders' GrpcException)
classifyServerResponse (forall (t :: k). Proxy t
forall {k} (t :: k). Proxy t
Proxy @rpc) Status
status [Header]
headers Maybe ByteString
body of
        Left TrailersOnly' GrpcException
parsed -> do
          trailersOnly <- (forall a. GrpcException -> IO a)
-> TrailersOnly' GrpcException
-> IO (TrailersOnly' HandledSynthesized)
forall (h :: (* -> *) -> *) (m :: * -> *).
(Traversable h, Monad m) =>
(forall a. GrpcException -> m a)
-> h (Checked (InvalidHeaders GrpcException))
-> m (h (Checked (InvalidHeaders HandledSynthesized)))
throwSynthesized GrpcException -> IO a
forall a. GrpcException -> IO a
forall e a. (HasCallStack, Exception e) => e -> IO a
throwIO TrailersOnly' GrpcException
parsed
          -- We classify the response as Trailers-Only if the grpc-status header
          -- is present, or when the HTTP status is anything other than 200 OK
          -- (which we treat, as per the spec, as an implicit grpc-status).
          -- The "closed without trailers" case is therefore not relevant.
          case verifyAllIf connVerifyHeaders trailersOnly of
            Left  InvalidHeaders HandledSynthesized
err   -> CallSetupFailure -> IO (FlowStart (ClientInbound rpc))
forall e a. (HasCallStack, Exception e) => e -> IO a
throwIO (CallSetupFailure -> IO (FlowStart (ClientInbound rpc)))
-> CallSetupFailure -> IO (FlowStart (ClientInbound rpc))
forall a b. (a -> b) -> a -> b
$ InvalidHeaders HandledSynthesized -> CallSetupFailure
CallSetupInvalidResponseHeaders InvalidHeaders HandledSynthesized
err
            Right RequiredHeaders TrailersOnly_
_hdrs -> FlowStart (ClientInbound rpc) -> IO (FlowStart (ClientInbound rpc))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (FlowStart (ClientInbound rpc)
 -> IO (FlowStart (ClientInbound rpc)))
-> FlowStart (ClientInbound rpc)
-> IO (FlowStart (ClientInbound rpc))
forall a b. (a -> b) -> a -> b
$ NoMessages (ClientInbound rpc) -> FlowStart (ClientInbound rpc)
forall {k} (flow :: k). NoMessages flow -> FlowStart flow
FlowStartNoMessages TrailersOnly' HandledSynthesized
NoMessages (ClientInbound rpc)
trailersOnly
        Right ResponseHeaders' GrpcException
parsed -> do
          responseHeaders <- (forall a. GrpcException -> IO a)
-> ResponseHeaders' GrpcException
-> IO (ResponseHeaders' HandledSynthesized)
forall (h :: (* -> *) -> *) (m :: * -> *).
(Traversable h, Monad m) =>
(forall a. GrpcException -> m a)
-> h (Checked (InvalidHeaders GrpcException))
-> m (h (Checked (InvalidHeaders HandledSynthesized)))
throwSynthesized GrpcException -> IO a
forall a. GrpcException -> IO a
forall e a. (HasCallStack, Exception e) => e -> IO a
throwIO ResponseHeaders' GrpcException
parsed
          case verifyAllIf connVerifyHeaders responseHeaders of
            Left  InvalidHeaders HandledSynthesized
err  -> CallSetupFailure -> IO (FlowStart (ClientInbound rpc))
forall e a. (HasCallStack, Exception e) => e -> IO a
throwIO (CallSetupFailure -> IO (FlowStart (ClientInbound rpc)))
-> CallSetupFailure -> IO (FlowStart (ClientInbound rpc))
forall a b. (a -> b) -> a -> b
$ InvalidHeaders HandledSynthesized -> CallSetupFailure
CallSetupInvalidResponseHeaders InvalidHeaders HandledSynthesized
err
            Right RequiredHeaders ResponseHeaders_
hdrs -> do
              Connection -> ResponseHeaders' HandledSynthesized -> IO ()
Connection.updateConnectionMeta Connection
conn ResponseHeaders' HandledSynthesized
responseHeaders
              cIn <- Maybe CompressionId -> IO Compression
getCompression (Maybe CompressionId -> IO Compression)
-> Maybe CompressionId -> IO Compression
forall a b. (a -> b) -> a -> b
$ RequiredHeaders ResponseHeaders_ -> Maybe CompressionId
requiredResponseCompression RequiredHeaders ResponseHeaders_
hdrs
              return $ FlowStartRegular $ InboundHeaders {
                  inbHeaders     = responseHeaders
                , inbCompression = cIn
                }
    where
      ConnParams{
          Negotation
connCompression :: Negotation
connCompression :: ConnParams -> Negotation
connCompression
        , Bool
connVerifyHeaders :: Bool
connVerifyHeaders :: ConnParams -> Bool
connVerifyHeaders
        } = Connection -> ConnParams
Connection.connParams Connection
conn

      getCompression :: Maybe CompressionId -> IO Compression
      getCompression :: Maybe CompressionId -> IO Compression
getCompression Maybe CompressionId
Nothing    = Compression -> IO Compression
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Compression
noCompression
      getCompression (Just CompressionId
cid) =
          case Negotation -> CompressionId -> Maybe Compression
Compr.getSupported Negotation
connCompression CompressionId
cid of
            Just Compression
compr -> Compression -> IO Compression
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Compression
compr
            Maybe Compression
Nothing    -> CallSetupFailure -> IO Compression
forall e a. (HasCallStack, Exception e) => e -> IO a
throwIO (CallSetupFailure -> IO Compression)
-> CallSetupFailure -> IO Compression
forall a b. (a -> b) -> a -> b
$ CompressionId -> CallSetupFailure
CallSetupUnsupportedCompression CompressionId
cid

  buildRequestInfo :: ClientSession rpc
-> FlowStart (Outbound (ClientSession rpc)) -> RequestInfo
buildRequestInfo ClientSession rpc
_ FlowStart (Outbound (ClientSession rpc))
start = RequestInfo {
        requestMethod :: ByteString
requestMethod  = RawResourceHeaders -> ByteString
rawMethod RawResourceHeaders
resourceHeaders
      , requestPath :: ByteString
requestPath    = RawResourceHeaders -> ByteString
rawPath RawResourceHeaders
resourceHeaders
      , requestHeaders :: [Header]
requestHeaders = Proxy rpc -> RequestHeaders -> [Header]
forall {k} (rpc :: k).
IsRPC rpc =>
Proxy rpc -> RequestHeaders -> [Header]
buildRequestHeaders (forall (t :: k). Proxy t
forall {k} (t :: k). Proxy t
Proxy @rpc) (RequestHeaders -> [Header]) -> RequestHeaders -> [Header]
forall a b. (a -> b) -> a -> b
$
            case FlowStart (Outbound (ClientSession rpc))
start of
              FlowStartRegular    Headers (Outbound (ClientSession rpc))
headers    -> Headers (ClientOutbound rpc) -> RequestHeaders
forall k (rpc :: k). Headers (ClientOutbound rpc) -> RequestHeaders
outHeaders Headers (Outbound (ClientSession rpc))
Headers (ClientOutbound rpc)
headers
              FlowStartNoMessages NoMessages (Outbound (ClientSession rpc))
impossible -> Void -> RequestHeaders
forall a. Void -> a
absurd Void
NoMessages (Outbound (ClientSession rpc))
impossible
      }
    where
      resourceHeaders :: RawResourceHeaders
      resourceHeaders :: RawResourceHeaders
resourceHeaders = ResourceHeaders -> RawResourceHeaders
buildResourceHeaders (ResourceHeaders -> RawResourceHeaders)
-> ResourceHeaders -> RawResourceHeaders
forall a b. (a -> b) -> a -> b
$ ResourceHeaders {
            resourceMethod :: Method
resourceMethod = Method
Post
          , resourcePath :: Path
resourcePath   = Proxy rpc -> Path
forall {k} (rpc :: k). IsRPC rpc => Proxy rpc -> Path
rpcPath (forall (t :: k). Proxy t
forall {k} (t :: k). Proxy t
Proxy @rpc)
          }

instance NoTrailers (ClientSession rpc) where
  noTrailers :: Proxy (ClientSession rpc)
-> Trailers (Outbound (ClientSession rpc))
noTrailers Proxy (ClientSession rpc)
_ = NoMetadata
Trailers (Outbound (ClientSession rpc))
NoMetadata

{-------------------------------------------------------------------------------
  Exceptions
-------------------------------------------------------------------------------}

data CallSetupFailure =
    -- | Server chose an unsupported compression algorithm
    CallSetupUnsupportedCompression CompressionId

    -- | We failed to parse the response headers
  | CallSetupInvalidResponseHeaders (InvalidHeaders HandledSynthesized)
  deriving stock (Int -> CallSetupFailure -> ShowS
[CallSetupFailure] -> ShowS
CallSetupFailure -> String
(Int -> CallSetupFailure -> ShowS)
-> (CallSetupFailure -> String)
-> ([CallSetupFailure] -> ShowS)
-> Show CallSetupFailure
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> CallSetupFailure -> ShowS
showsPrec :: Int -> CallSetupFailure -> ShowS
$cshow :: CallSetupFailure -> String
show :: CallSetupFailure -> String
$cshowList :: [CallSetupFailure] -> ShowS
showList :: [CallSetupFailure] -> ShowS
Show)
  deriving anyclass (Show CallSetupFailure
Typeable CallSetupFailure
(Typeable CallSetupFailure, Show CallSetupFailure) =>
(CallSetupFailure -> SomeException)
-> (SomeException -> Maybe CallSetupFailure)
-> (CallSetupFailure -> String)
-> (CallSetupFailure -> Bool)
-> Exception CallSetupFailure
SomeException -> Maybe CallSetupFailure
CallSetupFailure -> Bool
CallSetupFailure -> String
CallSetupFailure -> SomeException
forall e.
(Typeable e, Show e) =>
(e -> SomeException)
-> (SomeException -> Maybe e)
-> (e -> String)
-> (e -> Bool)
-> Exception e
$ctoException :: CallSetupFailure -> SomeException
toException :: CallSetupFailure -> SomeException
$cfromException :: SomeException -> Maybe CallSetupFailure
fromException :: SomeException -> Maybe CallSetupFailure
$cdisplayException :: CallSetupFailure -> String
displayException :: CallSetupFailure -> String
$cbacktraceDesired :: CallSetupFailure -> Bool
backtraceDesired :: CallSetupFailure -> Bool
Exception)

-- | We failed to parse the response trailers
data InvalidTrailers =
    -- | Some of the trailers could not be parsed
    InvalidTrailers {
        InvalidTrailers -> [Header]
invalidTrailers      :: [HTTP.Header]
      , InvalidTrailers -> String
invalidTrailersError :: String
      }
  deriving stock (Int -> InvalidTrailers -> ShowS
[InvalidTrailers] -> ShowS
InvalidTrailers -> String
(Int -> InvalidTrailers -> ShowS)
-> (InvalidTrailers -> String)
-> ([InvalidTrailers] -> ShowS)
-> Show InvalidTrailers
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> InvalidTrailers -> ShowS
showsPrec :: Int -> InvalidTrailers -> ShowS
$cshow :: InvalidTrailers -> String
show :: InvalidTrailers -> String
$cshowList :: [InvalidTrailers] -> ShowS
showList :: [InvalidTrailers] -> ShowS
Show)
  deriving anyclass (Show InvalidTrailers
Typeable InvalidTrailers
(Typeable InvalidTrailers, Show InvalidTrailers) =>
(InvalidTrailers -> SomeException)
-> (SomeException -> Maybe InvalidTrailers)
-> (InvalidTrailers -> String)
-> (InvalidTrailers -> Bool)
-> Exception InvalidTrailers
SomeException -> Maybe InvalidTrailers
InvalidTrailers -> Bool
InvalidTrailers -> String
InvalidTrailers -> SomeException
forall e.
(Typeable e, Show e) =>
(e -> SomeException)
-> (SomeException -> Maybe e)
-> (e -> String)
-> (e -> Bool)
-> Exception e
$ctoException :: InvalidTrailers -> SomeException
toException :: InvalidTrailers -> SomeException
$cfromException :: SomeException -> Maybe InvalidTrailers
fromException :: SomeException -> Maybe InvalidTrailers
$cdisplayException :: InvalidTrailers -> String
displayException :: InvalidTrailers -> String
$cbacktraceDesired :: InvalidTrailers -> Bool
backtraceDesired :: InvalidTrailers -> Bool
Exception)