{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Network.QUIC.Client.Run (
    run,
    migrate,
) where

import Control.Concurrent.Async
import qualified Control.Exception as E
import qualified Network.Socket as NS

import Network.QUIC.Client.Reader
import Network.QUIC.Closer
import Network.QUIC.Common
import Network.QUIC.Config
import Network.QUIC.Connection
import Network.QUIC.Crypto
import Network.QUIC.Handshake
import Network.QUIC.Imports
import Network.QUIC.Logger
import Network.QUIC.Parameters
import Network.QUIC.QLogger
import Network.QUIC.Receiver
import Network.QUIC.Recovery
import Network.QUIC.Sender
import Network.QUIC.Types

----------------------------------------------------------------

-- | Running a QUIC client.
--   A UDP socket is created according to 'ccServerName' and 'ccPortName'.
--
--   If 'ccAutoMigration' is 'True', a unconnected socket is made.
--   Otherwise, a connected socket is made.
--   Use the 'migrate' API for the connected socket.
run :: ClientConfig -> (Connection -> IO a) -> IO a
-- Don't use handleLogUnit here because of a return value.
run :: forall a. ClientConfig -> (Connection -> IO a) -> IO a
run ClientConfig
conf Connection -> IO a
client = do
    let resInfo :: ResumptionInfo
resInfo = ClientConfig -> ResumptionInfo
ccResumption ClientConfig
conf
        verInfo :: VersionInfo
verInfo = case ResumptionInfo -> Maybe (ByteString, SessionData)
resumptionSession ResumptionInfo
resInfo of
            Maybe (ByteString, SessionData)
Nothing
                | ResumptionInfo -> ByteString
resumptionToken ResumptionInfo
resInfo ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
emptyToken ->
                    let ver :: Version
ver = ClientConfig -> Version
ccVersion ClientConfig
conf
                        vers :: [Version]
vers = ClientConfig -> [Version]
ccVersions ClientConfig
conf
                     in Version -> [Version] -> VersionInfo
VersionInfo Version
ver [Version]
vers
            Maybe (ByteString, SessionData)
_ -> let ver :: Version
ver = ResumptionInfo -> Version
resumptionVersion ResumptionInfo
resInfo in Version -> [Version] -> VersionInfo
VersionInfo Version
ver [Version
ver]
    -- Exceptions except NextVersion are passed through.
    Either NextVersion a
ex <- IO a -> IO (Either NextVersion a)
forall e a. Exception e => IO a -> IO (Either e a)
E.try (IO a -> IO (Either NextVersion a))
-> IO a -> IO (Either NextVersion a)
forall a b. (a -> b) -> a -> b
$ ClientConfig -> (Connection -> IO a) -> Bool -> VersionInfo -> IO a
forall a.
ClientConfig -> (Connection -> IO a) -> Bool -> VersionInfo -> IO a
runClient ClientConfig
conf Connection -> IO a
client Bool
False VersionInfo
verInfo
    case Either NextVersion a
ex of
        Right a
v -> a -> IO a
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
v
        -- Other exceptions go though.
        Left (NextVersion VersionInfo
nextVerInfo)
            | VersionInfo
verInfo VersionInfo -> VersionInfo -> Bool
forall a. Eq a => a -> a -> Bool
== VersionInfo
brokenVersionInfo -> QUICException -> IO a
forall e a. Exception e => e -> IO a
E.throwIO QUICException
VersionNegotiationFailed
            | Bool
otherwise -> ClientConfig -> (Connection -> IO a) -> Bool -> VersionInfo -> IO a
forall a.
ClientConfig -> (Connection -> IO a) -> Bool -> VersionInfo -> IO a
runClient ClientConfig
conf Connection -> IO a
client Bool
True VersionInfo
nextVerInfo

runClient :: ClientConfig -> (Connection -> IO a) -> Bool -> VersionInfo -> IO a
runClient :: forall a.
ClientConfig -> (Connection -> IO a) -> Bool -> VersionInfo -> IO a
runClient ClientConfig
conf Connection -> IO a
client0 Bool
isICVN VersionInfo
verInfo = do
    IO ConnRes -> (ConnRes -> IO ()) -> (ConnRes -> IO a) -> IO a
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracket IO ConnRes
open ConnRes -> IO ()
clse ((ConnRes -> IO a) -> IO a) -> (ConnRes -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \(ConnRes Connection
conn AuthCIDs
myAuthCIDs IO ()
reader) -> do
        Connection -> IO () -> IO ()
forkManaged Connection
conn IO ()
reader
        let conf' :: ClientConfig
conf' =
                ClientConfig
conf
                    { ccParameters =
                        (ccParameters conf)
                            { versionInformation = Just verInfo
                            }
                    }
        Connection -> Bool -> IO ()
setIncompatibleVN Connection
conn Bool
isICVN -- must be before handshaker
        Connection -> ByteString -> IO ()
setToken Connection
conn (ByteString -> IO ()) -> ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ ResumptionInfo -> ByteString
resumptionToken (ResumptionInfo -> ByteString) -> ResumptionInfo -> ByteString
forall a b. (a -> b) -> a -> b
$ ClientConfig -> ResumptionInfo
ccResumption ClientConfig
conf
        IO ()
handshaker <- ClientConfig -> Connection -> AuthCIDs -> IO (IO ())
handshakeClient ClientConfig
conf' Connection
conn AuthCIDs
myAuthCIDs
        let client :: IO a
client = do
                -- For 0-RTT, the following variables should be initialized
                -- in advance.
                Connection -> Int -> IO ()
setTxMaxStreams Connection
conn (Int -> IO ()) -> Int -> IO ()
forall a b. (a -> b) -> a -> b
$ Parameters -> Int
initialMaxStreamsBidi Parameters
defaultParameters
                Connection -> Int -> IO ()
setTxUniMaxStreams Connection
conn (Int -> IO ()) -> Int -> IO ()
forall a b. (a -> b) -> a -> b
$ Parameters -> Int
initialMaxStreamsUni Parameters
defaultParameters
                if ClientConfig -> Bool
ccUse0RTT ClientConfig
conf
                    then Connection -> IO ()
wait0RTTReady Connection
conn
                    else Connection -> IO ()
wait1RTTReady Connection
conn
                Connection -> IO a
client0 Connection
conn
            ldcc :: LDCC
ldcc = Connection -> LDCC
connLDCC Connection
conn
            supporters :: IO ()
supporters =
                (IO () -> IO () -> IO ()) -> [IO ()] -> IO ()
forall a. (a -> a -> a) -> [a] -> a
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldr1
                    IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO ()
concurrently_
                    [ IO ()
handshaker
                    , Connection -> IO ()
sender Connection
conn
                    , Connection -> IO ()
receiver Connection
conn
                    , LDCC -> IO ()
resender LDCC
ldcc
                    , LDCC -> IO ()
ldccTimer LDCC
ldcc
                    ]
            runThreads :: IO a
runThreads = do
                Either () a
er <- IO () -> IO a -> IO (Either () a)
forall a b. IO a -> IO b -> IO (Either a b)
race IO ()
supporters IO a
client
                case Either () a
er of
                    Left () -> InternalControl -> IO a
forall e a. Exception e => e -> IO a
E.throwIO InternalControl
MustNotReached
                    Right a
r -> a -> IO a
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
r
        Either SomeException a
ex <- IO a -> IO (Either SomeException a)
forall e a. Exception e => IO a -> IO (Either e a)
E.try IO a
runThreads
        Connection -> IO ()
sendFinal Connection
conn
        Connection -> LDCC -> Either SomeException a -> IO a
forall a. Connection -> LDCC -> Either SomeException a -> IO a
closure Connection
conn LDCC
ldcc Either SomeException a
ex
  where
    open :: IO ConnRes
open = ClientConfig -> VersionInfo -> IO ConnRes
createClientConnection ClientConfig
conf VersionInfo
verInfo
    clse :: ConnRes -> IO ()
clse ConnRes
connRes = do
        let conn :: Connection
conn = ConnRes -> Connection
connResConnection ConnRes
connRes
        Connection -> IO ()
setDead Connection
conn
        Connection -> IO ()
freeResources Connection
conn
        Connection -> IO ()
killReaders Connection
conn

createClientConnection :: ClientConfig -> VersionInfo -> IO ConnRes
createClientConnection :: ClientConfig -> VersionInfo -> IO ConnRes
createClientConnection conf :: ClientConfig
conf@ClientConfig{Bool
ServiceName
[Group]
[Cipher]
[Version]
Maybe Int
Maybe ServiceName
Credentials
ClientHooks
Version
ResumptionInfo
Parameters
Hooks
ServiceName -> IO ()
Version -> IO (Maybe [ByteString])
ccServerName :: ClientConfig -> ServiceName
ccPortName :: ClientConfig -> ServiceName
ccResumption :: ClientConfig -> ResumptionInfo
ccVersion :: ClientConfig -> Version
ccVersions :: ClientConfig -> [Version]
ccParameters :: ClientConfig -> Parameters
ccUse0RTT :: ClientConfig -> Bool
ccVersion :: Version
ccVersions :: [Version]
ccCiphers :: [Cipher]
ccGroups :: [Group]
ccParameters :: Parameters
ccKeyLog :: ServiceName -> IO ()
ccQLog :: Maybe ServiceName
ccCredentials :: Credentials
ccHooks :: Hooks
ccTlsHooks :: ClientHooks
ccUse0RTT :: Bool
ccServerName :: ServiceName
ccPortName :: ServiceName
ccALPN :: Version -> IO (Maybe [ByteString])
ccValidate :: Bool
ccResumption :: ResumptionInfo
ccPacketSize :: Maybe Int
ccDebugLog :: Bool
ccCiphers :: ClientConfig -> [Cipher]
ccGroups :: ClientConfig -> [Group]
ccKeyLog :: ClientConfig -> ServiceName -> IO ()
ccQLog :: ClientConfig -> Maybe ServiceName
ccCredentials :: ClientConfig -> Credentials
ccHooks :: ClientConfig -> Hooks
ccTlsHooks :: ClientConfig -> ClientHooks
ccALPN :: ClientConfig -> Version -> IO (Maybe [ByteString])
ccValidate :: ClientConfig -> Bool
ccPacketSize :: ClientConfig -> Maybe Int
ccDebugLog :: ClientConfig -> Bool
..} VersionInfo
verInfo = do
    (Socket
sock, SockAddr
peersa) <- ServiceName -> ServiceName -> IO (Socket, SockAddr)
clientSocket ServiceName
ccServerName ServiceName
ccPortName
    RecvQ
q <- IO RecvQ
newRecvQ
    IORef Socket
sref <- Socket -> IO (IORef Socket)
forall a. a -> IO (IORef a)
newIORef Socket
sock
    IORef PeerInfo
piref <- PeerInfo -> IO (IORef PeerInfo)
forall a. a -> IO (IORef a)
newIORef (PeerInfo -> IO (IORef PeerInfo))
-> PeerInfo -> IO (IORef PeerInfo)
forall a b. (a -> b) -> a -> b
$ SockAddr -> PeerInfo
PeerInfo SockAddr
peersa
    let send :: Ptr a -> Int -> IO ()
send Ptr a
buf Int
siz = do
            Socket
s <- IORef Socket -> IO Socket
forall a. IORef a -> IO a
readIORef IORef Socket
sref
            PeerInfo SockAddr
sa <- IORef PeerInfo -> IO PeerInfo
forall a. IORef a -> IO a
readIORef IORef PeerInfo
piref
            IO Int -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Int -> IO ()) -> IO Int -> IO ()
forall a b. (a -> b) -> a -> b
$ Socket -> Ptr a -> Int -> SockAddr -> IO Int
forall a. Socket -> Ptr a -> Int -> SockAddr -> IO Int
NS.sendBufTo Socket
s Ptr a
buf Int
siz SockAddr
sa
        recv :: IO ReceivedPacket
recv = RecvQ -> IO ReceivedPacket
recvClient RecvQ
q
    CID
myCID <- IO CID
newCID
    CID
peerCID <- IO CID
newCID
    TimeMicrosecond
now <- IO TimeMicrosecond
getTimeMicrosecond
    (QLogger
qLog, IO ()
qclean) <- Maybe ServiceName
-> TimeMicrosecond -> CID -> ByteString -> IO (QLogger, IO ())
dirQLogger Maybe ServiceName
ccQLog TimeMicrosecond
now CID
peerCID ByteString
"client"
    let debugLog :: Builder -> IO ()
debugLog Builder
msg
            | Bool
ccDebugLog = Builder -> IO ()
stdoutLogger Builder
msg
            | Bool
otherwise = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    Builder -> IO ()
debugLog (Builder -> IO ()) -> Builder -> IO ()
forall a b. (a -> b) -> a -> b
$ Builder
"Original CID: " Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> CID -> Builder
forall a. Show a => a -> Builder
bhow CID
peerCID
    let myAuthCIDs :: AuthCIDs
myAuthCIDs = AuthCIDs
defaultAuthCIDs{initSrcCID = Just myCID}
        peerAuthCIDs :: AuthCIDs
peerAuthCIDs = AuthCIDs
defaultAuthCIDs{initSrcCID = Just peerCID, origDstCID = Just peerCID}
    CID -> StatelessResetToken
genSRT <- IO (CID -> StatelessResetToken)
makeGenStatelessReset
    Connection
conn <-
        ClientConfig
-> VersionInfo
-> AuthCIDs
-> AuthCIDs
-> (Builder -> IO ())
-> QLogger
-> Hooks
-> IORef Socket
-> IORef PeerInfo
-> RecvQ
-> Send
-> IO ReceivedPacket
-> (CID -> StatelessResetToken)
-> IO Connection
clientConnection
            ClientConfig
conf
            VersionInfo
verInfo
            AuthCIDs
myAuthCIDs
            AuthCIDs
peerAuthCIDs
            Builder -> IO ()
debugLog
            QLogger
qLog
            Hooks
ccHooks
            IORef Socket
sref
            IORef PeerInfo
piref
            RecvQ
q
            Send
forall {a}. Ptr a -> Int -> IO ()
send
            IO ReceivedPacket
recv
            CID -> StatelessResetToken
genSRT
    Connection -> IO () -> IO ()
addResource Connection
conn IO ()
qclean
    let ver :: Version
ver = VersionInfo -> Version
chosenVersion VersionInfo
verInfo
    Connection
-> EncryptionLevel -> TrafficSecrets InitialSecret -> IO ()
forall a.
Connection -> EncryptionLevel -> TrafficSecrets a -> IO ()
initializeCoder Connection
conn EncryptionLevel
InitialLevel (TrafficSecrets InitialSecret -> IO ())
-> TrafficSecrets InitialSecret -> IO ()
forall a b. (a -> b) -> a -> b
$ Version -> CID -> TrafficSecrets InitialSecret
initialSecrets Version
ver CID
peerCID
    Connection -> IO ()
setupCryptoStreams Connection
conn -- fixme: cleanup
    let pktSiz0 :: Int
pktSiz0 = Int -> Maybe Int -> Int
forall a. a -> Maybe a -> a
fromMaybe Int
0 Maybe Int
ccPacketSize
        pktSiz :: Int
pktSiz = (SockAddr -> Int
defaultPacketSize SockAddr
peersa Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
pktSiz0) Int -> Int -> Int
forall a. Ord a => a -> a -> a
`min` SockAddr -> Int
maximumPacketSize SockAddr
peersa
    Connection -> Int -> IO ()
setMaxPacketSize Connection
conn Int
pktSiz
    LDCC -> Int -> IO ()
setInitialCongestionWindow (Connection -> LDCC
connLDCC Connection
conn) Int
pktSiz
    Connection -> IO ()
setAddressValidated Connection
conn
    let reader :: IO ()
reader = Socket -> Connection -> IO ()
readerClient Socket
sock Connection
conn -- dies when s0 is closed.
    ConnRes -> IO ConnRes
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ConnRes -> IO ConnRes) -> ConnRes -> IO ConnRes
forall a b. (a -> b) -> a -> b
$ Connection -> AuthCIDs -> IO () -> ConnRes
ConnRes Connection
conn AuthCIDs
myAuthCIDs IO ()
reader

-- | Creating a new socket and execute a path validation
--   with a new connection ID.
migrate :: Connection -> IO Bool
migrate :: Connection -> IO Bool
migrate Connection
conn = Connection -> ConnectionControl -> IO Bool
controlConnection Connection
conn ConnectionControl
ActiveMigration