-- |
-- Module      : Net.DNSBase.Internal.SockIO
-- Description : Low-level UDP/TCP socket I/O for DNS messages
-- Copyright   : (c) IIJ Innovation Institute Inc., 2009
--               (c) Viktor Dukhovni, 2020-2026
-- License     : BSD-3-Clause
-- Maintainer  : ietf-dane@dukhovni.org
-- Stability   : unstable
{-# LANGUAGE RecordWildCards #-}

module Net.DNSBase.Internal.SockIO (
    -- * Receiving DNS messages
    receiveUDP
  , receiveTCP
    -- * Sending pre-encoded DNS messages
  , sendUDP
  , sendTCP
  ) where

import qualified Data.ByteString as B
import qualified Network.Socket.ByteString as Socket
import Network.Socket (Socket)
import Network.Socket.ByteString (recv)
import System.IO.Error (tryIOError, mkIOError, eofErrorType)

import Net.DNSBase.Internal.Error
import Net.DNSBase.Internal.Util
import Net.DNSBase.Resolver.Internal.Types

----------------------------------------------------------------

-- | Receive and a single 'Net.DNSBase.Message.DNSMessage' over a UDP 'Socket'.  Messages
-- longer than 'Net.DNSBase.Resolver.maxUdpSize' are silently truncated, but this should not occur
-- in practice, since we cap the advertised EDNS UDP buffer size limit at the
-- same value.  A 'DNSError' is raised if the I/O operation fails.
--
receiveUDP :: Word16 -> Socket -> DNSIO B.ByteString
receiveUDP :: Word16 -> Socket -> DNSIO ByteString
receiveUDP Word16
maxudp Socket
sock = (IOError -> DNSError)
-> ExceptT IOError IO ByteString -> DNSIO ByteString
forall (m :: * -> *) e e' a.
Functor m =>
(e -> e') -> ExceptT e m a -> ExceptT e' m a
withExceptT IOError -> DNSError
wrapError (ExceptT IOError IO ByteString -> DNSIO ByteString)
-> ExceptT IOError IO ByteString -> DNSIO ByteString
forall a b. (a -> b) -> a -> b
$ Socket -> Int -> ExceptT IOError IO ByteString
recv' Socket
sock Int
bufsiz
  where
    bufsiz :: Int
bufsiz = Word16 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
maxudp
    wrapError :: IOError -> DNSError
wrapError = NetworkContext -> DNSError
NetworkError (NetworkContext -> DNSError)
-> (IOError -> NetworkContext) -> IOError -> DNSError
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IOError -> NetworkContext
NetworkFailure

recv' :: Socket -> Int -> ExceptT IOError IO ByteString
recv' :: Socket -> Int -> ExceptT IOError IO ByteString
recv' Socket
sock Int
bufsiz = IO (Either IOError ByteString) -> ExceptT IOError IO ByteString
forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT (IO (Either IOError ByteString) -> ExceptT IOError IO ByteString)
-> IO (Either IOError ByteString) -> ExceptT IOError IO ByteString
forall a b. (a -> b) -> a -> b
$ IO ByteString -> IO (Either IOError ByteString)
forall a. IO a -> IO (Either IOError a)
tryIOError (IO ByteString -> IO (Either IOError ByteString))
-> IO ByteString -> IO (Either IOError ByteString)
forall a b. (a -> b) -> a -> b
$ Socket -> Int -> IO ByteString
recv Socket
sock Int
bufsiz

-- | Receive a single DNS message over a virtual-circuit (TCP) connection.  It
-- is up to the caller to implement any desired timeout. An 'DNSError' is
-- raised if the I/O operation fails.
--
receiveTCP :: Socket -> DNSIO B.ByteString
receiveTCP :: Socket -> DNSIO ByteString
receiveTCP Socket
sock = Socket -> Int -> DNSIO ByteString
recvDNS Socket
sock Int
2 DNSIO ByteString
-> (ByteString -> DNSIO ByteString) -> DNSIO ByteString
forall a b.
ExceptT DNSError IO a
-> (a -> ExceptT DNSError IO b) -> ExceptT DNSError IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Socket -> Int -> DNSIO ByteString
recvDNS Socket
sock (Int -> DNSIO ByteString)
-> (ByteString -> Int) -> ByteString -> DNSIO ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Int
toLen
  where
    toLen :: ByteString -> Int
    toLen :: ByteString -> Int
toLen = Word16 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word16 -> Int) -> (ByteString -> Word16) -> ByteString -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Word16
word16be

recvDNS :: Socket -> Int -> DNSIO ByteString
recvDNS :: Socket -> Int -> DNSIO ByteString
recvDNS Socket
sock Int
len = (IOError -> DNSError)
-> ExceptT IOError IO ByteString -> DNSIO ByteString
forall (m :: * -> *) e e' a.
Functor m =>
(e -> e') -> ExceptT e m a -> ExceptT e' m a
withExceptT IOError -> DNSError
wrapError ExceptT IOError IO ByteString
recv1
  where
    wrapError :: IOError -> DNSError
wrapError = NetworkContext -> DNSError
NetworkError (NetworkContext -> DNSError)
-> (IOError -> NetworkContext) -> IOError -> DNSError
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IOError -> NetworkContext
NetworkFailure

    recv1 :: ExceptT IOError IO ByteString
    recv1 :: ExceptT IOError IO ByteString
recv1 = Int -> ExceptT IOError IO ByteString
recvCore Int
len ExceptT IOError IO ByteString
-> (ByteString -> ExceptT IOError IO ByteString)
-> ExceptT IOError IO ByteString
forall a b.
ExceptT IOError IO a
-> (a -> ExceptT IOError IO b) -> ExceptT IOError IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (ByteString -> Bool)
-> (ByteString -> ExceptT IOError IO ByteString)
-> (ByteString -> ExceptT IOError IO ByteString)
-> ByteString
-> ExceptT IOError IO ByteString
forall a b. (a -> Bool) -> (a -> b) -> (a -> b) -> a -> b
cond (ByteString -> Int
B.length (ByteString -> Int) -> Int -> ByteString -> Bool
forall b a. Eq b => (a -> b) -> b -> a -> Bool
.= Int
len) ByteString -> ExceptT IOError IO ByteString
forall a. a -> ExceptT IOError IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString -> ExceptT IOError IO ByteString
loop

    loop :: ByteString -> ExceptT IOError IO ByteString
    loop :: ByteString -> ExceptT IOError IO ByteString
loop ByteString
bs0 = do
        let left :: Int
left = Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- ByteString -> Int
B.length ByteString
bs0
        bs1 <- Int -> ExceptT IOError IO ByteString
recvCore Int
left
        cond (B.length .= len) return loop $! bs0 <> bs1

    eofE :: IOError
eofE = IOErrorType -> String -> Maybe Handle -> Maybe String -> IOError
mkIOError IOErrorType
eofErrorType String
"connection terminated" Maybe Handle
forall a. Maybe a
Nothing Maybe String
forall a. Maybe a
Nothing

    recvCore :: Int -> ExceptT IOError IO ByteString
    recvCore :: Int -> ExceptT IOError IO ByteString
recvCore Int
len0 = Socket -> Int -> ExceptT IOError IO ByteString
recv' Socket
sock Int
len0
                ExceptT IOError IO ByteString
-> (ByteString -> ExceptT IOError IO ByteString)
-> ExceptT IOError IO ByteString
forall a b.
ExceptT IOError IO a
-> (a -> ExceptT IOError IO b) -> ExceptT IOError IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (ByteString -> Bool)
-> (ByteString -> ExceptT IOError IO ByteString)
-> (ByteString -> ExceptT IOError IO ByteString)
-> ByteString
-> ExceptT IOError IO ByteString
forall a b. (a -> Bool) -> (a -> b) -> (a -> b) -> a -> b
cond ByteString -> Bool
B.null (ExceptT IOError IO ByteString
-> ByteString -> ExceptT IOError IO ByteString
forall a b. a -> b -> a
const (ExceptT IOError IO ByteString
 -> ByteString -> ExceptT IOError IO ByteString)
-> ExceptT IOError IO ByteString
-> ByteString
-> ExceptT IOError IO ByteString
forall a b. (a -> b) -> a -> b
$ IOError -> ExceptT IOError IO ByteString
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE IOError
eofE) ByteString -> ExceptT IOError IO ByteString
forall a. a -> ExceptT IOError IO a
forall (m :: * -> *) a. Monad m => a -> m a
return

----------------------------------------------------------------

-- | Send an encoded 'Net.DNSBase.Message.DNSMessage' datagram over UDP.  The socket must be
-- explicitly connected to the destination nameserver.  The message length is
-- implicit in the size of the UDP datagram.  With TCP you must use 'sendTCP',
-- because TCP does not have message boundaries, and each message needs to be
-- prepended with an explicit length.
--
sendUDP :: Socket -> ByteString -> DNSIO ()
sendUDP :: Socket -> ByteString -> DNSIO ()
sendUDP Socket
sock = IO () -> DNSIO ()
forall (m :: * -> *) a. Monad m => m a -> ExceptT DNSError m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IO () -> DNSIO ())
-> (ByteString -> IO ()) -> ByteString -> DNSIO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO Int -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Int -> IO ()) -> (ByteString -> IO Int) -> ByteString -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Socket -> ByteString -> IO Int
Socket.send Socket
sock

-- | Send one or more encoded 'Net.DNSBase.Message.DNSMessage' buffers over TCP, each already
-- encapsulated with an explicit length prefix and then concatenated into a
-- single buffer.  DO NOT use 'sendTCP' with UDP.
--
sendTCP :: Socket -> ByteString -> DNSIO ()
sendTCP :: Socket -> ByteString -> DNSIO ()
sendTCP Socket
vc = IO () -> DNSIO ()
forall (m :: * -> *) a. Monad m => m a -> ExceptT DNSError m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IO () -> DNSIO ())
-> (ByteString -> IO ()) -> ByteString -> DNSIO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Socket -> ByteString -> IO ()
Socket.sendAll Socket
vc