{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Network.QUIC.Client.Reader (
    readerClient,
    recvClient,
    ConnectionControl (..),
    controlConnection,
    clientSocket,
) where

import Control.Concurrent
import qualified Control.Exception as E
import Data.List (intersect)
import Network.Socket (Socket, close, getSocketName)
import qualified Network.Socket.ByteString as NSB

import Network.QUIC.Common
import Network.QUIC.Connection
import Network.QUIC.Connector
import Network.QUIC.Crypto
import Network.QUIC.Exception
import Network.QUIC.Imports
import Network.QUIC.Packet
import Network.QUIC.Parameters
import Network.QUIC.Qlog
import Network.QUIC.Recovery
import Network.QUIC.Socket
import Network.QUIC.Types

-- | readerClient dies when the socket is closed.
readerClient :: Socket -> Connection -> IO ()
readerClient :: Socket -> Connection -> IO ()
readerClient Socket
s0 Connection
conn = DebugLogger -> IO () -> IO ()
handleLogUnit DebugLogger
logAction (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
    String -> IO ()
labelMe String
"readerClient"
    IO ()
wait
    IO ()
loop
  where
    wait :: IO ()
wait = do
        Bool
bound <- (SomeException -> IO Bool) -> IO Bool -> IO Bool
forall e a. Exception e => (e -> IO a) -> IO a -> IO a
E.handle (IO Bool -> SomeException -> IO Bool
forall a. IO a -> SomeException -> IO a
throughAsync (Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False)) (IO Bool -> IO Bool) -> IO Bool -> IO Bool
forall a b. (a -> b) -> a -> b
$ do
            SockAddr
_ <- Socket -> IO SockAddr
getSocketName Socket
s0
            Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
bound (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
            IO ()
yield
            IO ()
wait
    loop :: IO ()
loop = do
        Microseconds
ito <- Connection -> IO Microseconds
readMinIdleTimeout Connection
conn
        Maybe (ByteString, SockAddr)
mbs <- Microseconds
-> String
-> IO (ByteString, SockAddr)
-> IO (Maybe (ByteString, SockAddr))
forall a. Microseconds -> String -> IO a -> IO (Maybe a)
timeout Microseconds
ito String
"readeClient" (IO (ByteString, SockAddr) -> IO (Maybe (ByteString, SockAddr)))
-> IO (ByteString, SockAddr) -> IO (Maybe (ByteString, SockAddr))
forall a b. (a -> b) -> a -> b
$ Socket -> Int -> IO (ByteString, SockAddr)
NSB.recvFrom Socket
s0 Int
2048
        case Maybe (ByteString, SockAddr)
mbs of
            Maybe (ByteString, SockAddr)
Nothing -> Socket -> IO ()
close Socket
s0
            Just (ByteString
bs, SockAddr
peersa) -> do
                PeerInfo SockAddr
peersa' <- Connection -> IO PeerInfo
getPeerInfo Connection
conn
                Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (SockAddr
peersa SockAddr -> SockAddr -> Bool
forall a. Eq a => a -> a -> Bool
== SockAddr
peersa') (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
                    TimeMicrosecond
now <- IO TimeMicrosecond
getTimeMicrosecond
                    let quicBit :: Bool
quicBit = Parameters -> Bool
greaseQuicBit (Parameters -> Bool) -> Parameters -> Bool
forall a b. (a -> b) -> a -> b
$ Connection -> Parameters
getMyParameters Connection
conn
                    [PacketI]
pkts <- ByteString -> Bool -> IO [PacketI]
decodePackets ByteString
bs (Bool -> Bool
not Bool
quicBit)
                    (PacketI -> IO ()) -> [PacketI] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (TimeMicrosecond -> PacketI -> IO ()
putQ TimeMicrosecond
now) [PacketI]
pkts
                IO ()
loop
    logAction :: DebugLogger
logAction Builder
msg = Connection -> DebugLogger
connDebugLog Connection
conn (Builder
"debug: readerClient: " Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Builder
msg)
    putQ :: TimeMicrosecond -> PacketI -> IO ()
putQ TimeMicrosecond
_ (PacketIB BrokenPacket
BrokenPacket Int
_) = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    putQ TimeMicrosecond
t (PacketIV pkt :: VersionNegotiationPacket
pkt@(VersionNegotiationPacket CID
dCID CID
sCID [Version]
peerVers)) = do
        Connection -> VersionNegotiationPacket -> TimeMicrosecond -> IO ()
forall q a.
(KeepQlog q, Qlog a) =>
q -> a -> TimeMicrosecond -> IO ()
qlogReceived Connection
conn VersionNegotiationPacket
pkt TimeMicrosecond
t
        VersionInfo
myVerInfo <- Connection -> IO VersionInfo
getVersionInfo Connection
conn
        let myVer :: Version
myVer = VersionInfo -> Version
chosenVersion VersionInfo
myVerInfo
            myVers0 :: [Version]
myVers0 = VersionInfo -> [Version]
otherVersions VersionInfo
myVerInfo
        -- ignoring VN if the original version is included.
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Version
myVer Version -> [Version] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [Version]
peerVers Bool -> Bool -> Bool
&& Version
Negotiation Version -> [Version] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [Version]
peerVers) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
            Bool
ok <- Connection -> CID -> Either CID (ByteString, ByteString) -> IO Bool
checkCIDs Connection
conn CID
dCID (CID -> Either CID (ByteString, ByteString)
forall a b. a -> Either a b
Left CID
sCID)
            let myVers :: [Version]
myVers = (Version -> Bool) -> [Version] -> [Version]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (Version -> Bool) -> Version -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Version -> Bool
isGreasingVersion) [Version]
myVers0
                nextVerInfo :: VersionInfo
nextVerInfo = case [Version]
myVers [Version] -> [Version] -> [Version]
forall a. Eq a => [a] -> [a] -> [a]
`intersect` [Version]
peerVers of
                    vers :: [Version]
vers@(Version
ver : [Version]
_) | Bool
ok -> Version -> [Version] -> VersionInfo
VersionInfo Version
ver [Version]
vers
                    [Version]
_ -> VersionInfo
brokenVersionInfo
            ThreadId -> Abort -> IO ()
forall e. Exception e => ThreadId -> e -> IO ()
E.throwTo (Connection -> ThreadId
mainThreadId Connection
conn) (Abort -> IO ()) -> Abort -> IO ()
forall a b. (a -> b) -> a -> b
$ VersionInfo -> Abort
VerNego VersionInfo
nextVerInfo
    putQ TimeMicrosecond
t (PacketIC pkt :: CryptPacket
pkt@(CryptPacket Header
hdr Crypt
crypt) EncryptionLevel
lvl Int
siz) = do
        let cid :: CID
cid = Header -> CID
headerMyCID Header
hdr
        Maybe Int
included <- Connection -> CID -> IO (Maybe Int)
myCIDsInclude Connection
conn CID
cid
        case Maybe Int
included of
            Just Int
_ -> RecvQ -> ReceivedPacket -> IO ()
writeRecvQ (Connection -> RecvQ
connRecvQ Connection
conn) (ReceivedPacket -> IO ()) -> ReceivedPacket -> IO ()
forall a b. (a -> b) -> a -> b
$ CryptPacket
-> TimeMicrosecond -> Int -> EncryptionLevel -> ReceivedPacket
mkReceivedPacket CryptPacket
pkt TimeMicrosecond
t Int
siz EncryptionLevel
lvl
            Maybe Int
Nothing -> case ByteString -> Maybe StatelessResetToken
decodeStatelessResetToken (Crypt -> ByteString
cryptPacket Crypt
crypt) of
                Just StatelessResetToken
token -> do
                    Bool
isStatelessReset <- Connection -> StatelessResetToken -> IO Bool
isStatelessRestTokenValid Connection
conn StatelessResetToken
token
                    -- Our client does not send a stateless reset:
                    -- 1) Stateless reset token is not generated for
                    --    the my first CID.
                    -- 2) It's unlikely that QUIC packets are delivered
                    --    to a new UDP port when out client is rebooted.
                    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
isStatelessReset (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
                        Connection -> StatelessReset -> TimeMicrosecond -> IO ()
forall q a.
(KeepQlog q, Qlog a) =>
q -> a -> TimeMicrosecond -> IO ()
qlogReceived Connection
conn StatelessReset
StatelessReset TimeMicrosecond
t
                        Connection -> DebugLogger
connDebugLog Connection
conn Builder
"debug: connection is reset statelessly"
                        ThreadId -> QUICException -> IO ()
forall e. Exception e => ThreadId -> e -> IO ()
E.throwTo (Connection -> ThreadId
mainThreadId Connection
conn) QUICException
ConnectionIsReset
                Maybe StatelessResetToken
_ -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return () -- really invalid, just ignore
    putQ TimeMicrosecond
t (PacketIR pkt :: RetryPacket
pkt@(RetryPacket Version
ver CID
dCID CID
sCID ByteString
token Either CID (ByteString, ByteString)
ex)) = do
        Connection -> RetryPacket -> TimeMicrosecond -> IO ()
forall q a.
(KeepQlog q, Qlog a) =>
q -> a -> TimeMicrosecond -> IO ()
qlogReceived Connection
conn RetryPacket
pkt TimeMicrosecond
t
        Bool
ok <- Connection -> CID -> Either CID (ByteString, ByteString) -> IO Bool
checkCIDs Connection
conn CID
dCID Either CID (ByteString, ByteString)
ex
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
ok (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
            Connection -> CID -> IO ()
resetPeerCID Connection
conn CID
sCID
            Connection -> (AuthCIDs -> AuthCIDs) -> IO ()
setPeerAuthCIDs Connection
conn ((AuthCIDs -> AuthCIDs) -> IO ())
-> (AuthCIDs -> AuthCIDs) -> IO ()
forall a b. (a -> b) -> a -> b
$ \AuthCIDs
auth -> AuthCIDs
auth{retrySrcCID = Just sCID}
            Connection
-> EncryptionLevel -> TrafficSecrets InitialSecret -> IO ()
forall a.
Connection -> EncryptionLevel -> TrafficSecrets a -> IO ()
initializeCoder Connection
conn EncryptionLevel
InitialLevel (TrafficSecrets InitialSecret -> IO ())
-> TrafficSecrets InitialSecret -> IO ()
forall a b. (a -> b) -> a -> b
$ Version -> CID -> TrafficSecrets InitialSecret
initialSecrets Version
ver CID
sCID
            Connection -> ByteString -> IO ()
setToken Connection
conn ByteString
token
            Connection -> Bool -> IO ()
setRetried Connection
conn Bool
True
            LDCC -> IO (Seq PlainPacket)
releaseByRetry (Connection -> LDCC
connLDCC Connection
conn) IO (Seq PlainPacket) -> (Seq PlainPacket -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (PlainPacket -> IO ()) -> Seq PlainPacket -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ PlainPacket -> IO ()
put
      where
        put :: PlainPacket -> IO ()
put PlainPacket
ppkt = Connection -> Output -> IO ()
putOutput Connection
conn (Output -> IO ()) -> Output -> IO ()
forall a b. (a -> b) -> a -> b
$ PlainPacket -> Output
OutRetrans PlainPacket
ppkt

checkCIDs :: Connection -> CID -> Either CID (ByteString, ByteString) -> IO Bool
checkCIDs :: Connection -> CID -> Either CID (ByteString, ByteString) -> IO Bool
checkCIDs Connection
conn CID
dCID (Left CID
sCID) = do
    CID
localCID <- Connection -> IO CID
getMyCID Connection
conn
    CID
remoteCID <- Connection -> IO CID
getPeerCID Connection
conn
    Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (CID
dCID CID -> CID -> Bool
forall a. Eq a => a -> a -> Bool
== CID
localCID Bool -> Bool -> Bool
&& CID
sCID CID -> CID -> Bool
forall a. Eq a => a -> a -> Bool
== CID
remoteCID)
checkCIDs Connection
conn CID
dCID (Right (ByteString
pseudo0, ByteString
tag)) = do
    CID
localCID <- Connection -> IO CID
getMyCID Connection
conn
    CID
remoteCID <- Connection -> IO CID
getPeerCID Connection
conn
    Version
ver <- Connection -> IO Version
getVersion Connection
conn
    let ok :: Bool
ok = Version -> CID -> ByteString -> ByteString
calculateIntegrityTag Version
ver CID
remoteCID ByteString
pseudo0 ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
tag
    Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (CID
dCID CID -> CID -> Bool
forall a. Eq a => a -> a -> Bool
== CID
localCID Bool -> Bool -> Bool
&& Bool
ok)

recvClient :: RecvQ -> IO ReceivedPacket
recvClient :: RecvQ -> IO ReceivedPacket
recvClient = RecvQ -> IO ReceivedPacket
readRecvQ

----------------------------------------------------------------

-- | How to control a connection.
data ConnectionControl
    = ChangeServerCID
    | ChangeClientCID
    | NATRebinding
    | ActiveMigration
    deriving (ConnectionControl -> ConnectionControl -> Bool
(ConnectionControl -> ConnectionControl -> Bool)
-> (ConnectionControl -> ConnectionControl -> Bool)
-> Eq ConnectionControl
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ConnectionControl -> ConnectionControl -> Bool
== :: ConnectionControl -> ConnectionControl -> Bool
$c/= :: ConnectionControl -> ConnectionControl -> Bool
/= :: ConnectionControl -> ConnectionControl -> Bool
Eq, Int -> ConnectionControl -> ShowS
[ConnectionControl] -> ShowS
ConnectionControl -> String
(Int -> ConnectionControl -> ShowS)
-> (ConnectionControl -> String)
-> ([ConnectionControl] -> ShowS)
-> Show ConnectionControl
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ConnectionControl -> ShowS
showsPrec :: Int -> ConnectionControl -> ShowS
$cshow :: ConnectionControl -> String
show :: ConnectionControl -> String
$cshowList :: [ConnectionControl] -> ShowS
showList :: [ConnectionControl] -> ShowS
Show)

controlConnection :: Connection -> ConnectionControl -> IO Bool
controlConnection :: Connection -> ConnectionControl -> IO Bool
controlConnection Connection
conn ConnectionControl
typ
    | Connection -> Bool
forall a. Connector a => a -> Bool
isClient Connection
conn = do
        Connection -> IO ()
waitEstablished Connection
conn
        Connection -> ConnectionControl -> IO Bool
controlConnection' Connection
conn ConnectionControl
typ
    | Bool
otherwise = Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False

controlConnection' :: Connection -> ConnectionControl -> IO Bool
controlConnection' :: Connection -> ConnectionControl -> IO Bool
controlConnection' Connection
conn ConnectionControl
ChangeServerCID = do
    Maybe CIDInfo
mn <- Microseconds -> String -> IO CIDInfo -> IO (Maybe CIDInfo)
forall a. Microseconds -> String -> IO a -> IO (Maybe a)
timeout (Int -> Microseconds
Microseconds Int
1000000) String
"controlConnection' 1" (IO CIDInfo -> IO (Maybe CIDInfo))
-> IO CIDInfo -> IO (Maybe CIDInfo)
forall a b. (a -> b) -> a -> b
$ Connection -> IO CIDInfo
waitPeerCID Connection
conn -- fixme
    case Maybe CIDInfo
mn of
        Maybe CIDInfo
Nothing -> Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
        Just CIDInfo
cidInfo -> do
            Connection -> EncryptionLevel -> [Frame] -> IO ()
sendFrames Connection
conn EncryptionLevel
RTT1Level [Int -> Frame
RetireConnectionID (CIDInfo -> Int
cidInfoSeq CIDInfo
cidInfo)]
            Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
controlConnection' Connection
conn ConnectionControl
ChangeClientCID = do
    CIDInfo
cidInfo <- Connection -> IO CIDInfo
getNewMyCID Connection
conn
    Int
x <- (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int -> Int) -> IO Int -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Connection -> IO Int
getMyCIDSeqNum Connection
conn
    Connection -> EncryptionLevel -> [Frame] -> IO ()
sendFrames Connection
conn EncryptionLevel
RTT1Level [CIDInfo -> Int -> Frame
NewConnectionID CIDInfo
cidInfo Int
x]
    Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
controlConnection' Connection
conn ConnectionControl
NATRebinding = do
    Connection -> Microseconds -> IO ()
rebind Connection
conn (Microseconds -> IO ()) -> Microseconds -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> Microseconds
Microseconds Int
5000 -- nearly 0
    Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
controlConnection' Connection
conn ConnectionControl
ActiveMigration = do
    Maybe CIDInfo
mn <- Microseconds -> String -> IO CIDInfo -> IO (Maybe CIDInfo)
forall a. Microseconds -> String -> IO a -> IO (Maybe a)
timeout (Int -> Microseconds
Microseconds Int
1000000) String
"controlConnection' 2" (IO CIDInfo -> IO (Maybe CIDInfo))
-> IO CIDInfo -> IO (Maybe CIDInfo)
forall a b. (a -> b) -> a -> b
$ Connection -> IO CIDInfo
waitPeerCID Connection
conn -- fixme
    case Maybe CIDInfo
mn of
        Maybe CIDInfo
Nothing -> Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
        Maybe CIDInfo
mcidinfo -> do
            Connection -> Microseconds -> IO ()
rebind Connection
conn (Microseconds -> IO ()) -> Microseconds -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> Microseconds
Microseconds Int
5000000
            Connection -> Maybe CIDInfo -> IO ()
validatePath Connection
conn Maybe CIDInfo
mcidinfo
            Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True

rebind :: Connection -> Microseconds -> IO ()
rebind :: Connection -> Microseconds -> IO ()
rebind Connection
conn Microseconds
microseconds = do
    PeerInfo SockAddr
peersa <- Connection -> IO PeerInfo
getPeerInfo Connection
conn
    Socket
newSock <- SockAddr -> IO Socket
natRebinding SockAddr
peersa
    Socket
oldSock <- Connection -> Socket -> IO Socket
setSocket Connection
conn Socket
newSock
    let reader :: IO ()
reader = Socket -> Connection -> IO ()
readerClient Socket
newSock Connection
conn
    Connection -> IO () -> IO ()
forkManaged Connection
conn IO ()
reader
    Connection -> Microseconds -> IO () -> IO ()
fire Connection
conn Microseconds
microseconds (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Socket -> IO ()
close Socket
oldSock