module Net.DNSBase.Decode.Internal.Domain
( getDomain
, getDomainNC
) where
import qualified Data.ByteString as B
import qualified Data.ByteString.Builder as B
import Net.DNSBase.Decode.Internal.State
import Net.DNSBase.Internal.Domain
import Net.DNSBase.Internal.Util
maxWireLen :: Int
maxWireLen :: Int
maxWireLen = Int
255
getDomainNC :: SGet Domain
getDomainNC :: SGet Domain
getDomainNC = do
(_, bldr) <- Bool -> Int -> SGet (Int, Builder)
getDomain' Bool
False (Int -> SGet (Int, Builder)) -> SGet Int -> SGet (Int, Builder)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SGet Int
getPosition
case buildDomain (Just bldr) of
Just Domain
dom -> Domain -> SGet Domain
forall a. a -> SGet a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Domain
dom
Maybe Domain
Nothing -> String -> SGet Domain
forall a. String -> SGet a
failSGet String
"Internal error"
getDomain :: SGet Domain
getDomain :: SGet Domain
getDomain = do
nc <- SGet Bool
getNameComp
(_, bldr) <- getDomain' nc =<< getPosition
case buildDomain (Just bldr) of
Just Domain
dom -> Domain -> SGet Domain
forall a. a -> SGet a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Domain
dom
Maybe Domain
Nothing -> String -> SGet Domain
forall a. String -> SGet a
failSGet String
"Internal error"
getDomain' :: Bool -> Int -> SGet (Int, B.Builder)
getDomain' :: Bool -> Int -> SGet (Int, Builder)
getDomain' Bool
allowPtr Int
start = do
vl <- SGet Word8
get8
if | vl == 0 -> do
end <- getPosition
getSlice start (end+1)
| vl <= 63 -> do
let len = Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
vl
skipNBytes len
getDomain' allowPtr start
| vl >= 0b1100_0000 -> do
unless allowPtr $ failSGet "domain name compression not allowed in current context"
end <- getPosition
(plen, prefix) <- getSlice start end
vl' <- get8
let offset :: Word16
offset = (Word8 -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8
vl Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
0x3f) Word16 -> Int -> Word16
forall a. Bits a => a -> Int -> a
`shiftL` Int
8) Word16 -> Word16 -> Word16
forall a. Bits a => a -> a -> a
.|. (Word8 -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
vl')
when (fromIntegral offset >= start) $ failSGet "invalid compression pointer"
(slen, suffix) <- getPtr offset
let len = Int
plen Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
slen
when (len > maxWireLen) do
failSGet "domain name too long"
return $ (len, prefix <> suffix)
| otherwise -> failSGet "unsupported label type"
where
getPtr :: Word16 -> SGet (Int, B.Builder)
getPtr :: Word16 -> SGet (Int, Builder)
getPtr Word16
off = Word16 -> SGet (Int, Builder) -> SGet (Int, Builder)
forall a. Word16 -> SGet a -> SGet a
seekSGet Word16
off (SGet (Int, Builder) -> SGet (Int, Builder))
-> SGet (Int, Builder) -> SGet (Int, Builder)
forall a b. (a -> b) -> a -> b
$ SGet Int
getPosition SGet Int -> (Int -> SGet (Int, Builder)) -> SGet (Int, Builder)
forall a b. SGet a -> (a -> SGet b) -> SGet b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Bool -> Int -> SGet (Int, Builder)
getDomain' Bool
allowPtr
getSlice :: Int -> Int -> SGet (Int, B.Builder)
getSlice :: Int -> Int -> SGet (Int, Builder)
getSlice Int
off ((Int -> Int -> Int
forall a. Num a => a -> a -> a
subtract (Int
off Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)) -> Int
len)
| Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 = String -> SGet (Int, Builder)
forall a. String -> SGet a
failSGet String
"negative-length domain name slice"
| Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
maxWireLen = String -> SGet (Int, Builder)
forall a. String -> SGet a
failSGet String
"domain name too long"
| Bool
otherwise = do
buf <- SGet ByteString
getPacket
let slice = Int -> ByteString -> ByteString
B.take Int
len (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> ByteString
B.drop Int
off ByteString
buf
return $ (len, B.byteString slice)