{-# LANGUAGE OverloadedStrings #-}

-- | Connection to a server
--
-- Intended for qualified import.
--
-- > import Network.GRPC.Client.Connection (Connection, withConnection)
-- > import Network.GRPC.Client.Connection qualified as Connection
module Network.GRPC.Client.Connection (
    -- * Definition
    Connection -- opaque
  , withConnection
    -- * Configuration
  , Server(..)
  , ServerValidation(..)
  , SslKeyLog(..)
  , ConnParams(..)
  , ReconnectPolicy(..)
  , ReconnectTo(..)
  , exponentialBackoff
    -- * Using the connection
  , connParams
  , getConnectionToServer
  , getOutboundCompression
  , updateConnectionMeta
  ) where

import Control.Concurrent
import Control.Concurrent.STM
import Control.Monad
import Control.Monad.Catch
import Data.Default
import GHC.Stack
import Network.HPACK qualified as HPACK
import Network.HTTP2.Client qualified as HTTP2.Client
import Network.HTTP2.TLS.Client qualified as HTTP2.TLS.Client
import Network.Run.TCP qualified as Run
import Network.Socket
import Network.TLS (TLSException)
import System.Random

import Network.GRPC.Client.Meta (Meta)
import Network.GRPC.Client.Meta qualified as Meta
import Network.GRPC.Common.Compression qualified as Compr
import Network.GRPC.Common.HTTP2Settings
import Network.GRPC.Spec
import Network.GRPC.Util.GHC
import Network.GRPC.Util.Session qualified as Session
import Network.GRPC.Util.TLS (ServerValidation(..), SslKeyLog(..))
import Network.GRPC.Util.TLS qualified as Util.TLS

{---------------------------------------------------2----------------------------
  Connection API

  'Connection' is kept abstract (opaque) in the user facing API.

  The closest concept on the server side concept is
  'Network.GRPC.Server.Context': this does not identify a connection from a
  particular client (@http2@ gives us each request separately, without
  identifying which requests come from the same client), but keeps track of the
  overall server state.
-------------------------------------------------------------------------------}

-- | Open connection to server
--
-- See 'withConnection'.
--
-- Before we can send RPC requests, we have to connect to a specific server
-- first. Once we have opened a connection to that server, we can send as many
-- RPC requests over that one connection as we wish. 'Connection' abstracts over
-- this connection, and also maintains some information about the server.
--
-- We can make many RPC calls over the same connection.
data Connection = Connection {
      -- | Configuration
      Connection -> ConnParams
connParams :: ConnParams

      -- | Information about the open connection
    , Connection -> MVar Meta
connMetaVar :: MVar Meta

      -- | Connection state
    , Connection -> TVar ConnectionState
connStateVar :: TVar ConnectionState
    }

{-------------------------------------------------------------------------------
  Config
-------------------------------------------------------------------------------}

-- | Connection configuration
--
-- You may wish to override 'connReconnectPolicy'.
data ConnParams = ConnParams {
      -- | Compression negotation
      ConnParams -> Negotation
connCompression :: Compr.Negotation

      -- | Default timeout
      --
      -- Individual RPC calls can override this through 'CallParams'.
    , ConnParams -> Maybe Timeout
connDefaultTimeout :: Maybe Timeout

      -- | Reconnection policy
      --
      -- NOTE: The default 'ReconnectPolicy' is 'DontReconnect', as per the
      -- spec (see 'ReconnectPolicy'). You may wish to override this in order
      -- to enable Wait for Ready semantics (retry connecting to a server
      -- when it is not yet ready) as well as automatic reconnects (reconnecting
      -- after a server disappears). The latter can be especially important
      -- when there are proxies, which tend to drop connections after a certain
      -- amount of time.
    , ConnParams -> ReconnectPolicy
connReconnectPolicy :: ReconnectPolicy

      -- | Optionally override the content type
      --
      -- If 'Nothing', the @Content-Type@ header will be omitted entirely
      -- (this is not conform gRPC spec).
    , ConnParams -> Maybe ContentType
connContentType :: Maybe ContentType

      -- | Should we verify all request headers?
      --
      -- This is the client analogue of
      -- 'Network.GRPC.Server.Context.serverVerifyHeaders'; see detailed
      -- discussion there.
      --
      -- Arguably, it is less essential to verify headers on the client: a
      -- server must deal with all kinds of different clients, and might want to
      -- know if any of those clients has expectations that it cannot fulfill. A
      -- client however connects to a known server, and knows what information
      -- it wants from the server.
    , ConnParams -> Bool
connVerifyHeaders :: Bool

      -- | Optionally set the initial compression algorithm
      --
      -- Under normal circumstances, the @grapesy@ client will only start using
      -- compression once the server has informed it what compression algorithms
      -- it supports. This means the first message will necessarily be
      -- uncompressed. 'connCompression' can be used to override this behaviour,
      -- but should be used with care: if the server does not support the
      -- selected compression algorithm, it will not be able to decompress any
      -- messages sent by the client to the server.
    , ConnParams -> Maybe Compression
connInitCompression :: Maybe Compression

      -- | HTTP2 settings
    , ConnParams -> HTTP2Settings
connHTTP2Settings :: HTTP2Settings
    }

instance Default ConnParams where
  def :: ConnParams
def = ConnParams {
        connCompression :: Negotation
connCompression     = Negotation
forall a. Default a => a
def
      , connDefaultTimeout :: Maybe Timeout
connDefaultTimeout  = Maybe Timeout
forall a. Maybe a
Nothing
      , connReconnectPolicy :: ReconnectPolicy
connReconnectPolicy = ReconnectPolicy
forall a. Default a => a
def
      , connContentType :: Maybe ContentType
connContentType     = ContentType -> Maybe ContentType
forall a. a -> Maybe a
Just ContentType
ContentTypeDefault
      , connVerifyHeaders :: Bool
connVerifyHeaders   = Bool
False
      , connInitCompression :: Maybe Compression
connInitCompression = Maybe Compression
forall a. Maybe a
Nothing
      , connHTTP2Settings :: HTTP2Settings
connHTTP2Settings   = HTTP2Settings
forall a. Default a => a
def
      }

{-------------------------------------------------------------------------------
  Reconnection policy
-------------------------------------------------------------------------------}

-- | Reconnect policy
--
-- See 'exponentialBackoff' for a convenient function to construct a policy.
data ReconnectPolicy =
    -- | Do not attempt to reconnect
    --
    -- When we get disconnected from the server (or fail to establish a
    -- connection), do not attempt to connect again.
    DontReconnect

    -- | Reconnect to the (potentially different) server after the IO action
    -- returns
    --
    -- The 'ReconnectTo' can be used to implement a rudimentary redundancy
    -- scheme. For example, you could decide to reconnect to a known fallback
    -- server after connection to a main server fails a certain number of times.
    --
    -- This is a very general API: typically the IO action will call
    -- 'threadDelay' after some amount of time (which will typically involve
    -- some randomness), but it can be used to do things such as display a
    -- message to the user somewhere that the client is reconnecting.
  | ReconnectAfter ReconnectTo (IO ReconnectPolicy)

-- | What server should we attempt to reconnect to?
--
-- * 'ReconnectToPrevious' will attempt to reconnect to the last server we
--   attempted to connect to, whether or not that attempt was successful.
-- * 'ReconnectToOriginal' will attempt to reconnect to the original server that
--   'withConnection' was given.
-- * 'ReconnectToNew' will attempt to connect to the newly specified server.
data ReconnectTo =
      ReconnectToPrevious
    | ReconnectToOriginal
    | ReconnectToNew Server

-- | The default policy is 'DontReconnect'
--
-- The default follows the gRPC specification of Wait for Ready semantics
-- <https://github.com/grpc/grpc/blob/master/doc/wait-for-ready.md>.
instance Default ReconnectPolicy where
  def :: ReconnectPolicy
def = ReconnectPolicy
DontReconnect

instance Default ReconnectTo where
  def :: ReconnectTo
def = ReconnectTo
ReconnectToPrevious

-- | Exponential backoff
--
-- If the exponent is @1@, the delay interval will be the same every step;
-- for an exponent of greater than @1@, we will wait longer each step.
exponentialBackoff ::
     (Int -> IO ())
     -- ^ Execute the delay (in microseconds)
     --
     -- The default choice here can simply be 'threadDelay', but it is also
     -- possible to use this to add some logging. Simple example:
     --
     -- > waitFor :: Int -> IO ()
     -- > waitFor delay = do
     -- >   putStrLn $ "Disconnected. Reconnecting after " ++ show delay ++ "μs"
     -- >   threadDelay delay
     -- >   putStrLn "Reconnecting now."
  -> Double
     -- ^ Exponent
  -> (Double, Double)
     -- ^ Initial delay
  -> Word
     -- ^ Maximum number of attempts
  -> ReconnectPolicy
exponentialBackoff :: (Int -> IO ())
-> Double -> (Double, Double) -> Word -> ReconnectPolicy
exponentialBackoff Int -> IO ()
waitFor Double
e = (Double, Double) -> Word -> ReconnectPolicy
go
  where
    go :: (Double, Double) -> Word -> ReconnectPolicy
    go :: (Double, Double) -> Word -> ReconnectPolicy
go (Double, Double)
_        Word
0 = ReconnectPolicy
DontReconnect
    go (Double
lo, Double
hi) Word
n = ReconnectTo -> IO ReconnectPolicy -> ReconnectPolicy
ReconnectAfter ReconnectTo
forall a. Default a => a
def (IO ReconnectPolicy -> ReconnectPolicy)
-> IO ReconnectPolicy -> ReconnectPolicy
forall a b. (a -> b) -> a -> b
$ do
        delay <- (Double, Double) -> IO Double
forall a (m :: * -> *). (Random a, MonadIO m) => (a, a) -> m a
randomRIO (Double
lo, Double
hi)
        waitFor $ round $ delay * 1_000_000
        return $ go (lo * e, hi * e) (pred n)

{-------------------------------------------------------------------------------
  Fatal exceptions (no point reconnecting)
-------------------------------------------------------------------------------}

isFatalException :: SomeException -> Bool
isFatalException :: SomeException -> Bool
isFatalException SomeException
err
  | Just (TLSException
_tlsException :: TLSException) <- SomeException -> Maybe TLSException
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
err
  = Bool
True

  | Bool
otherwise
  = Bool
False

{-------------------------------------------------------------------------------
  Server address
-------------------------------------------------------------------------------}

data Server =
    -- | Make insecure connection (without TLS) to the given server
    ServerInsecure Address

    -- | Make secure connection (with TLS) to the given server
  | ServerSecure ServerValidation SslKeyLog Address

    -- | Make a local connection over a Unix domain socket
  | ServerUnix FilePath
  deriving stock (Int -> Server -> ShowS
[Server] -> ShowS
Server -> String
(Int -> Server -> ShowS)
-> (Server -> String) -> ([Server] -> ShowS) -> Show Server
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Server -> ShowS
showsPrec :: Int -> Server -> ShowS
$cshow :: Server -> String
show :: Server -> String
$cshowList :: [Server] -> ShowS
showList :: [Server] -> ShowS
Show)

{-------------------------------------------------------------------------------
  Open a new connection
-------------------------------------------------------------------------------}

-- | Open connection to the server
--
-- See 'Network.GRPC.Client.withRPC' for making individual RPCs on the new
-- connection.
--
-- The connection to the server is set up asynchronously; the first call to
-- 'withRPC' will block until the connection has been established.
--
-- If the server cannot be reached, the behaviour depends on
-- 'connReconnectPolicy': if the policy allows reconnection attempts, we will
-- wait the time specified by the policy and try again. This implements the gRPC
-- "Wait for ready" semantics.
--
-- If the connection to the server is lost /after/ it has been established, any
-- currently ongoing RPC calls will be closed; attempts at further communication
-- on any of these calls will result in an exception being thrown. However, if
-- the 'ReconnectPolicy' allows, we will automatically try to re-establish a
-- connection to the server. This can be especially important when there is a
-- proxy between the client and the server, which may drop an existing
-- connection after a certain period.
--
-- NOTE: The /default/ 'ReconnectPolicy' is 'DontReconnect', as per the gRPC
-- specification of "Wait for ready" semantics. You may wish to override this
-- default.
--
-- Clients should prefer sending many calls on a single connection, rather than
-- sending few calls on many connections, as minimizing the number of
-- connections used via this interface results in better memory behavior. See
-- [well-typed/grapesy#134](https://github.com/well-typed/grapesy/issues/133)
-- for discussion.
withConnection ::
     ConnParams
  -> Server
  -> (Connection -> IO a)
  -> IO a
withConnection :: forall a. ConnParams -> Server -> (Connection -> IO a) -> IO a
withConnection ConnParams
connParams Server
server Connection -> IO a
k = do
    connMetaVar  <- Meta -> IO (MVar Meta)
forall a. a -> IO (MVar a)
newMVar (Meta -> IO (MVar Meta)) -> Meta -> IO (MVar Meta)
forall a b. (a -> b) -> a -> b
$ Maybe Compression -> Meta
Meta.init (ConnParams -> Maybe Compression
connInitCompression ConnParams
connParams)
    connStateVar <- newTVarIO ConnectionNotReady

    connOutOfScope <- newEmptyMVar
    let stayConnectedThread :: IO ()
        stayConnectedThread =
            ConnParams -> Server -> TVar ConnectionState -> MVar () -> IO ()
stayConnected ConnParams
connParams Server
server TVar ConnectionState
connStateVar MVar ()
connOutOfScope

    -- We don't use withAsync because we want the thread to terminate cleanly
    -- when we no longer need the connection (which we indicate by writing to
    -- connOutOfScope).
    void $ forkLabelled "grapesy:stayConnected" $ stayConnectedThread
    k Connection {connParams, connMetaVar, connStateVar}
      `finally` putMVar connOutOfScope ()

{-------------------------------------------------------------------------------
  Making use of the connection
-------------------------------------------------------------------------------}

-- | Get connection to the server
--
-- Returns two things: the connection to the server, as well as a @TMVar@ that
-- should be monitored to see if that connection is still live.
getConnectionToServer :: forall.
     HasCallStack
  => Connection
  -> IO (TMVar (Maybe SomeException), Session.ConnectionToServer)
getConnectionToServer :: HasCallStack =>
Connection -> IO (TMVar (Maybe SomeException), ConnectionToServer)
getConnectionToServer Connection{TVar ConnectionState
connStateVar :: Connection -> TVar ConnectionState
connStateVar :: TVar ConnectionState
connStateVar} = STM (TMVar (Maybe SomeException), ConnectionToServer)
-> IO (TMVar (Maybe SomeException), ConnectionToServer)
forall a. STM a -> IO a
atomically (STM (TMVar (Maybe SomeException), ConnectionToServer)
 -> IO (TMVar (Maybe SomeException), ConnectionToServer))
-> STM (TMVar (Maybe SomeException), ConnectionToServer)
-> IO (TMVar (Maybe SomeException), ConnectionToServer)
forall a b. (a -> b) -> a -> b
$ do
    connState <- TVar ConnectionState -> STM ConnectionState
forall a. TVar a -> STM a
readTVar TVar ConnectionState
connStateVar
    case connState of
      ConnectionState
ConnectionNotReady              -> STM (TMVar (Maybe SomeException), ConnectionToServer)
forall a. STM a
retry
      ConnectionReady TMVar (Maybe SomeException)
connClosed ConnectionToServer
conn -> (TMVar (Maybe SomeException), ConnectionToServer)
-> STM (TMVar (Maybe SomeException), ConnectionToServer)
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return (TMVar (Maybe SomeException)
connClosed, ConnectionToServer
conn)
      ConnectionAbandoned SomeException
err         -> SomeException
-> STM (TMVar (Maybe SomeException), ConnectionToServer)
forall e a. Exception e => e -> STM a
throwSTM SomeException
err
      ConnectionState
ConnectionOutOfScope            -> String -> STM (TMVar (Maybe SomeException), ConnectionToServer)
forall a. HasCallStack => String -> a
error String
"impossible"

-- | Get outbound compression algorithm
--
-- This is stateful, because it depends on whether or not compression negotation
-- has happened yet: before the remote peer has told us which compression
-- algorithms it can support, we must use no compression.
getOutboundCompression :: Connection -> IO (Maybe Compression)
getOutboundCompression :: Connection -> IO (Maybe Compression)
getOutboundCompression Connection{MVar Meta
connMetaVar :: Connection -> MVar Meta
connMetaVar :: MVar Meta
connMetaVar} =
    Meta -> Maybe Compression
Meta.outboundCompression (Meta -> Maybe Compression) -> IO Meta -> IO (Maybe Compression)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVar Meta -> IO Meta
forall a. MVar a -> IO a
readMVar MVar Meta
connMetaVar

-- | Update connection metadata
--
-- Amongst other things, this updates the compression algorithm to be used
-- (see also 'getOutboundCompression').
updateConnectionMeta ::
     Connection
  -> ResponseHeaders' HandledSynthesized
  -> IO ()
updateConnectionMeta :: Connection -> ResponseHeaders' HandledSynthesized -> IO ()
updateConnectionMeta Connection{MVar Meta
connMetaVar :: Connection -> MVar Meta
connMetaVar :: MVar Meta
connMetaVar, ConnParams
connParams :: Connection -> ConnParams
connParams :: ConnParams
connParams} ResponseHeaders' HandledSynthesized
hdrs =
    MVar Meta -> (Meta -> IO Meta) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ MVar Meta
connMetaVar ((Meta -> IO Meta) -> IO ()) -> (Meta -> IO Meta) -> IO ()
forall a b. (a -> b) -> a -> b
$ Negotation
-> ResponseHeaders' HandledSynthesized -> Meta -> IO Meta
forall (m :: * -> *).
MonadThrow m =>
Negotation -> ResponseHeaders' HandledSynthesized -> Meta -> m Meta
Meta.update (ConnParams -> Negotation
connCompression ConnParams
connParams) ResponseHeaders' HandledSynthesized
hdrs

{-------------------------------------------------------------------------------
  Internal auxiliary
-------------------------------------------------------------------------------}

data ConnectionState =
    -- | We haven't set up the connection yet
    ConnectionNotReady

    -- | The connection is ready
    --
    -- The nested @TMVar@ is written to when the connection is closed.
  | ConnectionReady (TMVar (Maybe SomeException)) Session.ConnectionToServer

    -- | We gave up trying to (re)establish the connection
  | ConnectionAbandoned SomeException

    -- | The connection was closed because it is no longer needed.
  | ConnectionOutOfScope

-- | Connection attempt
--
-- This is an internal data structure used only in 'stayConnected' and helpers.
data Attempt = ConnectionAttempt {
      Attempt -> ConnParams
attemptParams     :: ConnParams
    , Attempt -> TVar ConnectionState
attemptState      :: TVar ConnectionState
    , Attempt -> MVar ()
attemptOutOfScope :: MVar ()
    , Attempt -> TMVar (Maybe SomeException)
attemptClosed     :: TMVar (Maybe SomeException)
    }

newConnectionAttempt ::
     ConnParams
  -> TVar ConnectionState
  -> MVar ()
  -> IO Attempt
newConnectionAttempt :: ConnParams -> TVar ConnectionState -> MVar () -> IO Attempt
newConnectionAttempt ConnParams
attemptParams TVar ConnectionState
attemptState MVar ()
attemptOutOfScope = do
    attemptClosed <- IO (TMVar (Maybe SomeException))
forall a. IO (TMVar a)
newEmptyTMVarIO
    return ConnectionAttempt{
        attemptParams
      , attemptState
      , attemptOutOfScope
      , attemptClosed
      }

-- | Stay connected to the server
stayConnected ::
     ConnParams
  -> Server
  -> TVar ConnectionState
  -> MVar ()
  -> IO ()
stayConnected :: ConnParams -> Server -> TVar ConnectionState -> MVar () -> IO ()
stayConnected ConnParams
connParams Server
initialServer TVar ConnectionState
connStateVar MVar ()
connOutOfScope = do
    Server -> ReconnectPolicy -> IO ()
loop Server
initialServer (ConnParams -> ReconnectPolicy
connReconnectPolicy ConnParams
connParams)
  where
    loop :: Server -> ReconnectPolicy -> IO ()
    loop :: Server -> ReconnectPolicy -> IO ()
loop Server
server ReconnectPolicy
remainingReconnectPolicy = do
        -- Start new attempt (this just allocates some internal state)
        attempt <- ConnParams -> TVar ConnectionState -> MVar () -> IO Attempt
newConnectionAttempt ConnParams
connParams TVar ConnectionState
connStateVar MVar ()
connOutOfScope

        -- Just like in 'runHandler' on the server side, it is important that
        -- 'stayConnected' runs in a separate thread. If it does not, then the
        -- moment we disconnect @http2[-tls]@ will throw an exception and we
        -- will not get the chance to process any other messages. This is
        -- especially important when we fail to setup a call: the server will
        -- respond with an informative gRPC error message (which we will raise
        -- as a 'GrpcException' in the client), and then disconnect. If we do
        -- not call @run@ in a separate thread, the only exception we will see
        -- is the low-level exception reported by @http2@ (something about
        -- stream errors), rather than the informative gRPC exception we want.

        mRes <- try $
          case server of
            ServerInsecure Address
addr ->
              ConnParams -> Attempt -> Address -> IO ()
connectInsecure ConnParams
connParams Attempt
attempt Address
addr
            ServerSecure ServerValidation
validation SslKeyLog
sslKeyLog Address
addr ->
              ConnParams
-> Attempt -> ServerValidation -> SslKeyLog -> Address -> IO ()
connectSecure ConnParams
connParams Attempt
attempt ServerValidation
validation SslKeyLog
sslKeyLog Address
addr
            ServerUnix String
path ->
              ConnParams -> Attempt -> String -> IO ()
connectUnix ConnParams
connParams Attempt
attempt String
path

        thisReconnectPolicy <- atomically $ do
          putTMVar (attemptClosed attempt) $ either Just (\() -> Maybe SomeException
forall a. Maybe a
Nothing) mRes
          connState <- readTVar connStateVar
          return $ case connState of
            ConnectionReady{}->
              -- Suppose we have a maximum of 5x to try and connect to a server.
              -- Then if we manage to connect, and /then/ lose the connection,
              -- we should have those same 5x tries again.
              ConnParams -> ReconnectPolicy
connReconnectPolicy ConnParams
connParams
            ConnectionState
_otherwise ->
              ReconnectPolicy
remainingReconnectPolicy

        case mRes of
          Right () -> do
            STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar ConnectionState -> ConnectionState -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar ConnectionState
connStateVar (ConnectionState -> STM ()) -> ConnectionState -> STM ()
forall a b. (a -> b) -> a -> b
$ ConnectionState
ConnectionOutOfScope
          Left SomeException
err -> do
            case (SomeException -> Bool
isFatalException SomeException
err, ReconnectPolicy
thisReconnectPolicy) of
              (Bool
True, ReconnectPolicy
_) -> do
                STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar ConnectionState -> ConnectionState -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar ConnectionState
connStateVar (ConnectionState -> STM ()) -> ConnectionState -> STM ()
forall a b. (a -> b) -> a -> b
$ SomeException -> ConnectionState
ConnectionAbandoned SomeException
err
              (Bool
False, ReconnectPolicy
DontReconnect) -> do
                STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar ConnectionState -> ConnectionState -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar ConnectionState
connStateVar (ConnectionState -> STM ()) -> ConnectionState -> STM ()
forall a b. (a -> b) -> a -> b
$ SomeException -> ConnectionState
ConnectionAbandoned SomeException
err
                STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar ConnectionState -> ConnectionState -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar ConnectionState
connStateVar (ConnectionState -> STM ()) -> ConnectionState -> STM ()
forall a b. (a -> b) -> a -> b
$ SomeException -> ConnectionState
ConnectionAbandoned SomeException
err
              (Bool
False, ReconnectAfter ReconnectTo
to IO ReconnectPolicy
f) -> do
                let
                  nextServer :: Server
nextServer =
                    case ReconnectTo
to of
                      ReconnectTo
ReconnectToPrevious -> Server
server
                      ReconnectTo
ReconnectToOriginal -> Server
initialServer
                      ReconnectToNew Server
new  -> Server
new
                STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar ConnectionState -> ConnectionState -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar ConnectionState
connStateVar (ConnectionState -> STM ()) -> ConnectionState -> STM ()
forall a b. (a -> b) -> a -> b
$ ConnectionState
ConnectionNotReady
                Server -> ReconnectPolicy -> IO ()
loop Server
nextServer (ReconnectPolicy -> IO ()) -> IO ReconnectPolicy -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO ReconnectPolicy
f

-- | Unix domain socket connection
connectUnix :: ConnParams -> Attempt -> FilePath -> IO ()
connectUnix :: ConnParams -> Attempt -> String -> IO ()
connectUnix ConnParams
connParams Attempt
attempt String
path = do
  client <- Family -> SocketType -> CInt -> IO Socket
socket Family
AF_UNIX SocketType
Stream CInt
defaultProtocol
  connect client $ SockAddrUnix path
  connectSocket connParams attempt "localhost" client

-- | Insecure connection (no TLS)
connectInsecure :: ConnParams -> Attempt -> Address -> IO ()
connectInsecure :: ConnParams -> Attempt -> Address -> IO ()
connectInsecure ConnParams
connParams Attempt
attempt Address
addr = do
    Settings -> String -> String -> (Socket -> IO ()) -> IO ()
forall a. Settings -> String -> String -> (Socket -> IO a) -> IO a
Run.runTCPClientWithSettings
        Settings
runSettings
        (Address -> String
addressHost Address
addr)
        (PortNumber -> String
forall a. Show a => a -> String
show (PortNumber -> String) -> PortNumber -> String
forall a b. (a -> b) -> a -> b
$ Address -> PortNumber
addressPort Address
addr)
        ((Socket -> IO ()) -> IO ()) -> (Socket -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ ConnParams -> Attempt -> String -> Socket -> IO ()
connectSocket ConnParams
connParams Attempt
attempt (Address -> String
authority Address
addr)
  where
    ConnParams{HTTP2Settings
connHTTP2Settings :: ConnParams -> HTTP2Settings
connHTTP2Settings :: HTTP2Settings
connHTTP2Settings} = ConnParams
connParams

    runSettings :: Run.Settings
    runSettings :: Settings
runSettings = Settings
Run.defaultSettings {
          Run.settingsOpenClientSocket = openClientSocket connHTTP2Settings
        }

-- | Insecure connection over the given socket
connectSocket :: ConnParams -> Attempt -> String -> Socket -> IO ()
connectSocket :: ConnParams -> Attempt -> String -> Socket -> IO ()
connectSocket ConnParams
connParams Attempt
attempt String
connAuthority Socket
sock = do
    IO Config -> (Config -> IO ()) -> (Config -> IO ()) -> IO ()
forall (m :: * -> *) a c b.
(HasCallStack, MonadMask m) =>
m a -> (a -> m c) -> (a -> m b) -> m b
bracket (Socket -> Int -> IO Config
HTTP2.Client.allocSimpleConfig Socket
sock Int
writeBufferSize)
            Config -> IO ()
HTTP2.Client.freeSimpleConfig ((Config -> IO ()) -> IO ()) -> (Config -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Config
conf ->
      ClientConfig -> Config -> Client () -> IO ()
forall a. ClientConfig -> Config -> Client a -> IO a
HTTP2.Client.run ClientConfig
clientConfig Config
conf (Client () -> IO ()) -> Client () -> IO ()
forall a b. (a -> b) -> a -> b
$ \SendRequest
sendRequest Aux
_aux -> do
        let conn :: ConnectionToServer
conn = SendRequest -> ConnectionToServer
Session.ConnectionToServer Request -> (Response -> IO a) -> IO a
SendRequest
sendRequest
        STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$
          TVar ConnectionState -> ConnectionState -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar (Attempt -> TVar ConnectionState
attemptState Attempt
attempt) (ConnectionState -> STM ()) -> ConnectionState -> STM ()
forall a b. (a -> b) -> a -> b
$
            TMVar (Maybe SomeException)
-> ConnectionToServer -> ConnectionState
ConnectionReady (Attempt -> TMVar (Maybe SomeException)
attemptClosed Attempt
attempt) ConnectionToServer
conn
        MVar () -> IO ()
forall a. MVar a -> IO a
takeMVar (MVar () -> IO ()) -> MVar () -> IO ()
forall a b. (a -> b) -> a -> b
$ Attempt -> MVar ()
attemptOutOfScope Attempt
attempt
  where
    ConnParams{HTTP2Settings
connHTTP2Settings :: ConnParams -> HTTP2Settings
connHTTP2Settings :: HTTP2Settings
connHTTP2Settings} = ConnParams
connParams

    settings :: HTTP2.Client.Settings
    settings :: Settings
settings = Settings
HTTP2.Client.defaultSettings {
          HTTP2.Client.maxConcurrentStreams =
              Just . fromIntegral $
                http2MaxConcurrentStreams connHTTP2Settings
        , HTTP2.Client.initialWindowSize =
              fromIntegral $
                http2StreamWindowSize connHTTP2Settings
        }

    clientConfig :: HTTP2.Client.ClientConfig
    clientConfig :: ClientConfig
clientConfig = ConnParams -> ClientConfig -> ClientConfig
overrideRateLimits ConnParams
connParams (ClientConfig -> ClientConfig) -> ClientConfig -> ClientConfig
forall a b. (a -> b) -> a -> b
$
        ClientConfig
HTTP2.Client.defaultClientConfig {
            HTTP2.Client.authority = connAuthority
          , HTTP2.Client.settings = settings
          , HTTP2.Client.connectionWindowSize =
                fromIntegral $
                  http2ConnectionWindowSize connHTTP2Settings
          }

-- | Secure connection (using TLS)
connectSecure ::
     ConnParams
  -> Attempt
  -> ServerValidation
  -> SslKeyLog
  -> Address
  -> IO ()
connectSecure :: ConnParams
-> Attempt -> ServerValidation -> SslKeyLog -> Address -> IO ()
connectSecure ConnParams
connParams Attempt
attempt ServerValidation
validation SslKeyLog
sslKeyLog Address
addr = do
    keyLogger <- SslKeyLog -> IO (String -> IO ())
Util.TLS.keyLogger SslKeyLog
sslKeyLog
    caStore   <- Util.TLS.validationCAStore validation

    let settings :: HTTP2.TLS.Client.Settings
        settings = Settings
HTTP2.TLS.Client.defaultSettings {
              HTTP2.TLS.Client.settingsValidateCert =
                case validation of
                  ValidateServer CertificateStoreSpec
_   -> Bool
True
                  ServerValidation
NoServerValidation -> Bool
False

            , HTTP2.TLS.Client.settingsCAStore          = caStore
            , HTTP2.TLS.Client.settingsKeyLogger        = keyLogger
            , HTTP2.TLS.Client.settingsAddrInfoFlags    = []

            , HTTP2.TLS.Client.settingsOpenClientSocket =
                openClientSocket connHTTP2Settings
            , HTTP2.TLS.Client.settingsConcurrentStreams = fromIntegral $
                http2MaxConcurrentStreams connHTTP2Settings
            , HTTP2.TLS.Client.settingsStreamWindowSize = fromIntegral $
                http2StreamWindowSize connHTTP2Settings
            , HTTP2.TLS.Client.settingsConnectionWindowSize = fromIntegral $
                http2ConnectionWindowSize connHTTP2Settings
            }

        clientConfig :: HTTP2.Client.ClientConfig
        clientConfig = ConnParams -> ClientConfig -> ClientConfig
overrideRateLimits ConnParams
connParams (ClientConfig -> ClientConfig) -> ClientConfig -> ClientConfig
forall a b. (a -> b) -> a -> b
$
            Settings -> String -> ClientConfig
HTTP2.TLS.Client.defaultClientConfig
              Settings
settings
              (Address -> String
authority Address
addr)

    HTTP2.TLS.Client.runWithConfig
          clientConfig
          settings
          (addressHost addr)
          (addressPort addr)
        $ \SendRequest
sendRequest Aux
_aux -> do
      let conn :: ConnectionToServer
conn = SendRequest -> ConnectionToServer
Session.ConnectionToServer Request -> (Response -> IO a) -> IO a
SendRequest
sendRequest
      STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$
        TVar ConnectionState -> ConnectionState -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar (Attempt -> TVar ConnectionState
attemptState Attempt
attempt) (ConnectionState -> STM ()) -> ConnectionState -> STM ()
forall a b. (a -> b) -> a -> b
$
          TMVar (Maybe SomeException)
-> ConnectionToServer -> ConnectionState
ConnectionReady (Attempt -> TMVar (Maybe SomeException)
attemptClosed Attempt
attempt) ConnectionToServer
conn
      MVar () -> IO ()
forall a. MVar a -> IO a
takeMVar (MVar () -> IO ()) -> MVar () -> IO ()
forall a b. (a -> b) -> a -> b
$ Attempt -> MVar ()
attemptOutOfScope Attempt
attempt
  where
    ConnParams{HTTP2Settings
connHTTP2Settings :: ConnParams -> HTTP2Settings
connHTTP2Settings :: HTTP2Settings
connHTTP2Settings} = ConnParams
connParams

-- | Authority
--
-- We omit the port number in the authority, for compatibility with TLS
-- SNI as well as the gRPC spec (the HTTP2 spec says the port number is
-- optional in the authority).
authority :: Address -> String
authority :: Address -> String
authority Address
addr =
    case Address -> Maybe String
addressAuthority Address
addr of
      Maybe String
Nothing   -> Address -> String
addressHost Address
addr
      Just String
auth -> String
auth

-- | Override rate limits imposed by @http2@
overrideRateLimits ::
     ConnParams
  -> HTTP2.Client.ClientConfig -> HTTP2.Client.ClientConfig
overrideRateLimits :: ConnParams -> ClientConfig -> ClientConfig
overrideRateLimits ConnParams
connParams ClientConfig
clientConfig = ClientConfig
clientConfig {
      HTTP2.Client.settings = settings {
          HTTP2.Client.pingRateLimit =
            case http2OverridePingRateLimit (connHTTP2Settings connParams) of
              Maybe Int
Nothing    -> Settings -> Int
HTTP2.Client.pingRateLimit Settings
settings
              Just Int
limit -> Int
limit
        , HTTP2.Client.emptyFrameRateLimit =
            case http2OverrideEmptyFrameRateLimit (connHTTP2Settings connParams) of
              Maybe Int
Nothing    -> Settings -> Int
HTTP2.Client.emptyFrameRateLimit Settings
settings
              Just Int
limit -> Int
limit
        , HTTP2.Client.settingsRateLimit =
            case http2OverrideSettingsRateLimit (connHTTP2Settings connParams) of
              Maybe Int
Nothing    -> Settings -> Int
HTTP2.Client.settingsRateLimit Settings
settings
              Just Int
limit -> Int
limit
        , HTTP2.Client.rstRateLimit =
            case http2OverrideRstRateLimit (connHTTP2Settings connParams) of
              Maybe Int
Nothing    -> Settings -> Int
HTTP2.Client.rstRateLimit Settings
settings
              Just Int
limit -> Int
limit
        }
    }
  where
    settings :: HTTP2.Client.Settings
    settings :: Settings
settings = ClientConfig -> Settings
HTTP2.Client.settings ClientConfig
clientConfig

{-------------------------------------------------------------------------------
  Auxiliary http2
-------------------------------------------------------------------------------}

openClientSocket :: HTTP2Settings -> AddrInfo -> IO Socket
openClientSocket :: HTTP2Settings -> AddrInfo -> IO Socket
openClientSocket HTTP2Settings
http2Settings =
    [(SocketOption, SockOptValue)] -> AddrInfo -> IO Socket
Run.openClientSocketWithOpts [(SocketOption, SockOptValue)]
socketOptions
  where
    socketOptions :: [(SocketOption, SockOptValue)]
    socketOptions :: [(SocketOption, SockOptValue)]
socketOptions = [[(SocketOption, SockOptValue)]] -> [(SocketOption, SockOptValue)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [
          [ ( SocketOption
NoDelay
            , forall a. Storable a => a -> SockOptValue
SockOptValue @Int Int
1
            )
          | HTTP2Settings -> Bool
http2TcpNoDelay HTTP2Settings
http2Settings
          ]
        , [ ( SocketOption
Linger
            , StructLinger -> SockOptValue
forall a. Storable a => a -> SockOptValue
SockOptValue (StructLinger -> SockOptValue) -> StructLinger -> SockOptValue
forall a b. (a -> b) -> a -> b
$ StructLinger { sl_onoff :: CInt
sl_onoff = CInt
1, sl_linger :: CInt
sl_linger = CInt
0 }
            )
          | HTTP2Settings -> Bool
http2TcpAbortiveClose HTTP2Settings
http2Settings
          ]
        ]

-- | Write-buffer size
--
-- See docs of 'confBufferSize', but importantly: "this value is announced
-- via SETTINGS_MAX_FRAME_SIZE to the peer."
--
-- Value of 4KB is taken from the example code.
writeBufferSize :: HPACK.BufferSize
writeBufferSize :: Int
writeBufferSize = Int
4096