-- |
-- Module      : Net.DNSBase.Decode.Internal.Message
-- Description : Decoder for the DNS message envelope (header, sections)
-- Copyright   : (c) IIJ Innovation Institute Inc., 2009
--               (c) Viktor Dukhovni, 2020-2026
-- License     : BSD-3-Clause
-- Maintainer  : ietf-dane@dukhovni.org
-- Stability   : unstable
{-# LANGUAGE RecordWildCards #-}

module Net.DNSBase.Decode.Internal.Message
      ( getMessage
      ) where

import Data.List (partition)

import Net.DNSBase.Decode.Internal.Domain
import Net.DNSBase.Decode.Internal.Option
import Net.DNSBase.Decode.Internal.RData
import Net.DNSBase.Decode.Internal.State
import Net.DNSBase.Internal.Domain
import Net.DNSBase.Internal.EDNS
import Net.DNSBase.Internal.Error
import Net.DNSBase.Internal.Flags
import Net.DNSBase.Internal.Message
import Net.DNSBase.Internal.Opcode
import Net.DNSBase.Internal.RCODE
import Net.DNSBase.Internal.RData
import Net.DNSBase.Internal.RR
import Net.DNSBase.Internal.RRCLASS
import Net.DNSBase.Internal.RRTYPE
import Net.DNSBase.Internal.Util

-- | Decoder for a complete DNSMessage, including EDNS pseudo-header information when present
getMessage :: RDataMap -> OptionMap -> SGet DNSMessage
getMessage :: RDataMap -> OptionMap -> SGet DNSMessage
getMessage RDataMap
dm OptionMap
om = (SGetEnv -> SGetEnv) -> SGet DNSMessage -> SGet DNSMessage
forall a. (SGetEnv -> SGetEnv) -> SGet a -> SGet a
local (DnsSection -> SGetEnv -> SGetEnv
setDecodeSection DnsSection
DnsHeaderSection) do
    phd <- SGet PartialHeader
getPartialHeader
    qdCount <- getInt16
    anCount <- getInt16
    nsCount <- getInt16
    arCount <- getInt16
    queries <- local (setDecodeSection DnsQuestionSection) $ getQueries qdCount
    if | hasAnyFlags TCflag $ p_dnsMsgFl phd
            -- Don't bother parsing RRs of truncated messages, we won't use
            -- them, and they can be truncated in a way that raises parser
            -- errors.
         -> pure $ mkMsg phd No queries [] [] []
       | otherwise -> do
            if | q:_ <- queries -> () <$ setLastOwner (dnsTripleName q)
               | otherwise      -> pure ()
            answers <- local (setDecodeSection DnsAnswerSection) $ getRRs dm Nothing anCount
            authrrs <- local (setDecodeSection DnsAuthoritySection) $ getRRs dm Nothing nsCount
            addnrrs <- local (setDecodeSection DnsAdditionalSection) $ getRRs dm (Just om) arCount
            case partition isOpt addnrrs of
                ([], [RR]
rrs) -> DNSMessage -> SGet DNSMessage
forall a. a -> SGet a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DNSMessage -> SGet DNSMessage) -> DNSMessage -> SGet DNSMessage
forall a b. (a -> b) -> a -> b
$ PartialHeader
-> EDNSData -> [DnsTriple] -> [RR] -> [RR] -> [RR] -> DNSMessage
mkMsg PartialHeader
phd EDNSData
No [DnsTriple]
queries [RR]
answers [RR]
authrrs [RR]
rrs
                ([RR
optrr], [RR]
rrs)
                    | Domain
RootDomain <- RR -> Domain
rrOwner RR
optrr
                    , EDNSData
edns <- RR -> EDNSData
getEDNS RR
optrr
                      -> DNSMessage -> SGet DNSMessage
forall a. a -> SGet a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DNSMessage -> SGet DNSMessage) -> DNSMessage -> SGet DNSMessage
forall a b. (a -> b) -> a -> b
$ PartialHeader
-> EDNSData -> [DnsTriple] -> [RR] -> [RR] -> [RR] -> DNSMessage
mkMsg PartialHeader
phd EDNSData
edns [DnsTriple]
queries [RR]
answers [RR]
authrrs [RR]
rrs
                ([RR], [RR])
_ -> (SGetEnv -> SGetEnv) -> SGet DNSMessage -> SGet DNSMessage
forall a. (SGetEnv -> SGetEnv) -> SGet a -> SGet a
local (DnsSection -> SGetEnv -> SGetEnv
setDecodeSection DnsSection
DnsEDNSSection) (SGet DNSMessage -> SGet DNSMessage)
-> SGet DNSMessage -> SGet DNSMessage
forall a b. (a -> b) -> a -> b
$
                         String -> SGet DNSMessage
forall a. String -> SGet a
failSGet String
"Multiple or bad additional section OPT records"
  where
    isOpt :: RR -> Bool
    isOpt :: RR -> Bool
isOpt = (RRTYPE -> RRTYPE -> Bool
forall a. Eq a => a -> a -> Bool
== RRTYPE
OPT) (RRTYPE -> Bool) -> (RR -> RRTYPE) -> RR -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RData -> RRTYPE
rdataType (RData -> RRTYPE) -> (RR -> RData) -> RR -> RRTYPE
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RR -> RData
rrData

    getEDNS :: RR -> EDNSData
    getEDNS :: RR -> EDNSData
getEDNS RR
rr
      | Just (EDNS
edns, Word8
ext_rc, Word16
ext_fl) <- RR -> Maybe (EDNS, Word8, Word16)
optEDNS RR
rr = Yes{Word8
Word16
EDNS
edns :: EDNS
ext_rc :: Word8
ext_fl :: Word16
edns :: EDNS
ext_rc :: Word8
ext_fl :: Word16
..}
        -- Should not happen, the OPT record should always be a T_OPT!
      | Bool
otherwise                                 = EDNSData
No

    optEDNS :: RR -> Maybe (EDNS, Word8, Word16)
    optEDNS :: RR -> Maybe (EDNS, Word8, Word16)
optEDNS (RR Domain
_ RRCLASS
vcl Word32
vttl RData
rd)
        | Just (T_OPT [EdnsOption]
opts) <- RData -> Maybe T_opt
forall a. KnownRData a => RData -> Maybe a
fromRData RData
rd
        , Word8
ext_rc <- Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32 -> Word8) -> Word32 -> Word8
forall a b. (a -> b) -> a -> b
$ (Word32
vttl Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
24) Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
0xff
        , Word8
vers   <- Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32 -> Word8) -> Word32 -> Word8
forall a b. (a -> b) -> a -> b
$ (Word32
vttl Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
16) Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
0xff
        , Word16
ext_fl <- Word32 -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32 -> Word16) -> Word32 -> Word16
forall a b. (a -> b) -> a -> b
$ Word32
vttl Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
0xffff
            = (EDNS, Word8, Word16) -> Maybe (EDNS, Word8, Word16)
forall a. a -> Maybe a
Just (Word8 -> Word16 -> [EdnsOption] -> EDNS
EDNS Word8
vers (RRCLASS -> Word16
forall a b. Coercible a b => a -> b
coerce RRCLASS
vcl) [EdnsOption]
opts, Word8
ext_rc, Word16
ext_fl)
        | Bool
otherwise
            = Maybe (EDNS, Word8, Word16)
forall a. Maybe a
Nothing

-- | Decoder for a list of question (query) fields appearing within a DNS
-- message.  The integer parameter corresponds to the reported QDCOUNT of the
-- message, which should never be more than 1; this decoder neither tests nor
-- enforces this constraint and will attempt to decode exactly as many
-- questions as are reported to exist.
getQueries :: Int -> SGet [DnsTriple]
getQueries :: Int -> SGet [DnsTriple]
getQueries Int
n = Int -> SGet DnsTriple -> SGet [DnsTriple]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n SGet DnsTriple
getQuery
  where
    getQuery :: SGet DnsTriple
    getQuery :: SGet DnsTriple
getQuery = Domain -> RRTYPE -> RRCLASS -> DnsTriple
DnsTriple (Domain -> RRTYPE -> RRCLASS -> DnsTriple)
-> SGet Domain -> SGet (RRTYPE -> RRCLASS -> DnsTriple)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SGet Domain
getDomain SGet (RRTYPE -> RRCLASS -> DnsTriple)
-> SGet RRTYPE -> SGet (RRCLASS -> DnsTriple)
forall a b. SGet (a -> b) -> SGet a -> SGet b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SGet RRTYPE
getType SGet (RRCLASS -> DnsTriple) -> SGet RRCLASS -> SGet DnsTriple
forall a b. SGet (a -> b) -> SGet a -> SGet b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SGet RRCLASS
getClass
      where
        getType :: SGet RRTYPE
getType = Word16 -> RRTYPE
RRTYPE (Word16 -> RRTYPE) -> SGet Word16 -> SGet RRTYPE
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SGet Word16
get16
        getClass :: SGet RRCLASS
getClass = Word16 -> RRCLASS
RRCLASS (Word16 -> RRCLASS) -> SGet Word16 -> SGet RRCLASS
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SGet Word16
get16

-- | Decoder for a known-length list of resource records
getRRs :: RDataMap -> Maybe OptionMap -> Int -> SGet [RR]
getRRs :: RDataMap -> Maybe OptionMap -> Int -> SGet [RR]
getRRs RDataMap
dm Maybe OptionMap
om Int
n = Int -> SGet RR -> SGet [RR]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (RDataMap -> Maybe OptionMap -> SGet RR
getRR RDataMap
dm Maybe OptionMap
om)

-- | Decoder for a 'PartialHeader' contained in the header of a DNS message
getPartialHeader :: SGet PartialHeader
getPartialHeader :: SGet PartialHeader
getPartialHeader =
    Word16 -> (Opcode, PartialRCODE, DNSFlags) -> PartialHeader
makeHeader (Word16 -> (Opcode, PartialRCODE, DNSFlags) -> PartialHeader)
-> SGet Word16
-> SGet ((Opcode, PartialRCODE, DNSFlags) -> PartialHeader)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SGet Word16
decodeMsgId SGet ((Opcode, PartialRCODE, DNSFlags) -> PartialHeader)
-> SGet (Opcode, PartialRCODE, DNSFlags) -> SGet PartialHeader
forall a b. SGet (a -> b) -> SGet a -> SGet b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SGet (Opcode, PartialRCODE, DNSFlags)
getOpRFlags
  where
    makeHeader :: Word16 -> (Opcode, PartialRCODE, DNSFlags) -> PartialHeader
makeHeader Word16
mid (Opcode
oc,PartialRCODE
rc,DNSFlags
fl) = Word16 -> Opcode -> PartialRCODE -> DNSFlags -> PartialHeader
PartialHeader Word16
mid Opcode
oc PartialRCODE
rc DNSFlags
fl
    decodeMsgId :: SGet Word16
decodeMsgId = SGet Word16
get16

    getOpRFlags :: SGet (Opcode, PartialRCODE, PartialDNSFlags)
    getOpRFlags :: SGet (Opcode, PartialRCODE, DNSFlags)
getOpRFlags = do
        raw <- SGet Word16
get16
        return $ ( extractOpcode raw
                 , extractRCODE  raw
                 , makeDNSFlags  raw
                 )


type PartialRCODE = RCODE
type PartialDNSFlags = DNSFlags


-- | Data type representing the absence or presence of
-- an OPT record, which individually represents the extended bits of
-- the DNS flags and RCODE contained in the EDNS pseudo-header
-- and the remaining EDNS data
data EDNSData = No
              | Yes { EDNSData -> Word16
ext_fl :: Word16
                    , EDNSData -> Word8
ext_rc :: Word8
                    , EDNSData -> EDNS
edns   :: EDNS
                    } deriving (EDNSData -> EDNSData -> Bool
(EDNSData -> EDNSData -> Bool)
-> (EDNSData -> EDNSData -> Bool) -> Eq EDNSData
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: EDNSData -> EDNSData -> Bool
== :: EDNSData -> EDNSData -> Bool
$c/= :: EDNSData -> EDNSData -> Bool
/= :: EDNSData -> EDNSData -> Bool
Eq)

-- | Component of DNS message header that is extracted directly from
-- leading bytes of the DNS message (i.e. without parsing EDNS pseudo-header)
data PartialHeader = PartialHeader {
      PartialHeader -> Word16
p_dnsMsgId :: QueryID
    , PartialHeader -> Opcode
p_dnsMsgOp :: Opcode
    , PartialHeader -> PartialRCODE
p_dnsMsgRC :: PartialRCODE
    , PartialHeader -> DNSFlags
p_dnsMsgFl :: PartialDNSFlags
    } deriving (PartialHeader -> PartialHeader -> Bool
(PartialHeader -> PartialHeader -> Bool)
-> (PartialHeader -> PartialHeader -> Bool) -> Eq PartialHeader
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: PartialHeader -> PartialHeader -> Bool
== :: PartialHeader -> PartialHeader -> Bool
$c/= :: PartialHeader -> PartialHeader -> Bool
/= :: PartialHeader -> PartialHeader -> Bool
Eq, Int -> PartialHeader -> ShowS
[PartialHeader] -> ShowS
PartialHeader -> String
(Int -> PartialHeader -> ShowS)
-> (PartialHeader -> String)
-> ([PartialHeader] -> ShowS)
-> Show PartialHeader
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> PartialHeader -> ShowS
showsPrec :: Int -> PartialHeader -> ShowS
$cshow :: PartialHeader -> String
show :: PartialHeader -> String
$cshowList :: [PartialHeader] -> ShowS
showList :: [PartialHeader] -> ShowS
Show)

-- | Assemble a 'DNSMessage' from a 'PartialHeader' (the basic DNS
-- header bits decoded earlier in the parse) and a possibly-vacuous
-- 'EDNSData'.  The EDNS OPT pseudo-RR sits at the wire trailer
-- (in the additional section) but is logically header material,
-- contributing the upper bits of the extended 'RCODE' and the
-- extended flags.  The questions and the three RR sections
-- complete the message.
mkMsg :: PartialHeader
      -> EDNSData
      -> [DnsTriple]
      -> [RR] -> [RR] -> [RR]
      -> DNSMessage
mkMsg :: PartialHeader
-> EDNSData -> [DnsTriple] -> [RR] -> [RR] -> [RR] -> DNSMessage
mkMsg PartialHeader{Word16
PartialRCODE
Opcode
DNSFlags
p_dnsMsgFl :: PartialHeader -> DNSFlags
p_dnsMsgId :: PartialHeader -> Word16
p_dnsMsgOp :: PartialHeader -> Opcode
p_dnsMsgRC :: PartialHeader -> PartialRCODE
p_dnsMsgId :: Word16
p_dnsMsgOp :: Opcode
p_dnsMsgRC :: PartialRCODE
p_dnsMsgFl :: DNSFlags
..} EDNSData
No [DnsTriple]
dnsMsgQu [RR]
dnsMsgAn [RR]
dnsMsgNs [RR]
dnsMsgAr =
    DNSMessage {[DnsTriple]
[RR]
Maybe EDNS
Word16
PartialRCODE
Opcode
DNSFlags
forall a. Maybe a
dnsMsgQu :: [DnsTriple]
dnsMsgAn :: [RR]
dnsMsgNs :: [RR]
dnsMsgAr :: [RR]
dnsMsgId :: Word16
dnsMsgOp :: Opcode
dnsMsgRC :: PartialRCODE
dnsMsgFl :: DNSFlags
dnsMsgEx :: forall a. Maybe a
dnsMsgAr :: [RR]
dnsMsgNs :: [RR]
dnsMsgAn :: [RR]
dnsMsgQu :: [DnsTriple]
dnsMsgEx :: Maybe EDNS
dnsMsgFl :: DNSFlags
dnsMsgRC :: PartialRCODE
dnsMsgOp :: Opcode
dnsMsgId :: Word16
..}
  where
    dnsMsgId :: Word16
dnsMsgId = Word16
p_dnsMsgId
    dnsMsgOp :: Opcode
dnsMsgOp = Opcode
p_dnsMsgOp
    dnsMsgRC :: PartialRCODE
dnsMsgRC = PartialRCODE
p_dnsMsgRC
    dnsMsgFl :: DNSFlags
dnsMsgFl = DNSFlags
p_dnsMsgFl
    dnsMsgEx :: Maybe a
dnsMsgEx = Maybe a
forall a. Maybe a
Nothing

mkMsg PartialHeader{Word16
PartialRCODE
Opcode
DNSFlags
p_dnsMsgFl :: PartialHeader -> DNSFlags
p_dnsMsgId :: PartialHeader -> Word16
p_dnsMsgOp :: PartialHeader -> Opcode
p_dnsMsgRC :: PartialHeader -> PartialRCODE
p_dnsMsgId :: Word16
p_dnsMsgOp :: Opcode
p_dnsMsgRC :: PartialRCODE
p_dnsMsgFl :: DNSFlags
..} Yes{Word8
Word16
EDNS
edns :: EDNSData -> EDNS
ext_rc :: EDNSData -> Word8
ext_fl :: EDNSData -> Word16
ext_fl :: Word16
ext_rc :: Word8
edns :: EDNS
..} [DnsTriple]
dnsMsgQu [RR]
dnsMsgAn [RR]
dnsMsgNs [RR]
dnsMsgAr =
    DNSMessage {[DnsTriple]
[RR]
Maybe EDNS
Word16
PartialRCODE
Opcode
DNSFlags
dnsMsgAr :: [RR]
dnsMsgNs :: [RR]
dnsMsgAn :: [RR]
dnsMsgQu :: [DnsTriple]
dnsMsgEx :: Maybe EDNS
dnsMsgFl :: DNSFlags
dnsMsgRC :: PartialRCODE
dnsMsgOp :: Opcode
dnsMsgId :: Word16
dnsMsgQu :: [DnsTriple]
dnsMsgAn :: [RR]
dnsMsgNs :: [RR]
dnsMsgAr :: [RR]
dnsMsgId :: Word16
dnsMsgOp :: Opcode
dnsMsgRC :: PartialRCODE
dnsMsgFl :: DNSFlags
dnsMsgEx :: Maybe EDNS
..}
  where
    dnsMsgId :: Word16
dnsMsgId = Word16
p_dnsMsgId
    dnsMsgOp :: Opcode
dnsMsgOp = Opcode
p_dnsMsgOp
    dnsMsgRC :: PartialRCODE
dnsMsgRC = PartialRCODE -> Word8 -> PartialRCODE
extendRCODE PartialRCODE
p_dnsMsgRC Word8
ext_rc
    dnsMsgFl :: DNSFlags
dnsMsgFl = DNSFlags -> Word16 -> DNSFlags
extendFlags DNSFlags
p_dnsMsgFl Word16
ext_fl
    dnsMsgEx :: Maybe EDNS
dnsMsgEx = EDNS -> Maybe EDNS
forall a. a -> Maybe a
Just EDNS
edns