{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE InterruptibleFFI #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Network.QUIC.Client.Run (
    run,
    migrate,
) where

import Control.Concurrent
import Control.Concurrent.Async
import qualified Control.Exception as E
import Foreign.C.Types
import qualified Network.Socket as NS

import Network.QUIC.Client.Reader
import Network.QUIC.Closer
import Network.QUIC.Common
import Network.QUIC.Config
import Network.QUIC.Connection
import Network.QUIC.Crypto
import Network.QUIC.Handshake
import Network.QUIC.Imports
import Network.QUIC.Logger
import Network.QUIC.Parameters
import Network.QUIC.QLogger
import Network.QUIC.Receiver
import Network.QUIC.Recovery
import Network.QUIC.Sender
import Network.QUIC.Types

----------------------------------------------------------------

-- | Running a QUIC client.
--   A UDP socket is created according to 'ccServerName' and 'ccPortName'.
run :: ClientConfig -> (Connection -> IO a) -> IO a
-- Don't use handleLogUnit here because of a return value.
run :: forall a. ClientConfig -> (Connection -> IO a) -> IO a
run ClientConfig
conf Connection -> IO a
client = do
    let resInfo :: ResumptionInfo
resInfo = ClientConfig -> ResumptionInfo
ccResumption ClientConfig
conf
        verInfo :: VersionInfo
verInfo = case ResumptionInfo -> [(ByteString, SessionData)]
resumptionSession ResumptionInfo
resInfo of
            []
                | ResumptionInfo -> ByteString
resumptionToken ResumptionInfo
resInfo ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
emptyToken ->
                    let ver :: Version
ver = ClientConfig -> Version
ccVersion ClientConfig
conf
                        vers :: [Version]
vers = ClientConfig -> [Version]
ccVersions ClientConfig
conf
                     in Version -> [Version] -> VersionInfo
VersionInfo Version
ver [Version]
vers
            [(ByteString, SessionData)]
_ -> let ver :: Version
ver = ResumptionInfo -> Version
resumptionVersion ResumptionInfo
resInfo in Version -> [Version] -> VersionInfo
VersionInfo Version
ver [Version
ver]
    -- Exceptions except NextVersion are passed through.
    Either NextVersion a
ex <- IO a -> IO (Either NextVersion a)
forall e a. Exception e => IO a -> IO (Either e a)
E.try (IO a -> IO (Either NextVersion a))
-> IO a -> IO (Either NextVersion a)
forall a b. (a -> b) -> a -> b
$ ClientConfig -> (Connection -> IO a) -> Bool -> VersionInfo -> IO a
forall a.
ClientConfig -> (Connection -> IO a) -> Bool -> VersionInfo -> IO a
runClient ClientConfig
conf Connection -> IO a
client Bool
False VersionInfo
verInfo
    case Either NextVersion a
ex of
        Right a
v -> a -> IO a
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
v
        -- Other exceptions go though.
        Left (NextVersion VersionInfo
nextVerInfo)
            | VersionInfo
verInfo VersionInfo -> VersionInfo -> Bool
forall a. Eq a => a -> a -> Bool
== VersionInfo
brokenVersionInfo -> QUICException -> IO a
forall e a. Exception e => e -> IO a
E.throwIO QUICException
VersionNegotiationFailed
            | Bool
otherwise -> ClientConfig -> (Connection -> IO a) -> Bool -> VersionInfo -> IO a
forall a.
ClientConfig -> (Connection -> IO a) -> Bool -> VersionInfo -> IO a
runClient ClientConfig
conf Connection -> IO a
client Bool
True VersionInfo
nextVerInfo

runClient :: ClientConfig -> (Connection -> IO a) -> Bool -> VersionInfo -> IO a
runClient :: forall a.
ClientConfig -> (Connection -> IO a) -> Bool -> VersionInfo -> IO a
runClient ClientConfig
conf Connection -> IO a
client0 Bool
isICVN VersionInfo
verInfo = do
    IO ConnRes -> (ConnRes -> IO ()) -> (ConnRes -> IO a) -> IO a
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracket IO ConnRes
open ConnRes -> IO ()
clse ((ConnRes -> IO a) -> IO a) -> (ConnRes -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \(ConnRes Connection
conn AuthCIDs
myAuthCIDs IO ()
reader) -> do
        Connection -> IO () -> IO ()
forkManaged Connection
conn IO ()
reader
        let conf' :: ClientConfig
conf' =
                ClientConfig
conf
                    { ccParameters =
                        (ccParameters conf)
                            { versionInformation = Just verInfo
                            }
                    }
        Connection -> Bool -> IO ()
setIncompatibleVN Connection
conn Bool
isICVN -- must be before handshaker
        Connection -> ByteString -> IO ()
setToken Connection
conn (ByteString -> IO ()) -> ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ ResumptionInfo -> ByteString
resumptionToken (ResumptionInfo -> ByteString) -> ResumptionInfo -> ByteString
forall a b. (a -> b) -> a -> b
$ ClientConfig -> ResumptionInfo
ccResumption ClientConfig
conf
        IO ()
handshaker <- ClientConfig -> Connection -> AuthCIDs -> IO (IO ())
handshakeClient ClientConfig
conf' Connection
conn AuthCIDs
myAuthCIDs
        let client :: IO a
client = do
                -- For 0-RTT, the following variables should be initialized
                -- in advance.
                Connection -> Int -> IO ()
setTxMaxStreams Connection
conn (Int -> IO ()) -> Int -> IO ()
forall a b. (a -> b) -> a -> b
$ Parameters -> Int
initialMaxStreamsBidi Parameters
defaultParameters
                Connection -> Int -> IO ()
setTxUniMaxStreams Connection
conn (Int -> IO ()) -> Int -> IO ()
forall a b. (a -> b) -> a -> b
$ Parameters -> Int
initialMaxStreamsUni Parameters
defaultParameters
                if ClientConfig -> Bool
ccUse0RTT ClientConfig
conf
                    then Connection -> IO ()
wait0RTTReady Connection
conn
                    else Connection -> IO ()
wait1RTTReady Connection
conn
                Connection -> IO a
client0 Connection
conn
            ldcc :: LDCC
ldcc = Connection -> LDCC
connLDCC Connection
conn
            supporters :: IO ()
supporters =
                (IO () -> IO () -> IO ()) -> [IO ()] -> IO ()
forall a. (a -> a -> a) -> [a] -> a
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldr1
                    IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO ()
concurrently_
                    [ IO ()
handshaker
                    , Connection -> IO ()
sender Connection
conn
                    , Connection -> IO ()
receiver Connection
conn
                    , LDCC -> IO ()
resender LDCC
ldcc
                    , LDCC -> IO ()
ldccTimer LDCC
ldcc
                    ]
            runThreads :: IO a
runThreads = do
                Either () a
er <- IO () -> IO a -> IO (Either () a)
forall a b. IO a -> IO b -> IO (Either a b)
race IO ()
supporters IO a
client
                case Either () a
er of
                    Left () -> InternalControl -> IO a
forall e a. Exception e => e -> IO a
E.throwIO InternalControl
MustNotReached
                    Right a
r -> a -> IO a
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
r
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ClientConfig -> Bool
ccWatchDog ClientConfig
conf) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> IO () -> IO ()
forkManaged Connection
conn (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> IO ()
watchDog Connection
conn
        Either SomeException a
ex <- IO a -> IO (Either SomeException a)
forall e a. Exception e => IO a -> IO (Either e a)
E.try IO a
runThreads
        Connection -> IO ()
sendFinal Connection
conn
        Connection -> LDCC -> Either SomeException a -> IO a
forall a. Connection -> LDCC -> Either SomeException a -> IO a
closure Connection
conn LDCC
ldcc Either SomeException a
ex
  where
    open :: IO ConnRes
open = ClientConfig -> VersionInfo -> IO ConnRes
createClientConnection ClientConfig
conf VersionInfo
verInfo
    clse :: ConnRes -> IO ()
clse ConnRes
connRes = do
        let conn :: Connection
conn = ConnRes -> Connection
connResConnection ConnRes
connRes
        Connection -> IO ()
setDead Connection
conn
        Connection -> IO ()
freeResources Connection
conn
        Connection -> IO ()
killReaders Connection
conn

createClientConnection :: ClientConfig -> VersionInfo -> IO ConnRes
createClientConnection :: ClientConfig -> VersionInfo -> IO ConnRes
createClientConnection conf :: ClientConfig
conf@ClientConfig{Bool
ServiceName
[Group]
[Cipher]
[Version]
Maybe Int
Maybe ServiceName
Credentials
ClientHooks
Version
ResumptionInfo
Parameters
Hooks
ServiceName -> IO ()
Version -> IO (Maybe [ByteString])
ccServerName :: ClientConfig -> ServiceName
ccPortName :: ClientConfig -> ServiceName
ccResumption :: ClientConfig -> ResumptionInfo
ccVersion :: ClientConfig -> Version
ccVersions :: ClientConfig -> [Version]
ccParameters :: ClientConfig -> Parameters
ccUse0RTT :: ClientConfig -> Bool
ccWatchDog :: ClientConfig -> Bool
ccVersion :: Version
ccVersions :: [Version]
ccCiphers :: [Cipher]
ccGroups :: [Group]
ccParameters :: Parameters
ccKeyLog :: ServiceName -> IO ()
ccQLog :: Maybe ServiceName
ccCredentials :: Credentials
ccHooks :: Hooks
ccTlsHooks :: ClientHooks
ccUse0RTT :: Bool
ccServerName :: ServiceName
ccPortName :: ServiceName
ccALPN :: Version -> IO (Maybe [ByteString])
ccValidate :: Bool
ccResumption :: ResumptionInfo
ccPacketSize :: Maybe Int
ccDebugLog :: Bool
ccSockConnected :: Bool
ccWatchDog :: Bool
ccServerNameOverride :: Maybe ServiceName
ccServerNameOverride :: ClientConfig -> Maybe ServiceName
ccSockConnected :: ClientConfig -> Bool
ccDebugLog :: ClientConfig -> Bool
ccPacketSize :: ClientConfig -> Maybe Int
ccValidate :: ClientConfig -> Bool
ccALPN :: ClientConfig -> Version -> IO (Maybe [ByteString])
ccTlsHooks :: ClientConfig -> ClientHooks
ccHooks :: ClientConfig -> Hooks
ccCredentials :: ClientConfig -> Credentials
ccQLog :: ClientConfig -> Maybe ServiceName
ccKeyLog :: ClientConfig -> ServiceName -> IO ()
ccGroups :: ClientConfig -> [Group]
ccCiphers :: ClientConfig -> [Cipher]
..} VersionInfo
verInfo = do
    (Socket
sock, SockAddr
peersa) <- ServiceName -> ServiceName -> IO (Socket, SockAddr)
clientSocket ServiceName
ccServerName ServiceName
ccPortName
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
ccSockConnected (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Socket -> SockAddr -> IO ()
NS.connect Socket
sock SockAddr
peersa
    RecvQ
q <- IO RecvQ
newRecvQ
    IORef Socket
sref <- Socket -> IO (IORef Socket)
forall a. a -> IO (IORef a)
newIORef Socket
sock
    PathInfo
pathInfo <- SockAddr -> IO PathInfo
newPathInfo SockAddr
peersa
    IORef PeerInfo
piref <- PeerInfo -> IO (IORef PeerInfo)
forall a. a -> IO (IORef a)
newIORef (PeerInfo -> IO (IORef PeerInfo))
-> PeerInfo -> IO (IORef PeerInfo)
forall a b. (a -> b) -> a -> b
$ PathInfo -> Maybe PathInfo -> PeerInfo
PeerInfo PathInfo
pathInfo Maybe PathInfo
forall a. Maybe a
Nothing
    let send :: Ptr Word8 -> Int -> IO ()
send Ptr Word8
buf Int
siz
            | Bool
ccSockConnected = do
                Socket
s <- IORef Socket -> IO Socket
forall a. IORef a -> IO a
readIORef IORef Socket
sref
                IO Int -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Int -> IO ()) -> IO Int -> IO ()
forall a b. (a -> b) -> a -> b
$ Socket -> Ptr Word8 -> Int -> IO Int
NS.sendBuf Socket
s Ptr Word8
buf Int
siz
            | Bool
otherwise = do
                Socket
s <- IORef Socket -> IO Socket
forall a. IORef a -> IO a
readIORef IORef Socket
sref
                PeerInfo PathInfo
pinfo Maybe PathInfo
_ <- IORef PeerInfo -> IO PeerInfo
forall a. IORef a -> IO a
readIORef IORef PeerInfo
piref
                IO Int -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Int -> IO ()) -> IO Int -> IO ()
forall a b. (a -> b) -> a -> b
$ Socket -> Ptr Word8 -> Int -> SockAddr -> IO Int
forall a. Socket -> Ptr a -> Int -> SockAddr -> IO Int
NS.sendBufTo Socket
s Ptr Word8
buf Int
siz (SockAddr -> IO Int) -> SockAddr -> IO Int
forall a b. (a -> b) -> a -> b
$ PathInfo -> SockAddr
peerSockAddr PathInfo
pinfo
        recv :: IO ReceivedPacket
recv = RecvQ -> IO ReceivedPacket
recvClient RecvQ
q
    CID
myCID <- IO CID
newCID
    -- Creating peer's CIDDB with the temporary CID.  This is
    -- overridden by resetPeerCID later since no sequence number is
    -- assigned to the temporary CID by spec.
    CID
peerCID <- IO CID
newCID
    TimeMicrosecond
now <- IO TimeMicrosecond
getTimeMicrosecond
    (QLogger
qLog, IO ()
qclean) <- Maybe ServiceName
-> TimeMicrosecond -> CID -> ByteString -> IO (QLogger, IO ())
dirQLogger Maybe ServiceName
ccQLog TimeMicrosecond
now CID
peerCID ByteString
"client"
    let debugLog :: Builder -> IO ()
debugLog Builder
msg
            | Bool
ccDebugLog = Builder -> IO ()
stdoutLogger Builder
msg
            | Bool
otherwise = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    Builder -> IO ()
debugLog (Builder -> IO ()) -> Builder -> IO ()
forall a b. (a -> b) -> a -> b
$ Builder
"Original CID: " Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> CID -> Builder
forall a. Show a => a -> Builder
bhow CID
peerCID
    let myAuthCIDs :: AuthCIDs
myAuthCIDs = AuthCIDs
defaultAuthCIDs{initSrcCID = Just myCID}
        peerAuthCIDs :: AuthCIDs
peerAuthCIDs = AuthCIDs
defaultAuthCIDs{initSrcCID = Just peerCID, origDstCID = Just peerCID}
    CID -> StatelessResetToken
genSRT <- IO (CID -> StatelessResetToken)
makeGenStatelessReset
    Connection
conn <-
        ClientConfig
-> VersionInfo
-> AuthCIDs
-> AuthCIDs
-> (Builder -> IO ())
-> QLogger
-> Hooks
-> IORef Socket
-> IORef PeerInfo
-> RecvQ
-> (Ptr Word8 -> Int -> IO ())
-> IO ReceivedPacket
-> (CID -> StatelessResetToken)
-> IO Connection
clientConnection
            ClientConfig
conf
            VersionInfo
verInfo
            AuthCIDs
myAuthCIDs
            AuthCIDs
peerAuthCIDs
            Builder -> IO ()
debugLog
            QLogger
qLog
            Hooks
ccHooks
            IORef Socket
sref
            IORef PeerInfo
piref
            RecvQ
q
            Ptr Word8 -> Int -> IO ()
send
            IO ReceivedPacket
recv
            CID -> StatelessResetToken
genSRT
    Connection -> Bool -> IO ()
setSockConnected Connection
conn Bool
ccSockConnected
    Connection -> IO () -> IO ()
addResource Connection
conn IO ()
qclean
    Connection -> ResumptionInfo -> IO ()
modifytPeerParameters Connection
conn ResumptionInfo
ccResumption
    let ver :: Version
ver = VersionInfo -> Version
chosenVersion VersionInfo
verInfo
    Connection
-> EncryptionLevel -> TrafficSecrets InitialSecret -> IO ()
forall a.
Connection -> EncryptionLevel -> TrafficSecrets a -> IO ()
initializeCoder Connection
conn EncryptionLevel
InitialLevel (TrafficSecrets InitialSecret -> IO ())
-> TrafficSecrets InitialSecret -> IO ()
forall a b. (a -> b) -> a -> b
$ Version -> CID -> TrafficSecrets InitialSecret
initialSecrets Version
ver CID
peerCID
    Connection -> IO ()
setupCryptoStreams Connection
conn -- fixme: cleanup
    let pktSiz0 :: Int
pktSiz0 = Int -> Maybe Int -> Int
forall a. a -> Maybe a -> a
fromMaybe Int
0 Maybe Int
ccPacketSize
        pktSiz :: Int
pktSiz = (SockAddr -> Int
defaultPacketSize SockAddr
peersa Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
pktSiz0) Int -> Int -> Int
forall a. Ord a => a -> a -> a
`min` SockAddr -> Int
maximumPacketSize SockAddr
peersa
    Connection -> Int -> IO ()
setMaxPacketSize Connection
conn Int
pktSiz
    LDCC -> Int -> IO ()
setInitialCongestionWindow (Connection -> LDCC
connLDCC Connection
conn) Int
pktSiz
    PathInfo -> IO ()
setAddressValidated PathInfo
pathInfo
    let reader :: IO ()
reader = Socket -> Connection -> IO ()
readerClient Socket
sock Connection
conn -- dies when s0 is closed.
    ConnRes -> IO ConnRes
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ConnRes -> IO ConnRes) -> ConnRes -> IO ConnRes
forall a b. (a -> b) -> a -> b
$ Connection -> AuthCIDs -> IO () -> ConnRes
ConnRes Connection
conn AuthCIDs
myAuthCIDs IO ()
reader

-- | Creating a new socket and execute a path validation
--   with a new connection ID. Typically, this is used
--   for migration in the case where 'ccSockConnected' is 'True'.
--   But this can also be used even when the value is 'False'.
migrate :: Connection -> IO Bool
migrate :: Connection -> IO Bool
migrate Connection
conn = Connection -> ConnectionControl -> IO Bool
controlConnection Connection
conn ConnectionControl
ActiveMigration

watchDog :: Connection -> IO ()
watchDog :: Connection -> IO ()
watchDog Connection
conn = IO CInt -> (CInt -> IO CInt) -> (CInt -> IO ()) -> IO ()
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracket IO CInt
c_open_socket CInt -> IO CInt
c_close_socket CInt -> IO ()
loop
  where
    loop :: CInt -> IO ()
loop CInt
s = do
        CInt
ret <- CInt -> IO CInt
c_watch_socket CInt
s
        case CInt
ret of
            -1 -> CInt -> IO ()
loop CInt
s
            -2 -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
            CInt
_ -> do
                Bool
_ <- Connection -> IO Bool
migrate Connection
conn
                -- prevent calling "migrate" frequently
                Int -> IO ()
threadDelay Int
100000
                CInt -> IO ()
loop CInt
s

foreign import ccall unsafe "open_socket"
    c_open_socket :: IO CInt

foreign import ccall interruptible "watch_socket"
    c_watch_socket :: CInt -> IO CInt

foreign import ccall unsafe "close_socket"
    c_close_socket :: CInt -> IO CInt