{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}

module Network.QUIC.Server.Reader (
    Dispatch,
    newDispatch,
    clearDispatch,
    runDispatcher,
    tokenMgr,
    genStatelessReset,

    -- * Accepting
    Accept (..),

    -- * Receiving and reading
    RecvQ,
    recvServer,
    ServerState (..),
) where

import Control.Concurrent
import Control.Concurrent.STM
import qualified Control.Exception as E
import qualified Crypto.Token as CT
import qualified Data.ByteString as BS
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import qualified GHC.IO.Exception as E
import Network.ByteOrder
import Network.Control (LRUCacheRef, Rate, getRate, newRate)
import qualified Network.Control as LRUCache
import Network.Socket (Socket, waitReadSocketSTM)
import qualified Network.Socket.ByteString as NSB
import qualified System.IO.Error as E
import System.Random (getStdRandom, randomRIO, uniformByteString)

import Network.QUIC.Common
import Network.QUIC.Config
import Network.QUIC.Connection
import Network.QUIC.Connector
import Network.QUIC.Exception
import Network.QUIC.Imports
import Network.QUIC.Logger
import Network.QUIC.Packet
import Network.QUIC.Parameters
import Network.QUIC.Types
import Network.QUIC.Windows

----------------------------------------------------------------

data Dispatch = Dispatch
    { Dispatch -> TokenManager
tokenMgr :: CT.TokenManager
    , Dispatch -> IORef ConnectionDict
dstTable :: IORef ConnectionDict
    , Dispatch -> RecvQDict
srcTable :: RecvQDict
    , Dispatch -> CID -> StatelessResetToken
genStatelessReset :: CID -> StatelessResetToken
    , Dispatch -> Rate
statelessResetRate :: Rate
    }

statelessResetLimit :: Int
statelessResetLimit :: Int
statelessResetLimit = Int
20

newDispatch :: ServerConfig -> IO Dispatch
newDispatch :: ServerConfig -> IO Dispatch
newDispatch ServerConfig{Bool
Int
[(IP, PortNumber)]
[Group]
[Cipher]
[Version]
Maybe FilePath
Maybe (Version -> [Token] -> IO Token)
SessionManager
Credentials
ServerHooks
Parameters
Hooks
FilePath -> IO ()
scVersions :: [Version]
scCiphers :: [Cipher]
scGroups :: [Group]
scParameters :: Parameters
scKeyLog :: FilePath -> IO ()
scQLog :: Maybe FilePath
scCredentials :: Credentials
scHooks :: Hooks
scTlsHooks :: ServerHooks
scUse0RTT :: Bool
scAddresses :: [(IP, PortNumber)]
scALPN :: Maybe (Version -> [Token] -> IO Token)
scRequireRetry :: Bool
scSessionManager :: SessionManager
scDebugLog :: Maybe FilePath
scTicketLifetime :: Int
scVersions :: ServerConfig -> [Version]
scCiphers :: ServerConfig -> [Cipher]
scGroups :: ServerConfig -> [Group]
scParameters :: ServerConfig -> Parameters
scKeyLog :: ServerConfig -> FilePath -> IO ()
scQLog :: ServerConfig -> Maybe FilePath
scCredentials :: ServerConfig -> Credentials
scHooks :: ServerConfig -> Hooks
scTlsHooks :: ServerConfig -> ServerHooks
scUse0RTT :: ServerConfig -> Bool
scAddresses :: ServerConfig -> [(IP, PortNumber)]
scALPN :: ServerConfig -> Maybe (Version -> [Token] -> IO Token)
scRequireRetry :: ServerConfig -> Bool
scSessionManager :: ServerConfig -> SessionManager
scDebugLog :: ServerConfig -> Maybe FilePath
scTicketLifetime :: ServerConfig -> Int
..} =
    TokenManager
-> IORef ConnectionDict
-> RecvQDict
-> (CID -> StatelessResetToken)
-> Rate
-> Dispatch
Dispatch
        (TokenManager
 -> IORef ConnectionDict
 -> RecvQDict
 -> (CID -> StatelessResetToken)
 -> Rate
 -> Dispatch)
-> IO TokenManager
-> IO
     (IORef ConnectionDict
      -> RecvQDict -> (CID -> StatelessResetToken) -> Rate -> Dispatch)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Config -> IO TokenManager
CT.spawnTokenManager Config
conf
        IO
  (IORef ConnectionDict
   -> RecvQDict -> (CID -> StatelessResetToken) -> Rate -> Dispatch)
-> IO (IORef ConnectionDict)
-> IO
     (RecvQDict -> (CID -> StatelessResetToken) -> Rate -> Dispatch)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ConnectionDict -> IO (IORef ConnectionDict)
forall a. a -> IO (IORef a)
newIORef ConnectionDict
emptyConnectionDict
        IO (RecvQDict -> (CID -> StatelessResetToken) -> Rate -> Dispatch)
-> IO RecvQDict
-> IO ((CID -> StatelessResetToken) -> Rate -> Dispatch)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IO RecvQDict
newRecvQDict
        IO ((CID -> StatelessResetToken) -> Rate -> Dispatch)
-> IO (CID -> StatelessResetToken) -> IO (Rate -> Dispatch)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IO (CID -> StatelessResetToken)
makeGenStatelessReset
        IO (Rate -> Dispatch) -> IO Rate -> IO Dispatch
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IO Rate
newRate
  where
    conf :: Config
conf =
        Config
CT.defaultConfig
            { CT.tokenLifetime = scTicketLifetime
            , CT.threadName = "QUIC token manager"
            }

clearDispatch :: Dispatch -> IO ()
clearDispatch :: Dispatch -> IO ()
clearDispatch Dispatch
d = TokenManager -> IO ()
CT.killTokenManager (TokenManager -> IO ()) -> TokenManager -> IO ()
forall a b. (a -> b) -> a -> b
$ Dispatch -> TokenManager
tokenMgr Dispatch
d

----------------------------------------------------------------

newtype ConnectionDict = ConnectionDict (Map CID Connection)

emptyConnectionDict :: ConnectionDict
emptyConnectionDict :: ConnectionDict
emptyConnectionDict = Map CID Connection -> ConnectionDict
ConnectionDict Map CID Connection
forall k a. Map k a
M.empty

lookupConnectionDict :: IORef ConnectionDict -> CID -> IO (Maybe Connection)
lookupConnectionDict :: IORef ConnectionDict -> CID -> IO (Maybe Connection)
lookupConnectionDict IORef ConnectionDict
ref CID
cid = do
    ConnectionDict Map CID Connection
tbl <- IORef ConnectionDict -> IO ConnectionDict
forall a. IORef a -> IO a
readIORef IORef ConnectionDict
ref
    Maybe Connection -> IO (Maybe Connection)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Connection -> IO (Maybe Connection))
-> Maybe Connection -> IO (Maybe Connection)
forall a b. (a -> b) -> a -> b
$ CID -> Map CID Connection -> Maybe Connection
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup CID
cid Map CID Connection
tbl

registerConnectionDict :: IORef ConnectionDict -> CID -> Connection -> IO ()
registerConnectionDict :: IORef ConnectionDict -> CID -> Connection -> IO ()
registerConnectionDict IORef ConnectionDict
ref CID
cid Connection
conn = IORef ConnectionDict -> (ConnectionDict -> ConnectionDict) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
atomicModifyIORef'' IORef ConnectionDict
ref ((ConnectionDict -> ConnectionDict) -> IO ())
-> (ConnectionDict -> ConnectionDict) -> IO ()
forall a b. (a -> b) -> a -> b
$
    \(ConnectionDict Map CID Connection
tbl) -> Map CID Connection -> ConnectionDict
ConnectionDict (Map CID Connection -> ConnectionDict)
-> Map CID Connection -> ConnectionDict
forall a b. (a -> b) -> a -> b
$ CID -> Connection -> Map CID Connection -> Map CID Connection
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert CID
cid Connection
conn Map CID Connection
tbl

unregisterConnectionDict :: IORef ConnectionDict -> CID -> IO ()
unregisterConnectionDict :: IORef ConnectionDict -> CID -> IO ()
unregisterConnectionDict IORef ConnectionDict
ref CID
cid = IORef ConnectionDict -> (ConnectionDict -> ConnectionDict) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
atomicModifyIORef'' IORef ConnectionDict
ref ((ConnectionDict -> ConnectionDict) -> IO ())
-> (ConnectionDict -> ConnectionDict) -> IO ()
forall a b. (a -> b) -> a -> b
$
    \(ConnectionDict Map CID Connection
tbl) -> Map CID Connection -> ConnectionDict
ConnectionDict (Map CID Connection -> ConnectionDict)
-> Map CID Connection -> ConnectionDict
forall a b. (a -> b) -> a -> b
$ CID -> Map CID Connection -> Map CID Connection
forall k a. Ord k => k -> Map k a -> Map k a
M.delete CID
cid Map CID Connection
tbl

----------------------------------------------------------------

-- Source CID -> RecvQ
-- Initials and RTT0 are queued before Conneciton is created.
newtype RecvQDict = RecvQDict (LRUCacheRef CID RecvQ)

recvQDictSize :: Int
recvQDictSize :: Int
recvQDictSize = Int
100

newRecvQDict :: IO RecvQDict
newRecvQDict :: IO RecvQDict
newRecvQDict = LRUCacheRef CID RecvQ -> RecvQDict
RecvQDict (LRUCacheRef CID RecvQ -> RecvQDict)
-> IO (LRUCacheRef CID RecvQ) -> IO RecvQDict
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> IO (LRUCacheRef CID RecvQ)
forall k v. Int -> IO (LRUCacheRef k v)
LRUCache.newLRUCacheRef Int
recvQDictSize

lookupRecvQDict :: RecvQDict -> CID -> IO (RecvQ, Bool)
lookupRecvQDict :: RecvQDict -> CID -> IO (RecvQ, Bool)
lookupRecvQDict (RecvQDict LRUCacheRef CID RecvQ
ref) CID
dcid = LRUCacheRef CID RecvQ -> CID -> IO RecvQ -> IO (RecvQ, Bool)
forall k v. Ord k => LRUCacheRef k v -> k -> IO v -> IO (v, Bool)
LRUCache.cached LRUCacheRef CID RecvQ
ref CID
dcid IO RecvQ
newRecvQ

lookupRecvQDict' :: RecvQDict -> CID -> IO (Maybe RecvQ)
lookupRecvQDict' :: RecvQDict -> CID -> IO (Maybe RecvQ)
lookupRecvQDict' (RecvQDict LRUCacheRef CID RecvQ
ref) CID
dcid = LRUCacheRef CID RecvQ -> CID -> IO (Maybe RecvQ)
forall k v. Ord k => LRUCacheRef k v -> k -> IO (Maybe v)
LRUCache.cached' LRUCacheRef CID RecvQ
ref CID
dcid

----------------------------------------------------------------

data Accept = Accept
    { Accept -> VersionInfo
accVersionInfo :: VersionInfo
    , Accept -> AuthCIDs
accMyAuthCIDs :: AuthCIDs
    , Accept -> AuthCIDs
accPeerAuthCIDs :: AuthCIDs
    , Accept -> Socket
accMySocket :: Socket
    , Accept -> PeerInfo
accPeerInfo :: PeerInfo
    , Accept -> RecvQ
accRecvQ :: RecvQ
    , Accept -> Int
accPacketSize :: Int
    , Accept -> CID -> Connection -> IO ()
accRegister :: CID -> Connection -> IO ()
    , Accept -> CID -> IO ()
accUnregister :: CID -> IO ()
    , Accept -> Bool
accAddressValidated :: Bool
    , Accept -> TimeMicrosecond
accTime :: TimeMicrosecond
    }

----------------------------------------------------------------

runDispatcher
    :: Dispatch
    -> ServerConfig
    -> TVar ServerState
    -> (Accept -> IO ())
    -> Socket
    -> IO ThreadId
runDispatcher :: Dispatch
-> ServerConfig
-> TVar ServerState
-> (Accept -> IO ())
-> Socket
-> IO ThreadId
runDispatcher Dispatch
d ServerConfig
conf TVar ServerState
stvar Accept -> IO ()
forkConn Socket
mysock = IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ Dispatch
-> ServerConfig
-> TVar ServerState
-> (Accept -> IO ())
-> Socket
-> IO ()
dispatcher Dispatch
d ServerConfig
conf TVar ServerState
stvar Accept -> IO ()
forkConn Socket
mysock

data ServerState = Running | Stopped deriving (ServerState -> ServerState -> Bool
(ServerState -> ServerState -> Bool)
-> (ServerState -> ServerState -> Bool) -> Eq ServerState
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ServerState -> ServerState -> Bool
== :: ServerState -> ServerState -> Bool
$c/= :: ServerState -> ServerState -> Bool
/= :: ServerState -> ServerState -> Bool
Eq, Int -> ServerState -> ShowS
[ServerState] -> ShowS
ServerState -> FilePath
(Int -> ServerState -> ShowS)
-> (ServerState -> FilePath)
-> ([ServerState] -> ShowS)
-> Show ServerState
forall a.
(Int -> a -> ShowS) -> (a -> FilePath) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ServerState -> ShowS
showsPrec :: Int -> ServerState -> ShowS
$cshow :: ServerState -> FilePath
show :: ServerState -> FilePath
$cshowList :: [ServerState] -> ShowS
showList :: [ServerState] -> ShowS
Show)

checkLoop :: TVar ServerState -> STM () -> IO Bool
checkLoop :: TVar ServerState -> STM () -> IO Bool
checkLoop TVar ServerState
stvar STM ()
waitsock = STM Bool -> IO Bool
forall a. STM a -> IO a
atomically (STM Bool -> IO Bool) -> STM Bool -> IO Bool
forall a b. (a -> b) -> a -> b
$ do
    ServerState
st <- TVar ServerState -> STM ServerState
forall a. TVar a -> STM a
readTVar TVar ServerState
stvar
    if ServerState
st ServerState -> ServerState -> Bool
forall a. Eq a => a -> a -> Bool
== ServerState
Stopped
        then
            Bool -> STM Bool
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
        else do
            STM ()
waitsock -- blocking is retry
            Bool -> STM Bool
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True

dispatcher
    :: Dispatch
    -> ServerConfig
    -> TVar ServerState
    -> (Accept -> IO ())
    -> Socket
    -> IO ()
dispatcher :: Dispatch
-> ServerConfig
-> TVar ServerState
-> (Accept -> IO ())
-> Socket
-> IO ()
dispatcher Dispatch
d ServerConfig
conf TVar ServerState
stvar Accept -> IO ()
forkConnection Socket
mysock = do
    FilePath -> IO ()
labelMe FilePath
"QUIC dispatcher"
    STM ()
wait <- Socket -> IO (STM ())
waitReadSocketSTM Socket
mysock
    DebugLogger -> IO () -> IO ()
handleLogUnit DebugLogger
logAction (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ STM () -> IO ()
loop STM ()
wait
  where
    loop :: STM () -> IO ()
loop STM ()
wait = do
        Bool
cont <- TVar ServerState -> STM () -> IO Bool
checkLoop TVar ServerState
stvar STM ()
wait
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
cont (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
            (Token
bs, SockAddr
peersa) <- IO (Token, SockAddr) -> IO (Token, SockAddr)
forall {b}. IO b -> IO b
safeRecv (IO (Token, SockAddr) -> IO (Token, SockAddr))
-> IO (Token, SockAddr) -> IO (Token, SockAddr)
forall a b. (a -> b) -> a -> b
$ Socket -> Int -> IO (Token, SockAddr)
NSB.recvFrom Socket
mysock Int
2048
            TimeMicrosecond
now <- IO TimeMicrosecond
getTimeMicrosecond
            let send' :: Token -> IO ()
send' Token
b = IO Int -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Int -> IO ()) -> IO Int -> IO ()
forall a b. (a -> b) -> a -> b
$ Socket -> Token -> SockAddr -> IO Int
NSB.sendTo Socket
mysock Token
b SockAddr
peersa
                -- cf: greaseQuicBit $ getMyParameters conn
                quicBit :: Bool
quicBit = Parameters -> Bool
greaseQuicBit (Parameters -> Bool) -> Parameters -> Bool
forall a b. (a -> b) -> a -> b
$ ServerConfig -> Parameters
scParameters ServerConfig
conf
            [(CryptPacket, EncryptionLevel, Int)]
cpckts <- Token -> Bool -> IO [(CryptPacket, EncryptionLevel, Int)]
decodeCryptPackets Token
bs (Bool -> Bool
not Bool
quicBit)
            let bytes :: Int
bytes = Token -> Int
BS.length Token
bs
                peerInfo :: PeerInfo
peerInfo = SockAddr -> PeerInfo
PeerInfo SockAddr
peersa
                switch :: (CryptPacket, EncryptionLevel, Int) -> IO ()
switch = Dispatch
-> ServerConfig
-> (Accept -> IO ())
-> DebugLogger
-> Socket
-> PeerInfo
-> (Token -> IO ())
-> Int
-> TimeMicrosecond
-> (CryptPacket, EncryptionLevel, Int)
-> IO ()
dispatch Dispatch
d ServerConfig
conf Accept -> IO ()
forkConnection DebugLogger
logAction Socket
mysock PeerInfo
peerInfo Token -> IO ()
send' Int
bytes TimeMicrosecond
now
            ((CryptPacket, EncryptionLevel, Int) -> IO ())
-> [(CryptPacket, EncryptionLevel, Int)] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (CryptPacket, EncryptionLevel, Int) -> IO ()
switch [(CryptPacket, EncryptionLevel, Int)]
cpckts
            STM () -> IO ()
loop STM ()
wait

    doDebug :: Bool
doDebug = Maybe FilePath -> Bool
forall a. Maybe a -> Bool
isJust (Maybe FilePath -> Bool) -> Maybe FilePath -> Bool
forall a b. (a -> b) -> a -> b
$ ServerConfig -> Maybe FilePath
scDebugLog ServerConfig
conf
    logAction :: DebugLogger
logAction Builder
msg
        | Bool
doDebug = DebugLogger
stdoutLogger (Builder
"dispatch(er): " Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Builder
msg)
        | Bool
otherwise = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()

    safeRecv :: IO b -> IO b
safeRecv IO b
rcv = do
        Either SomeException b
ex <- IO b -> IO (Either SomeException b)
forall e a. Exception e => IO a -> IO (Either e a)
E.try (IO b -> IO (Either SomeException b))
-> IO b -> IO (Either SomeException b)
forall a b. (a -> b) -> a -> b
$ IO b -> IO b
forall {b}. IO b -> IO b
windowsThreadBlockHack IO b
rcv
        case Either SomeException b
ex of
            Right b
x -> b -> IO b
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return b
x
            Left SomeException
se | SomeException -> Bool
forall e. Exception e => e -> Bool
isAsyncException SomeException
se -> SomeException -> IO b
forall e a. Exception e => e -> IO a
E.throwIO (SomeException
se :: E.SomeException)
            Left SomeException
se -> case SomeException -> Maybe IOError
forall e. Exception e => SomeException -> Maybe e
E.fromException SomeException
se of
                Just IOError
e | IOError -> IOErrorType
E.ioeGetErrorType IOError
e IOErrorType -> IOErrorType -> Bool
forall a. Eq a => a -> a -> Bool
== IOErrorType
E.InvalidArgument -> SomeException -> IO b
forall e a. Exception e => e -> IO a
E.throwIO SomeException
se
                Maybe IOError
_ -> do
                    DebugLogger
logAction DebugLogger -> DebugLogger
forall a b. (a -> b) -> a -> b
$ Builder
"recv again: " Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> SomeException -> Builder
forall a. Show a => a -> Builder
bhow SomeException
se
                    IO b
rcv

----------------------------------------------------------------

-- If client initial is fragmented into multiple packets,
-- there is no way to put the all packets into a single queue.
-- Rather, each fragment packet is put into its own queue.
-- For the first fragment, handshake would successif others are
-- retransmitted.
-- For the other fragments, handshake will fail since its socket
-- cannot be connected.
dispatch
    :: Dispatch
    -> ServerConfig
    -> (Accept -> IO ())
    -> DebugLogger
    -> Socket
    -> PeerInfo
    -> (ByteString -> IO ())
    -> Int
    -> TimeMicrosecond
    -> (CryptPacket, EncryptionLevel, Int)
    -> IO ()
dispatch :: Dispatch
-> ServerConfig
-> (Accept -> IO ())
-> DebugLogger
-> Socket
-> PeerInfo
-> (Token -> IO ())
-> Int
-> TimeMicrosecond
-> (CryptPacket, EncryptionLevel, Int)
-> IO ()
dispatch
    Dispatch{IORef ConnectionDict
TokenManager
Rate
RecvQDict
CID -> StatelessResetToken
tokenMgr :: Dispatch -> TokenManager
genStatelessReset :: Dispatch -> CID -> StatelessResetToken
dstTable :: Dispatch -> IORef ConnectionDict
srcTable :: Dispatch -> RecvQDict
statelessResetRate :: Dispatch -> Rate
tokenMgr :: TokenManager
dstTable :: IORef ConnectionDict
srcTable :: RecvQDict
genStatelessReset :: CID -> StatelessResetToken
statelessResetRate :: Rate
..}
    ServerConfig{Bool
Int
[(IP, PortNumber)]
[Group]
[Cipher]
[Version]
Maybe FilePath
Maybe (Version -> [Token] -> IO Token)
SessionManager
Credentials
ServerHooks
Parameters
Hooks
FilePath -> IO ()
scVersions :: ServerConfig -> [Version]
scCiphers :: ServerConfig -> [Cipher]
scGroups :: ServerConfig -> [Group]
scParameters :: ServerConfig -> Parameters
scKeyLog :: ServerConfig -> FilePath -> IO ()
scQLog :: ServerConfig -> Maybe FilePath
scCredentials :: ServerConfig -> Credentials
scHooks :: ServerConfig -> Hooks
scTlsHooks :: ServerConfig -> ServerHooks
scUse0RTT :: ServerConfig -> Bool
scAddresses :: ServerConfig -> [(IP, PortNumber)]
scALPN :: ServerConfig -> Maybe (Version -> [Token] -> IO Token)
scRequireRetry :: ServerConfig -> Bool
scSessionManager :: ServerConfig -> SessionManager
scDebugLog :: ServerConfig -> Maybe FilePath
scTicketLifetime :: ServerConfig -> Int
scVersions :: [Version]
scCiphers :: [Cipher]
scGroups :: [Group]
scParameters :: Parameters
scKeyLog :: FilePath -> IO ()
scQLog :: Maybe FilePath
scCredentials :: Credentials
scHooks :: Hooks
scTlsHooks :: ServerHooks
scUse0RTT :: Bool
scAddresses :: [(IP, PortNumber)]
scALPN :: Maybe (Version -> [Token] -> IO Token)
scRequireRetry :: Bool
scSessionManager :: SessionManager
scDebugLog :: Maybe FilePath
scTicketLifetime :: Int
..}
    Accept -> IO ()
forkConnection
    DebugLogger
logAction
    Socket
mysock
    PeerInfo
peerInfo
    Token -> IO ()
send'
    Int
bytes
    TimeMicrosecond
tim
    (cpkt :: CryptPacket
cpkt@(CryptPacket (Initial Version
peerVer CID
dCID CID
sCID Token
token) Crypt
_), EncryptionLevel
lvl, Int
siz)
        | Int
bytes Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
defaultQUICPacketSize = do
            DebugLogger
logAction DebugLogger -> DebugLogger
forall a b. (a -> b) -> a -> b
$ Builder
"too small " Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Int -> Builder
forall a. Show a => a -> Builder
bhow Int
bytes Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Builder
", " Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> PeerInfo -> Builder
forall a. Show a => a -> Builder
bhow PeerInfo
peerInfo
        | Version
peerVer Version -> [Version] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [Version]
myVersions = do
            let offerVersions :: [Version]
offerVersions
                    | Version
peerVer Version -> Version -> Bool
forall a. Eq a => a -> a -> Bool
== Version
GreasingVersion = Version
GreasingVersion2 Version -> [Version] -> [Version]
forall a. a -> [a] -> [a]
: [Version]
myVersions
                    | Bool
otherwise = Version
GreasingVersion Version -> [Version] -> [Version]
forall a. a -> [a] -> [a]
: [Version]
myVersions
            Token
bss <-
                VersionNegotiationPacket -> IO Token
encodeVersionNegotiationPacket (VersionNegotiationPacket -> IO Token)
-> VersionNegotiationPacket -> IO Token
forall a b. (a -> b) -> a -> b
$
                    CID -> CID -> [Version] -> VersionNegotiationPacket
VersionNegotiationPacket CID
sCID CID
dCID [Version]
offerVersions
            Token -> IO ()
send' Token
bss
        | Token
token Token -> Token -> Bool
forall a. Eq a => a -> a -> Bool
== Token
"" = do
            Maybe Connection
mconn <- IORef ConnectionDict -> CID -> IO (Maybe Connection)
lookupConnectionDict IORef ConnectionDict
dstTable CID
dCID
            case Maybe Connection
mconn of
                Maybe Connection
Nothing
                    | Bool
scRequireRetry -> IO ()
sendRetry
                    | Bool
otherwise -> Bool -> IO ()
pushToAcceptFirst Bool
False
                Just Connection
conn -> RecvQ -> ReceivedPacket -> IO ()
writeRecvQ (Connection -> RecvQ
connRecvQ Connection
conn) (ReceivedPacket -> IO ()) -> ReceivedPacket -> IO ()
forall a b. (a -> b) -> a -> b
$ CryptPacket
-> TimeMicrosecond -> Int -> EncryptionLevel -> ReceivedPacket
mkReceivedPacket CryptPacket
cpkt TimeMicrosecond
tim Int
siz EncryptionLevel
lvl
        | Bool
otherwise = do
            Maybe Connection
mconn <- IORef ConnectionDict -> CID -> IO (Maybe Connection)
lookupConnectionDict IORef ConnectionDict
dstTable CID
dCID
            case Maybe Connection
mconn of
                Maybe Connection
Nothing -> do
                    Maybe CryptoToken
mct <- TokenManager -> Token -> IO (Maybe CryptoToken)
decryptToken TokenManager
tokenMgr Token
token
                    case Maybe CryptoToken
mct of
                        Just CryptoToken
ct
                            | CryptoToken -> Bool
isRetryToken CryptoToken
ct -> do
                                Bool
ok <- CryptoToken -> IO Bool
isRetryTokenValid CryptoToken
ct
                                if Bool
ok then CryptoToken -> IO ()
pushToAcceptRetried CryptoToken
ct else IO ()
sendRetry
                        Maybe CryptoToken
_ -> Bool -> IO ()
pushToAcceptFirst Bool
True
                Just Connection
conn -> RecvQ -> ReceivedPacket -> IO ()
writeRecvQ (Connection -> RecvQ
connRecvQ Connection
conn) (ReceivedPacket -> IO ()) -> ReceivedPacket -> IO ()
forall a b. (a -> b) -> a -> b
$ CryptPacket
-> TimeMicrosecond -> Int -> EncryptionLevel -> ReceivedPacket
mkReceivedPacket CryptPacket
cpkt TimeMicrosecond
tim Int
siz EncryptionLevel
lvl
      where
        myVersions :: [Version]
myVersions = [Version]
scVersions
        pushToAcceptQ :: AuthCIDs -> AuthCIDs -> Bool -> IO ()
pushToAcceptQ AuthCIDs
myAuthCIDs AuthCIDs
peerAuthCIDs Bool
addrValid = do
            (RecvQ
q, Bool
exist) <- RecvQDict -> CID -> IO (RecvQ, Bool)
lookupRecvQDict RecvQDict
srcTable CID
sCID
            RecvQ -> ReceivedPacket -> IO ()
writeRecvQ RecvQ
q (ReceivedPacket -> IO ()) -> ReceivedPacket -> IO ()
forall a b. (a -> b) -> a -> b
$ CryptPacket
-> TimeMicrosecond -> Int -> EncryptionLevel -> ReceivedPacket
mkReceivedPacket CryptPacket
cpkt TimeMicrosecond
tim Int
siz EncryptionLevel
lvl
            Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
exist (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
                let reg :: CID -> Connection -> IO ()
reg = IORef ConnectionDict -> CID -> Connection -> IO ()
registerConnectionDict IORef ConnectionDict
dstTable
                    unreg :: CID -> IO ()
unreg CID
cid =
                        Microseconds -> IO () -> IO ()
fire' (Int -> Microseconds
Microseconds Int
10000000) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ IORef ConnectionDict -> CID -> IO ()
unregisterConnectionDict IORef ConnectionDict
dstTable CID
cid
                    acc :: Accept
acc =
                        Accept
                            { accVersionInfo :: VersionInfo
accVersionInfo = Version -> [Version] -> VersionInfo
VersionInfo Version
peerVer [Version]
myVersions
                            , accMyAuthCIDs :: AuthCIDs
accMyAuthCIDs = AuthCIDs
myAuthCIDs
                            , accPeerAuthCIDs :: AuthCIDs
accPeerAuthCIDs = AuthCIDs
peerAuthCIDs
                            , accMySocket :: Socket
accMySocket = Socket
mysock
                            , accPeerInfo :: PeerInfo
accPeerInfo = PeerInfo
peerInfo
                            , accRecvQ :: RecvQ
accRecvQ = RecvQ
q
                            , accPacketSize :: Int
accPacketSize = Int
bytes
                            , accRegister :: CID -> Connection -> IO ()
accRegister = CID -> Connection -> IO ()
reg
                            , accUnregister :: CID -> IO ()
accUnregister = CID -> IO ()
unreg
                            , accAddressValidated :: Bool
accAddressValidated = Bool
addrValid
                            , accTime :: TimeMicrosecond
accTime = TimeMicrosecond
tim
                            }
                Accept -> IO ()
forkConnection Accept
acc
        -- Initial: DCID=S1, SCID=C1 ->
        --                                     <- Initial: DCID=C1, SCID=S2
        --                               ...
        -- 1-RTT: DCID=S2 ->
        --                                                <- 1-RTT: DCID=C1
        --
        -- initial_source_connection_id       = S2   (newdCID)
        -- original_destination_connection_id = S1   (dCID)
        -- retry_source_connection_id         = Nothing
        pushToAcceptFirst :: Bool -> IO ()
pushToAcceptFirst Bool
addrValid = do
            CID
newdCID <- IO CID
newCID
            let myAuthCIDs :: AuthCIDs
myAuthCIDs =
                    AuthCIDs
defaultAuthCIDs
                        { initSrcCID = Just newdCID
                        , origDstCID = Just dCID
                        }
                peerAuthCIDs :: AuthCIDs
peerAuthCIDs =
                    AuthCIDs
defaultAuthCIDs
                        { initSrcCID = Just sCID
                        }
            AuthCIDs -> AuthCIDs -> Bool -> IO ()
pushToAcceptQ AuthCIDs
myAuthCIDs AuthCIDs
peerAuthCIDs Bool
addrValid
        -- Initial: DCID=S1, SCID=C1 ->
        --                                       <- Retry: DCID=C1, SCID=S2
        -- Initial: DCID=S2, SCID=C1 ->
        --                                     <- Initial: DCID=C1, SCID=S3
        --                               ...
        -- 1-RTT: DCID=S3 ->
        --                                                <- 1-RTT: DCID=C1
        --
        -- initial_source_connection_id       = S3   (dCID)  S2 in our server
        -- original_destination_connection_id = S1   (o)
        -- retry_source_connection_id         = S2   (dCID)
        pushToAcceptRetried :: CryptoToken -> IO ()
pushToAcceptRetried (CryptoToken Version
_ Word32
_ TimeMicrosecond
_ (Just (CID
_, CID
_, CID
o))) = do
            let myAuthCIDs :: AuthCIDs
myAuthCIDs =
                    AuthCIDs
defaultAuthCIDs
                        { initSrcCID = Just dCID
                        , origDstCID = Just o
                        , retrySrcCID = Just dCID
                        }
                peerAuthCIDs :: AuthCIDs
peerAuthCIDs =
                    AuthCIDs
defaultAuthCIDs
                        { initSrcCID = Just sCID
                        }
            AuthCIDs -> AuthCIDs -> Bool -> IO ()
pushToAcceptQ AuthCIDs
myAuthCIDs AuthCIDs
peerAuthCIDs Bool
True
        pushToAcceptRetried CryptoToken
_ = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        isRetryTokenValid :: CryptoToken -> IO Bool
isRetryTokenValid (CryptoToken Version
_tver Word32
life TimeMicrosecond
etim (Just (CID
l, CID
r, CID
_))) = do
            Microseconds
diff <- TimeMicrosecond -> IO Microseconds
getElapsedTimeMicrosecond TimeMicrosecond
etim
            Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> IO Bool) -> Bool -> IO Bool
forall a b. (a -> b) -> a -> b
$
                Microseconds
diff Microseconds -> Microseconds -> Bool
forall a. Ord a => a -> a -> Bool
<= Int -> Microseconds
Microseconds (Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
life Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
1000000)
                    Bool -> Bool -> Bool
&& CID
dCID CID -> CID -> Bool
forall a. Eq a => a -> a -> Bool
== CID
l
                    Bool -> Bool -> Bool
&& CID
sCID CID -> CID -> Bool
forall a. Eq a => a -> a -> Bool
== CID
r
                    -- Initial for ACK contains the retry token but
                    -- the version would be already version 2, sigh.
                    Bool -> Bool -> Bool
&& Version
_tver Version -> Version -> Bool
forall a. Eq a => a -> a -> Bool
== Version
peerVer
        isRetryTokenValid CryptoToken
_ = Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
        sendRetry :: IO ()
sendRetry = do
            CID
newdCID <- IO CID
newCID
            CryptoToken
retryToken <- Version -> Int -> CID -> CID -> CID -> IO CryptoToken
generateRetryToken Version
peerVer Int
scTicketLifetime CID
newdCID CID
sCID CID
dCID
            Maybe Token
mnewtoken <-
                Microseconds -> FilePath -> IO Token -> IO (Maybe Token)
forall a. Microseconds -> FilePath -> IO a -> IO (Maybe a)
timeout (Int -> Microseconds
Microseconds Int
100000) FilePath
"sendRetry" (IO Token -> IO (Maybe Token)) -> IO Token -> IO (Maybe Token)
forall a b. (a -> b) -> a -> b
$ TokenManager -> CryptoToken -> IO Token
encryptToken TokenManager
tokenMgr CryptoToken
retryToken
            case Maybe Token
mnewtoken of
                Maybe Token
Nothing -> DebugLogger
logAction Builder
"retry token stacked"
                Just Token
newtoken -> do
                    Token
bss <- RetryPacket -> IO Token
encodeRetryPacket (RetryPacket -> IO Token) -> RetryPacket -> IO Token
forall a b. (a -> b) -> a -> b
$ Version
-> CID -> CID -> Token -> Either CID (Token, Token) -> RetryPacket
RetryPacket Version
peerVer CID
sCID CID
newdCID Token
newtoken (CID -> Either CID (Token, Token)
forall a b. a -> Either a b
Left CID
dCID)
                    Token -> IO ()
send' Token
bss
----------------------------------------------------------------
dispatch
    Dispatch{IORef ConnectionDict
TokenManager
Rate
RecvQDict
CID -> StatelessResetToken
tokenMgr :: Dispatch -> TokenManager
genStatelessReset :: Dispatch -> CID -> StatelessResetToken
dstTable :: Dispatch -> IORef ConnectionDict
srcTable :: Dispatch -> RecvQDict
statelessResetRate :: Dispatch -> Rate
tokenMgr :: TokenManager
dstTable :: IORef ConnectionDict
srcTable :: RecvQDict
genStatelessReset :: CID -> StatelessResetToken
statelessResetRate :: Rate
..}
    ServerConfig
_
    Accept -> IO ()
_
    DebugLogger
_
    Socket
_mysock
    PeerInfo
_peerInfo
    Token -> IO ()
_
    Int
_
    TimeMicrosecond
tim
    (cpkt :: CryptPacket
cpkt@(CryptPacket (RTT0 Version
_ CID
_dCID CID
sCID) Crypt
_), EncryptionLevel
lvl, Int
siz) = do
        Maybe RecvQ
mq <- RecvQDict -> CID -> IO (Maybe RecvQ)
lookupRecvQDict' RecvQDict
srcTable CID
sCID
        case Maybe RecvQ
mq of
            Just RecvQ
q -> RecvQ -> ReceivedPacket -> IO ()
writeRecvQ RecvQ
q (ReceivedPacket -> IO ()) -> ReceivedPacket -> IO ()
forall a b. (a -> b) -> a -> b
$ CryptPacket
-> TimeMicrosecond -> Int -> EncryptionLevel -> ReceivedPacket
mkReceivedPacket CryptPacket
cpkt TimeMicrosecond
tim Int
siz EncryptionLevel
lvl
            Maybe RecvQ
Nothing -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
----------------------------------------------------------------
dispatch
    Dispatch{IORef ConnectionDict
TokenManager
Rate
RecvQDict
CID -> StatelessResetToken
tokenMgr :: Dispatch -> TokenManager
genStatelessReset :: Dispatch -> CID -> StatelessResetToken
dstTable :: Dispatch -> IORef ConnectionDict
srcTable :: Dispatch -> RecvQDict
statelessResetRate :: Dispatch -> Rate
tokenMgr :: TokenManager
dstTable :: IORef ConnectionDict
srcTable :: RecvQDict
genStatelessReset :: CID -> StatelessResetToken
statelessResetRate :: Rate
..}
    ServerConfig
_
    Accept -> IO ()
_
    DebugLogger
logAction
    Socket
mysock
    PeerInfo
peerInfo
    Token -> IO ()
send'
    Int
bytes
    TimeMicrosecond
tim
    (cpkt :: CryptPacket
cpkt@(CryptPacket (Short CID
dCID) Crypt
_), EncryptionLevel
lvl, Int
siz) = do
        Maybe Connection
mconn <- IORef ConnectionDict -> CID -> IO (Maybe Connection)
lookupConnectionDict IORef ConnectionDict
dstTable CID
dCID
        case Maybe Connection
mconn of
            Maybe Connection
Nothing -> do
                -- Three times rule for stateless reset
                -- Our packet size is 1280
                Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
bytes Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
427) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
                    Int
srRate <- Rate -> IO Int
getRate Rate
statelessResetRate
                    -- fixme: hard coding
                    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
srRate Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
statelessResetLimit) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
                        Word8
flag <- (Word8, Word8) -> IO Word8
forall a (m :: * -> *). (Random a, MonadIO m) => (a, a) -> m a
randomRIO (Word8
0, Word8
127)
                        Token
body <- (StdGen -> (Token, StdGen)) -> IO Token
forall (m :: * -> *) a. MonadIO m => (StdGen -> (a, StdGen)) -> m a
getStdRandom ((StdGen -> (Token, StdGen)) -> IO Token)
-> (StdGen -> (Token, StdGen)) -> IO Token
forall a b. (a -> b) -> a -> b
$ Int -> StdGen -> (Token, StdGen)
forall g. RandomGen g => Int -> g -> (Token, g)
uniformByteString Int
1263
                        let srt :: StatelessResetToken
srt = CID -> StatelessResetToken
genStatelessReset CID
dCID
                            statelessReset :: Token
statelessReset = [Token] -> Token
BS.concat [Word8 -> Token
BS.singleton Word8
flag, Token
body, StatelessResetToken -> Token
fromStatelessResetToken StatelessResetToken
srt]
                        Token -> IO ()
send' Token
statelessReset
                        DebugLogger
logAction DebugLogger -> DebugLogger
forall a b. (a -> b) -> a -> b
$ Builder
"Stateless reset is sent to " Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> PeerInfo -> Builder
forall a. Show a => a -> Builder
bhow PeerInfo
peerInfo
            Just Connection
conn -> do
                Bool
alive <- Connection -> IO Bool
forall a. Connector a => a -> IO Bool
getAlive Connection
conn
                Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
alive (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
                    IO Socket -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Socket -> IO ()) -> IO Socket -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> Socket -> IO Socket
setSocket Connection
conn Socket
mysock
                    Connection -> PeerInfo -> IO ()
setPeerInfo Connection
conn PeerInfo
peerInfo
                    RecvQ -> ReceivedPacket -> IO ()
writeRecvQ (Connection -> RecvQ
connRecvQ Connection
conn) (ReceivedPacket -> IO ()) -> ReceivedPacket -> IO ()
forall a b. (a -> b) -> a -> b
$ CryptPacket
-> TimeMicrosecond -> Int -> EncryptionLevel -> ReceivedPacket
mkReceivedPacket CryptPacket
cpkt TimeMicrosecond
tim Int
siz EncryptionLevel
lvl
----------------------------------------------------------------
dispatch
    Dispatch{IORef ConnectionDict
TokenManager
Rate
RecvQDict
CID -> StatelessResetToken
tokenMgr :: Dispatch -> TokenManager
genStatelessReset :: Dispatch -> CID -> StatelessResetToken
dstTable :: Dispatch -> IORef ConnectionDict
srcTable :: Dispatch -> RecvQDict
statelessResetRate :: Dispatch -> Rate
tokenMgr :: TokenManager
dstTable :: IORef ConnectionDict
srcTable :: RecvQDict
genStatelessReset :: CID -> StatelessResetToken
statelessResetRate :: Rate
..}
    ServerConfig
_
    Accept -> IO ()
_
    DebugLogger
logAction
    Socket
mysock
    PeerInfo
peerInfo
    Token -> IO ()
_
    Int
_
    TimeMicrosecond
tim
    (cpkt :: CryptPacket
cpkt@(CryptPacket Header
hdr Crypt
_crypt), EncryptionLevel
lvl, Int
siz) = do
        let dCID :: CID
dCID = Header -> CID
headerMyCID Header
hdr
        Maybe Connection
mconn <- IORef ConnectionDict -> CID -> IO (Maybe Connection)
lookupConnectionDict IORef ConnectionDict
dstTable CID
dCID
        case Maybe Connection
mconn of
            Maybe Connection
Nothing -> DebugLogger
logAction DebugLogger -> DebugLogger
forall a b. (a -> b) -> a -> b
$ Builder
"CID no match: " Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> CID -> Builder
forall a. Show a => a -> Builder
bhow CID
dCID Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Builder
", " Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> PeerInfo -> Builder
forall a. Show a => a -> Builder
bhow PeerInfo
peerInfo
            Just Connection
conn -> do
                IO Socket -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Socket -> IO ()) -> IO Socket -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> Socket -> IO Socket
setSocket Connection
conn Socket
mysock
                Connection -> PeerInfo -> IO ()
setPeerInfo Connection
conn PeerInfo
peerInfo
                RecvQ -> ReceivedPacket -> IO ()
writeRecvQ (Connection -> RecvQ
connRecvQ Connection
conn) (ReceivedPacket -> IO ()) -> ReceivedPacket -> IO ()
forall a b. (a -> b) -> a -> b
$ CryptPacket
-> TimeMicrosecond -> Int -> EncryptionLevel -> ReceivedPacket
mkReceivedPacket CryptPacket
cpkt TimeMicrosecond
tim Int
siz EncryptionLevel
lvl

recvServer :: RecvQ -> IO ReceivedPacket
recvServer :: RecvQ -> IO ReceivedPacket
recvServer = RecvQ -> IO ReceivedPacket
readRecvQ