-- |
-- Module      : Net.DNSBase.Decode.Internal.Domain
-- Description : TBD
-- Copyright   : (c) Viktor Dukhovni, 2026
-- License     : BSD-3-Clause
-- Maintainer  : ietf-dane@dukhovni.org
-- Stability   : unstable
module Net.DNSBase.Decode.Internal.Domain
    ( getDomain
    , getDomainNC
    ) where

import qualified Data.ByteString as B
import qualified Data.ByteString.Builder as B

import Net.DNSBase.Decode.Internal.State
import Net.DNSBase.Internal.Domain
import Net.DNSBase.Internal.Util

-- | Wire form length limit, sans final empty root label.
maxWireLen :: Int
maxWireLen :: Int
maxWireLen = Int
255

-- | Parse a wire-form domain with \"No Compression\" (i.e. treat pointer labels
--   as invalid)
--
-- When defining the decoders for newly standardized RData types, it is
-- generally required to use this function to decode transparent domain fields,
-- as name compression is explicitly forbidden for domain fields of future RData
-- types (see 'getDomain' for reference)
getDomainNC :: SGet Domain
getDomainNC :: SGet Domain
getDomainNC = do
    (_, bldr) <- Bool -> Int -> SGet (Int, Builder)
getDomain' Bool
False (Int -> SGet (Int, Builder)) -> SGet Int -> SGet (Int, Builder)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SGet Int
getPosition
    case buildDomain (Just bldr) of
        Just Domain
dom -> Domain -> SGet Domain
forall a. a -> SGet a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Domain
dom
        Maybe Domain
Nothing  -> String -> SGet Domain
forall a. String -> SGet a
failSGet String
"Internal error"

-- | Parse a wire-form domain with name compression (pointer labels) allowed
--
-- This function should only be used when decoding the owner name of resource
-- records, as well as for fields of the initial set of RData types defined in
-- [RFC 1035](https://tools.ietf.org/html/rfc1035) and several others listed in
-- section 4 of [RFC 3597](https://tools.ietf.org/html/rfc3597#section-4),
-- which also states that future RData types MUST NOT use name compression
getDomain :: SGet Domain
getDomain :: SGet Domain
getDomain = do
    -- No name (de)compression if the input is only a message fragment.
    nc <- SGet Bool
getNameComp
    (_, bldr) <- getDomain' nc =<< getPosition
    case buildDomain (Just bldr) of
        Just Domain
dom -> Domain -> SGet Domain
forall a. a -> SGet a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Domain
dom
        Maybe Domain
Nothing  -> String -> SGet Domain
forall a. String -> SGet a
failSGet String
"Internal error"

-- | First octet of a label determines the interpretation of the rest of the label;
--   11XX_XXXX indicates a 14-bit compression pointer composed of the low 6 bits of
--   that octet and the entirety of the next octet, while 00XX_XXXX is used for a
--   standard label to encode its length (<=63). 01XX_XXXX was proposed for extended
--   labels but remains experimental, and 10XX_XXXX is presently undefined. Latest
--   status can be found in
--   [IANA registry](https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#dns-parameters-10)
getDomain' :: Bool -> Int -> SGet (Int, B.Builder)
getDomain' :: Bool -> Int -> SGet (Int, Builder)
getDomain' Bool
allowPtr Int
start = do
    vl <- SGet Word8
get8
    if | vl == 0 -> do
            end <- getPosition
            -- Including the terminal empty label length byte
            getSlice start (end+1)
       | vl <= 63 -> do
            let len = Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
vl
            skipNBytes len
            getDomain' allowPtr start
       | vl >= 0b1100_0000 -> do
            unless allowPtr $ failSGet "domain name compression not allowed in current context"
            end <- getPosition
            (plen, prefix) <- getSlice start end
            vl' <- get8
            let offset :: Word16
                offset = (Word8 -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8
vl Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
0x3f) Word16 -> Int -> Word16
forall a. Bits a => a -> Int -> a
`shiftL` Int
8) Word16 -> Word16 -> Word16
forall a. Bits a => a -> a -> a
.|. (Word8 -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
vl')
            when (fromIntegral offset >= start) $ failSGet "invalid compression pointer"
            (slen, suffix) <- getPtr offset
            let len = Int
plen Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
slen
            when (len > maxWireLen) do
                failSGet "domain name too long"
            return $ (len, prefix <> suffix)
       | otherwise -> failSGet "unsupported label type"
  where
    getPtr :: Word16 -> SGet (Int, B.Builder)
    getPtr :: Word16 -> SGet (Int, Builder)
getPtr Word16
off = Word16 -> SGet (Int, Builder) -> SGet (Int, Builder)
forall a. Word16 -> SGet a -> SGet a
seekSGet Word16
off (SGet (Int, Builder) -> SGet (Int, Builder))
-> SGet (Int, Builder) -> SGet (Int, Builder)
forall a b. (a -> b) -> a -> b
$ SGet Int
getPosition SGet Int -> (Int -> SGet (Int, Builder)) -> SGet (Int, Builder)
forall a b. SGet a -> (a -> SGet b) -> SGet b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Bool -> Int -> SGet (Int, Builder)
getDomain' Bool
allowPtr

    -- get a bytestring slice from position i to position j-1
    getSlice :: Int -> Int -> SGet (Int, B.Builder)
    getSlice :: Int -> Int -> SGet (Int, Builder)
getSlice Int
off ((Int -> Int -> Int
forall a. Num a => a -> a -> a
subtract (Int
off Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)) -> Int
len)
       | Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0          = String -> SGet (Int, Builder)
forall a. String -> SGet a
failSGet String
"negative-length domain name slice"
       | Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
maxWireLen = String -> SGet (Int, Builder)
forall a. String -> SGet a
failSGet String
"domain name too long"
       | Bool
otherwise = do
          buf <- SGet ByteString
getPacket
          let slice = Int -> ByteString -> ByteString
B.take Int
len (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> ByteString
B.drop Int
off ByteString
buf
          return $ (len, B.byteString slice)