{-# LANGUAGE CPP #-}

module Network.QUIC.Simple
  ( -- $intro

    -- * Basic wrappers
    runServer
  , runClient
    -- * CBOR/Serialise wrappers
  , runServerSimple
  , startClientSimple
    -- ** More flexible variants
  , runServerStateful
  , startClientAsync
  , Serialise
    -- * The rest of the QUIC API
  , module Network.QUIC
  ) where

import Control.Concurrent.STM
import Network.QUIC
import Network.QUIC.Simple.Stream

import Codec.Serialise (Serialise)
import Control.Concurrent (threadDelay)
import Control.Concurrent.Async (Async, async, cancel, link, link2)
import Control.Concurrent.MVar (newEmptyMVar, putMVar, takeMVar)
import Control.Exception (SomeException, handle, onException)
import Control.Monad (forever)
import Data.IP (IP(..))
import Network.QUIC.Client (ClientConfig(..), defaultClientConfig)
import Network.QUIC.Client qualified as Client
import Network.QUIC.Server (ServerConfig(..), defaultServerConfig)
import Network.QUIC.Server qualified as Server
import Network.QUIC.Simple.Credentials (genCredentials)
import Network.Socket (HostName, PortNumber, ServiceName)

{- $intro
Check out the tests in the package git for a cookbook.

If you're unsure, start with the simplest wrapper.
If the wrapper's limitations bother you, replace it with the source code and customize it to suit your needs.
Alternatively, switch to a lower-level implementation to gain more features.

Don't let wrappers dictate your code structure and protocols — they're just there to get you a QUIC-start!
-}

{- | Start a server on all of the address-port pairs.

The server will have an autogenerated set of credentials on each start, just to get the TLS running.
You can use "Network.QUIC.Simple.Credentials.genCredentials" to generate and keep them,
so the clients can pin them after first connection.

The server will automatically accept the incoming stream before passing it to a (stateless) connection handler.
-}
runServer :: [(IP, PortNumber)] -> (Connection -> Stream -> IO ()) -> IO ()
runServer :: [(IP, PortNumber)] -> (Connection -> Stream -> IO ()) -> IO ()
runServer [(IP, PortNumber)]
scAddresses Connection -> Stream -> IO ()
action = do
  Credentials
scCredentials <- IO Credentials
genCredentials
  let
    sc :: ServerConfig
sc = ServerConfig
defaultServerConfig
      { scCredentials
      , scAddresses
      }
  ServerConfig -> (Connection -> IO ()) -> IO ()
Server.run ServerConfig
sc \Connection
conn -> do
    Stream
defaultStream <- Connection -> IO Stream
acceptStream Connection
conn
    Connection -> Stream -> IO ()
action Connection
conn Stream
defaultStream

{- | Start a server on the provided host and port and run a stateless CBOR-encoded request-response protocol.

While it is possible to use `myThreadId` to get some connection identifier and attach connection data on it,
you'd better use `runServerStateful` instead.
-}
runServerSimple
  :: (Serialise q, Serialise r)
  => IP
  -> PortNumber
  -> (q -> IO r)
  -> IO ()
runServerSimple :: forall q r.
(Serialise q, Serialise r) =>
IP -> PortNumber -> (q -> IO r) -> IO ()
runServerSimple IP
host PortNumber
port q -> IO r
action =
  IP
-> PortNumber
-> (Connection -> TBQueue r -> IO ())
-> (Connection -> () -> IO ())
-> (() -> q -> IO ((), Maybe r))
-> IO ()
forall q r s.
(Serialise q, Serialise r) =>
IP
-> PortNumber
-> (Connection -> TBQueue r -> IO s)
-> (Connection -> s -> IO ())
-> (s -> q -> IO (s, Maybe r))
-> IO ()
runServerStateful IP
host PortNumber
port Connection -> TBQueue r -> IO ()
forall {f :: * -> *} {p} {p}. Applicative f => p -> p -> f ()
setup Connection -> () -> IO ()
forall {f :: * -> *} {p} {p}. Applicative f => p -> p -> f ()
teardown () -> q -> IO ((), Maybe r)
forall {a}. a -> q -> IO (a, Maybe r)
handler
  where
    setup :: p -> p -> f ()
setup p
_conn p
_wq = () -> f ()
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    teardown :: p -> p -> f ()
teardown p
_conn p
_s = () -> f ()
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    handler :: a -> q -> IO (a, Maybe r)
handler a
s q
q = do
      r
r <- q -> IO r
action q
q
      (a, Maybe r) -> IO (a, Maybe r)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
s, r -> Maybe r
forall a. a -> Maybe a
Just r
r)

{- | Start a server on the provided host and port and run a stateless CBOR-encoded request-response protocol.

The connection handler is stateful, with the initial state provided by a setup function.
The handler function must provide next connection state, but may opt out of replying.
Throw an exception to terminate the curent connection - the teardown function then can do the clean up.
-}
runServerStateful
  :: (Serialise q, Serialise r)
  => IP
  -> PortNumber
  -> (Connection -> TBQueue r -> IO s)
  -> (Connection -> s -> IO ())
  -> (s -> q -> IO (s, Maybe r))
  -> IO ()
runServerStateful :: forall q r s.
(Serialise q, Serialise r) =>
IP
-> PortNumber
-> (Connection -> TBQueue r -> IO s)
-> (Connection -> s -> IO ())
-> (s -> q -> IO (s, Maybe r))
-> IO ()
runServerStateful IP
host PortNumber
port Connection -> TBQueue r -> IO s
setup Connection -> s -> IO ()
teardown s -> q -> IO (s, Maybe r)
action =
  [(IP, PortNumber)] -> (Connection -> Stream -> IO ()) -> IO ()
runServer [(IP
host, PortNumber
port)] \Connection
conn Stream
stream0 -> do
    (Async ()
codec, (TBQueue r
writeQ, TBQueue q
readQ)) <- Stream -> IO (Async (), MessageQueues r q)
forall sendMsg recvMsg.
(Serialise sendMsg, Serialise recvMsg) =>
Stream -> IO (Async (), MessageQueues sendMsg recvMsg)
streamSerialise Stream
stream0
    Async () -> IO ()
forall a. Async a -> IO ()
link Async ()
codec
    let
      loop :: s -> IO ()
loop !s
s = (SomeException -> IO ()) -> IO () -> IO ()
forall e a. Exception e => (e -> IO a) -> IO a -> IO a
handle (\(SomeException
_ :: SomeException) -> Connection -> s -> IO ()
teardown Connection
conn s
s) do
        q
query <- STM q -> IO q
forall a. STM a -> IO a
atomically (TBQueue q -> STM q
forall a. TBQueue a -> STM a
readTBQueue TBQueue q
readQ)
        (s
s', Maybe r
reply_) <- s -> q -> IO (s, Maybe r)
action s
s q
query
        (r -> IO ()) -> Maybe r -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> (r -> STM ()) -> r -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TBQueue r -> r -> STM ()
forall a. TBQueue a -> a -> STM ()
writeTBQueue TBQueue r
writeQ) Maybe r
reply_
        s -> IO ()
loop s
s'
    Connection -> TBQueue r -> IO s
setup Connection
conn TBQueue r
writeQ IO s -> (s -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= s -> IO ()
loop

{- | Run a client connecting to the provided host/port and auto-request a stream.

Server validation is disabled.
If you want server authentication, you'd have to do that in your protocol handshake.

With the @quic@ library >0.2.10 the connection migration will be enabled by default.
-}
runClient :: HostName -> ServiceName -> (Connection -> Stream -> IO ()) -> IO ()
runClient :: HostName -> HostName -> (Connection -> Stream -> IO ()) -> IO ()
runClient HostName
ccServerName HostName
ccPortName Connection -> Stream -> IO ()
action = do
  ClientConfig -> (Connection -> IO ()) -> IO ()
forall a. ClientConfig -> (Connection -> IO a) -> IO a
Client.run ClientConfig
cc \Connection
conn -> do
    Stream
defaultStream <- Connection -> IO Stream
stream Connection
conn
    Connection -> Stream -> IO ()
action Connection
conn Stream
defaultStream
  where
    cc :: ClientConfig
cc = ClientConfig
defaultClientConfig
      { ccServerName
      , ccPortName
      , ccValidate = False
#if MIN_VERSION_quic(0,2,10)
      , ccSockConnected = True
      , ccWatchDog = True
#endif
      }

{- | Start a client wrapper that will wait for a connection.

When connected, it will provide a way to stop it, and to do a simple blocking call.
There is no call tracking, so the client is not thread-safe.
Which is fine, when used with the 'runServerSimple'.

Use 'startClientAsync' to expose more functionality.
-}
startClientSimple
  :: (Serialise q, Serialise r)
  => HostName
  -> ServiceName
  -> IO (IO (), q -> IO r)
startClientSimple :: forall q r.
(Serialise q, Serialise r) =>
HostName -> HostName -> IO (IO (), q -> IO r)
startClientSimple HostName
host HostName
port = do
  (Async ()
client, Connection
_conn, (TBQueue q
writeQ, TBQueue r
readQ)) <- HostName
-> HostName -> IO (Async (), Connection, MessageQueues q r)
forall q r.
(Serialise q, Serialise r) =>
HostName
-> HostName -> IO (Async (), Connection, MessageQueues q r)
startClientAsync HostName
host HostName
port
  (IO (), q -> IO r) -> IO (IO (), q -> IO r)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( Async () -> IO ()
forall a. Async a -> IO ()
cancel Async ()
client
    , \q
query -> do
        STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TBQueue q -> q -> STM ()
forall a. TBQueue a -> a -> STM ()
writeTBQueue TBQueue q
writeQ q
query
        STM r -> IO r
forall a. STM a -> IO a
atomically (STM r -> IO r) -> STM r -> IO r
forall a b. (a -> b) -> a -> b
$ TBQueue r -> STM r
forall a. TBQueue a -> STM a
readTBQueue TBQueue r
readQ
    )

{- | Start a client wrapper that will wait for a connection.

Canceling the exposed worker thread will terminate connection.
The exposed connection can be used to request more streams.
The message queues are running CBOR codec to shuttle the data.
-}
startClientAsync
  :: (Serialise q, Serialise r)
  => HostName
  -> ServiceName
  -> IO (Async (), Connection, MessageQueues q r)
startClientAsync :: forall q r.
(Serialise q, Serialise r) =>
HostName
-> HostName -> IO (Async (), Connection, MessageQueues q r)
startClientAsync HostName
host HostName
port = do
  MVar (Connection, (Async (), MessageQueues q r))
client <- IO (MVar (Connection, (Async (), MessageQueues q r)))
forall a. IO (MVar a)
newEmptyMVar
  Async ()
tid <- IO () -> IO (Async ())
forall a. IO a -> IO (Async a)
async (IO () -> IO (Async ())) -> IO () -> IO (Async ())
forall a b. (a -> b) -> a -> b
$ HostName -> HostName -> (Connection -> Stream -> IO ()) -> IO ()
runClient HostName
host HostName
port \Connection
conn Stream
stream0 -> do
    (Async (), MessageQueues q r)
queues <- Stream -> IO (Async (), MessageQueues q r)
forall sendMsg recvMsg.
(Serialise sendMsg, Serialise recvMsg) =>
Stream -> IO (Async (), MessageQueues sendMsg recvMsg)
streamSerialise Stream
stream0
    MVar (Connection, (Async (), MessageQueues q r))
-> (Connection, (Async (), MessageQueues q r)) -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar (Connection, (Async (), MessageQueues q r))
client (Connection
conn, (Async (), MessageQueues q r)
queues)
    IO () -> IO ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (Int -> IO ()
threadDelay Int
forall a. Bounded a => a
maxBound)
  (Connection
conn, (Async ()
codec, MessageQueues q r
queues)) <- MVar (Connection, (Async (), MessageQueues q r))
-> IO (Connection, (Async (), MessageQueues q r))
forall a. MVar a -> IO a
takeMVar MVar (Connection, (Async (), MessageQueues q r))
client IO (Connection, (Async (), MessageQueues q r))
-> IO () -> IO (Connection, (Async (), MessageQueues q r))
forall a b. IO a -> IO b -> IO a
`onException` Async () -> IO ()
forall a. Async a -> IO ()
cancel Async ()
tid
  Async () -> Async () -> IO ()
forall a b. Async a -> Async b -> IO ()
link2 Async ()
codec Async ()
tid
  (Async (), Connection, MessageQueues q r)
-> IO (Async (), Connection, MessageQueues q r)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( Async ()
tid
    , Connection
conn
    , MessageQueues q r
queues
    )