{-# LANGUAGE OverloadedStrings #-}

module Network.QUIC.Closer (closure) where

import Control.Concurrent
import qualified Control.Exception as E
import Foreign.Marshal.Alloc
import Foreign.Ptr
import qualified Network.Socket as NS

import Network.QUIC.Common
import Network.QUIC.Config
import Network.QUIC.Connection
import Network.QUIC.Connector
import Network.QUIC.Imports
import Network.QUIC.Logger
import Network.QUIC.Packet
import Network.QUIC.Recovery
import Network.QUIC.Sender
import Network.QUIC.Types

closure :: Connection -> LDCC -> Either E.SomeException a -> IO a
closure :: forall a. Connection -> LDCC -> Either SomeException a -> IO a
closure Connection
conn LDCC
ldcc (Right a
x) = do
    Connection -> LDCC -> Frame -> IO ()
closure' Connection
conn LDCC
ldcc (Frame -> IO ()) -> Frame -> IO ()
forall a b. (a -> b) -> a -> b
$ TransportError -> Int -> ReasonPhrase -> Frame
ConnectionClose TransportError
NoError Int
0 ReasonPhrase
""
    a -> IO a
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
closure Connection
conn LDCC
ldcc (Left SomeException
se)
    | Just e :: QUICException
e@(TransportErrorIsSent TransportError
err ReasonPhrase
desc) <- SomeException -> Maybe QUICException
forall e. Exception e => SomeException -> Maybe e
E.fromException SomeException
se = do
        Connection -> LDCC -> Frame -> IO ()
closure' Connection
conn LDCC
ldcc (Frame -> IO ()) -> Frame -> IO ()
forall a b. (a -> b) -> a -> b
$ TransportError -> Int -> ReasonPhrase -> Frame
ConnectionClose TransportError
err Int
0 ReasonPhrase
desc
        QUICException -> IO a
forall e a. Exception e => e -> IO a
E.throwIO QUICException
e
    | Just e :: QUICException
e@(ApplicationProtocolErrorIsSent ApplicationProtocolError
err ReasonPhrase
desc) <- SomeException -> Maybe QUICException
forall e. Exception e => SomeException -> Maybe e
E.fromException SomeException
se = do
        Connection -> LDCC -> Frame -> IO ()
closure' Connection
conn LDCC
ldcc (Frame -> IO ()) -> Frame -> IO ()
forall a b. (a -> b) -> a -> b
$ ApplicationProtocolError -> ReasonPhrase -> Frame
ConnectionCloseApp ApplicationProtocolError
err ReasonPhrase
desc
        QUICException -> IO a
forall e a. Exception e => e -> IO a
E.throwIO QUICException
e
    | Just (Abort ApplicationProtocolError
err ReasonPhrase
desc) <- SomeException -> Maybe Abort
forall e. Exception e => SomeException -> Maybe e
E.fromException SomeException
se = do
        Connection -> LDCC -> Frame -> IO ()
closure' Connection
conn LDCC
ldcc (Frame -> IO ()) -> Frame -> IO ()
forall a b. (a -> b) -> a -> b
$ ApplicationProtocolError -> ReasonPhrase -> Frame
ConnectionCloseApp ApplicationProtocolError
err ReasonPhrase
desc
        QUICException -> IO a
forall e a. Exception e => e -> IO a
E.throwIO (QUICException -> IO a) -> QUICException -> IO a
forall a b. (a -> b) -> a -> b
$ ApplicationProtocolError -> ReasonPhrase -> QUICException
ApplicationProtocolErrorIsSent ApplicationProtocolError
err ReasonPhrase
desc
    | Just (VerNego VersionInfo
vers) <- SomeException -> Maybe Abort
forall e. Exception e => SomeException -> Maybe e
E.fromException SomeException
se = do
        NextVersion -> IO a
forall e a. Exception e => e -> IO a
E.throwIO (NextVersion -> IO a) -> NextVersion -> IO a
forall a b. (a -> b) -> a -> b
$ VersionInfo -> NextVersion
NextVersion VersionInfo
vers
    | Bool
otherwise = SomeException -> IO a
forall e a. Exception e => e -> IO a
E.throwIO SomeException
se -- including asynchronous exceptions

closure' :: Connection -> LDCC -> Frame -> IO ()
closure' :: Connection -> LDCC -> Frame -> IO ()
closure' Connection
conn LDCC
ldcc Frame
frame = do
    Socket
sock <- Connection -> IO Socket
getSocket Connection
conn
    SockAddr
peersa <- PathInfo -> SockAddr
peerSockAddr (PathInfo -> SockAddr) -> IO PathInfo -> IO SockAddr
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Connection -> IO PathInfo
getPathInfo Connection
conn
    Bool
connected <- Connection -> IO Bool
getSockConnected Connection
conn
    -- send
    let bufsiz :: Int
bufsiz = Int
maximumUdpPayloadSize
    Ptr Word8
sendbuf <- Int -> IO (Ptr Word8)
forall a. Int -> IO (Ptr a)
mallocBytes Int
bufsiz
    -- This must be called before freeResourcesin runClient.
    Int
siz <- Connection -> Ptr Word8 -> Int -> Frame -> IO Int
encodeCC Connection
conn Ptr Word8
sendbuf Int
bufsiz Frame
frame
    let send :: IO ()
send
            | Bool
connected = IO Int -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Int -> IO ()) -> IO Int -> IO ()
forall a b. (a -> b) -> a -> b
$ Socket -> Ptr Word8 -> Int -> IO Int
NS.sendBuf Socket
sock Ptr Word8
sendbuf Int
siz
            | Bool
otherwise = IO Int -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Int -> IO ()) -> IO Int -> IO ()
forall a b. (a -> b) -> a -> b
$ Socket -> Ptr Word8 -> Int -> SockAddr -> IO Int
forall a. Socket -> Ptr a -> Int -> SockAddr -> IO Int
NS.sendBufTo Socket
sock Ptr Word8
sendbuf Int
siz SockAddr
peersa
    -- recv and clos
    Connection -> IO ()
killReaders Connection
conn -- client only
    (IO ()
recv, IO ()
freeRecvBuf, IO ()
clos) <-
        if Connection -> Bool
forall a. Connector a => a -> Bool
isServer Connection
conn
            then (IO (), IO (), IO ()) -> IO (IO (), IO (), IO ())
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (IO ReceivedPacket -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO ReceivedPacket -> IO ()) -> IO ReceivedPacket -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> IO ReceivedPacket
connRecv Connection
conn, Ptr Word8 -> IO ()
forall a. Ptr a -> IO ()
free Ptr Word8
sendbuf, () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ())
            else do
                Ptr Word8
recvbuf <- Int -> IO (Ptr Word8)
forall a. Int -> IO (Ptr a)
mallocBytes Int
bufsiz
                let recv' :: IO ()
recv'
                        | Bool
connected = IO Int -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Int -> IO ()) -> IO Int -> IO ()
forall a b. (a -> b) -> a -> b
$ Socket -> Ptr Word8 -> Int -> IO Int
NS.recvBuf Socket
sock Ptr Word8
recvbuf Int
bufsiz
                        | Bool
otherwise = do
                            (Int
_, SockAddr
sa) <- Socket -> Ptr Word8 -> Int -> IO (Int, SockAddr)
forall a. Socket -> Ptr a -> Int -> IO (Int, SockAddr)
NS.recvBufFrom Socket
sock Ptr Word8
recvbuf Int
bufsiz
                            Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (SockAddr
sa SockAddr -> SockAddr -> Bool
forall a. Eq a => a -> a -> Bool
/= SockAddr
peersa) IO ()
recv'
                    free' :: IO ()
free' = Ptr Word8 -> IO ()
forall a. Ptr a -> IO ()
free Ptr Word8
recvbuf IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Ptr Word8 -> IO ()
forall a. Ptr a -> IO ()
free Ptr Word8
sendbuf
                    clos' :: IO ()
clos' = do
                        Socket -> IO ()
NS.close Socket
sock
                        -- This is just in case.
                        Connection -> IO Socket
getSocket Connection
conn IO Socket -> (Socket -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Socket -> IO ()
NS.close
                (IO (), IO (), IO ()) -> IO (IO (), IO (), IO ())
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (IO ()
recv', IO ()
free', IO ()
clos')
    -- hook
    let hook :: IO ()
hook = Hooks -> IO ()
onCloseCompleted (Hooks -> IO ()) -> Hooks -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> Hooks
connHooks Connection
conn
    Microseconds
pto <- LDCC -> IO Microseconds
getPTO LDCC
ldcc
    IO ThreadId -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO ThreadId -> IO ()) -> IO ThreadId -> IO ()
forall a b. (a -> b) -> a -> b
$ IO () -> (Either SomeException () -> IO ()) -> IO ThreadId
forall a. IO a -> (Either SomeException a -> IO ()) -> IO ThreadId
forkFinally (Connection -> Microseconds -> IO () -> IO () -> IO () -> IO ()
closer Connection
conn Microseconds
pto IO ()
send IO ()
recv IO ()
hook) ((Either SomeException () -> IO ()) -> IO ThreadId)
-> (Either SomeException () -> IO ()) -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ \Either SomeException ()
e -> do
        case Either SomeException ()
e of
            Left SomeException
e' -> Connection -> DebugLogger
connDebugLog Connection
conn DebugLogger -> DebugLogger
forall a b. (a -> b) -> a -> b
$ Builder
"closure' " Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> SomeException -> Builder
forall a. Show a => a -> Builder
bhow SomeException
e'
            Right ()
_ -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        IO ()
freeRecvBuf
        IO ()
clos

encodeCC :: Connection -> Buffer -> BufferSize -> Frame -> IO Int
encodeCC :: Connection -> Ptr Word8 -> Int -> Frame -> IO Int
encodeCC Connection
conn Ptr Word8
sendbuf0 Int
bufsiz0 Frame
frame = do
    EncryptionLevel
lvl0 <- Connection -> IO EncryptionLevel
forall a. Connector a => a -> IO EncryptionLevel
getEncryptionLevel Connection
conn
    let lvl :: EncryptionLevel
lvl
            | EncryptionLevel
lvl0 EncryptionLevel -> EncryptionLevel -> Bool
forall a. Eq a => a -> a -> Bool
== EncryptionLevel
RTT0Level = EncryptionLevel
InitialLevel
            | Bool
otherwise = EncryptionLevel
lvl0
    if EncryptionLevel
lvl EncryptionLevel -> EncryptionLevel -> Bool
forall a. Eq a => a -> a -> Bool
== EncryptionLevel
HandshakeLevel
        then do
            Int
siz0 <- Ptr Word8 -> Int -> EncryptionLevel -> IO Int
encCC Ptr Word8
sendbuf0 Int
bufsiz0 EncryptionLevel
InitialLevel
            let sendbuf1 :: Ptr b
sendbuf1 = Ptr Word8
sendbuf0 Ptr Word8 -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
siz0
                bufsiz1 :: Int
bufsiz1 = Int
bufsiz0 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
siz0
            Int
siz1 <- Ptr Word8 -> Int -> EncryptionLevel -> IO Int
encCC Ptr Word8
forall {b}. Ptr b
sendbuf1 Int
bufsiz1 EncryptionLevel
HandshakeLevel
            Int -> IO Int
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Int
siz0 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
siz1)
        else
            Ptr Word8 -> Int -> EncryptionLevel -> IO Int
encCC Ptr Word8
sendbuf0 Int
bufsiz0 EncryptionLevel
lvl
  where
    encCC :: Ptr Word8 -> Int -> EncryptionLevel -> IO Int
encCC Ptr Word8
sendbuf Int
bufsiz EncryptionLevel
lvl = do
        Header
header <- Connection -> EncryptionLevel -> IO Header
mkHeader Connection
conn EncryptionLevel
lvl
        Int
mypn <- Connection -> IO Int
nextPacketNumber Connection
conn
        let plain :: Plain
plain = Flags Raw -> Int -> [Frame] -> Int -> Plain
Plain (Word8 -> Flags Raw
forall a. Word8 -> Flags a
Flags Word8
0) Int
mypn [Frame
frame] Int
0
            ppkt :: PlainPacket
ppkt = Header -> Plain -> PlainPacket
PlainPacket Header
header Plain
plain
            res :: SizedBuffer
res = Ptr Word8 -> Int -> SizedBuffer
SizedBuffer Ptr Word8
sendbuf Int
bufsiz
        Int
siz <- (Int, Int) -> Int
forall a b. (a, b) -> a
fst ((Int, Int) -> Int) -> IO (Int, Int) -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Connection
-> SizedBuffer -> PlainPacket -> Maybe Int -> IO (Int, Int)
encodePlainPacket Connection
conn SizedBuffer
res PlainPacket
ppkt Maybe Int
forall a. Maybe a
Nothing
        if Int
siz Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0
            then do
                TimeMicrosecond
now <- IO TimeMicrosecond
getTimeMicrosecond
                Connection -> PlainPacket -> TimeMicrosecond -> IO ()
forall q pkt.
(KeepQlog q, Qlog pkt) =>
q -> pkt -> TimeMicrosecond -> IO ()
qlogSent Connection
conn PlainPacket
ppkt TimeMicrosecond
now
                Int -> IO Int
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Int
siz
            else
                Int -> IO Int
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Int
0

closer :: Connection -> Microseconds -> IO () -> IO () -> IO () -> IO ()
closer :: Connection -> Microseconds -> IO () -> IO () -> IO () -> IO ()
closer Connection
_conn (Microseconds Int
pto) IO ()
send IO ()
recv IO ()
hook = do
    String -> IO ()
labelMe String
"QUIC closer"
    Int -> IO ()
forall {t}. (Eq t, Num t) => t -> IO ()
loop (Int
3 :: Int)
  where
    loop :: t -> IO ()
loop t
0 = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    loop t
n = do
        IO ()
send
        IO TimeMicrosecond
getTimeMicrosecond IO TimeMicrosecond -> (TimeMicrosecond -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Microseconds -> TimeMicrosecond -> IO ()
skip (Int -> Microseconds
Microseconds Int
pto)
        Maybe ()
mx <- Microseconds -> String -> IO () -> IO (Maybe ())
forall a. Microseconds -> String -> IO a -> IO (Maybe a)
timeout (Int -> Microseconds
Microseconds (Int
pto Int -> Int -> Int
forall a. Bits a => a -> Int -> a
!>>. Int
1)) String
"closer 1" IO ()
recv
        case Maybe ()
mx of
            Maybe ()
Nothing -> IO ()
hook
            Just () -> t -> IO ()
loop (t
n t -> t -> t
forall a. Num a => a -> a -> a
- t
1)
    skip :: Microseconds -> TimeMicrosecond -> IO ()
skip tmo :: Microseconds
tmo@(Microseconds Int
duration) TimeMicrosecond
base = do
        Maybe ()
mx <- Microseconds -> String -> IO () -> IO (Maybe ())
forall a. Microseconds -> String -> IO a -> IO (Maybe a)
timeout Microseconds
tmo String
"closer 2" IO ()
recv
        case Maybe ()
mx of
            Maybe ()
Nothing -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
            Just () -> do
                Microseconds Int
elapsed <- TimeMicrosecond -> IO Microseconds
getElapsedTimeMicrosecond TimeMicrosecond
base
                let duration' :: Int
duration' = Int
duration Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
elapsed
                Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
duration' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
5000) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Microseconds -> TimeMicrosecond -> IO ()
skip (Int -> Microseconds
Microseconds Int
duration') TimeMicrosecond
base