{-|
Module      : Net.DNSBase.EDNS.Option.ECS
Description : EDNS Client Subnet option (RFC 7871)
Copyright   : (c) Viktor Dukhovni, 2026
License     : BSD-3-Clause
Maintainer  : ietf-dane@dukhovni.org
Stability   : unstable

The Client Subnet EDNS option lets a recursive resolver forward
a prefix of the original client's address to an authoritative
server, so that authority-side answers (CDNs and similar) can
be tailored to the client's network location.  The
specification, including privacy considerations, is
[RFC 7871](https://datatracker.ietf.org/doc/html/rfc7871).
-}

module Net.DNSBase.EDNS.Option.ECS
    ( O_ecs(..)
    ) where

import Net.DNSBase.Decode.Internal.State
import Net.DNSBase.EDNS.Internal.OptNum
import Net.DNSBase.EDNS.Internal.Option
import Net.DNSBase.Encode.Internal.State
import Net.DNSBase.Encode.Internal.Metric
import Net.DNSBase.Internal.Present
import Net.DNSBase.Internal.Util

-- | The Client Subnet EDNS option
-- ([RFC 7871 section 6](https://tools.ietf.org/html/rfc7871#section-6))
-- — three fields: a source prefix length, a scope prefix length,
-- and a (possibly truncated) IP address whose 16-bit @FAMILY@
-- field is implicit in the constructor's 'IP' value (@1@ for
-- 'IPv4', @2@ for 'IPv6').
--
-- >            +0 (MSB)                            +1 (LSB)
-- >  +---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+
-- >  |                          OPTION-CODE                          |
-- >  +---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+
-- >  |                         OPTION-LENGTH                         |
-- >  +---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+
-- >  |                            FAMILY                             |
-- >  +---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+
-- >  |     SOURCE PREFIX-LENGTH      |     SCOPE PREFIX-LENGTH       |
-- >  +---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+
-- >  |                           ADDRESS...                          /
-- >  +---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+
--
-- The address is masked and truncated to the source prefix length
-- on encode and zero-padded on decode.  A @FAMILY@ value other
-- than 1 or 2 fails the decoder (a future revision may decode
-- such values as an opaque option instead).
data O_ecs = O_ECS Word8 Word8 IP deriving (O_ecs -> O_ecs -> Bool
(O_ecs -> O_ecs -> Bool) -> (O_ecs -> O_ecs -> Bool) -> Eq O_ecs
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: O_ecs -> O_ecs -> Bool
== :: O_ecs -> O_ecs -> Bool
$c/= :: O_ecs -> O_ecs -> Bool
/= :: O_ecs -> O_ecs -> Bool
Eq, Int -> O_ecs -> ShowS
[O_ecs] -> ShowS
O_ecs -> String
(Int -> O_ecs -> ShowS)
-> (O_ecs -> String) -> ([O_ecs] -> ShowS) -> Show O_ecs
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> O_ecs -> ShowS
showsPrec :: Int -> O_ecs -> ShowS
$cshow :: O_ecs -> String
show :: O_ecs -> String
$cshowList :: [O_ecs] -> ShowS
showList :: [O_ecs] -> ShowS
Show)

instance Presentable O_ecs where
    present :: O_ecs -> Builder -> Builder
present (O_ECS Word8
srcbits Word8
scopebits IP
ip) =
        Word8 -> Builder -> Builder
forall a. Presentable a => a -> Builder -> Builder
present Word8
srcbits
        (Builder -> Builder) -> (Builder -> Builder) -> Builder -> Builder
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word8 -> Builder -> Builder
forall a. Presentable a => a -> Builder -> Builder
presentSp Word8
scopebits
        (Builder -> Builder) -> (Builder -> Builder) -> Builder -> Builder
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IP -> Builder -> Builder
forall a. Presentable a => a -> Builder -> Builder
presentSp IP
ip

instance KnownEdnsOption O_ecs where
    optNum :: forall b -> (b ~ O_ecs) => OptNum
optNum _ = OptNum
ECS
    {-# INLINE optNum #-}
    optEncode :: forall s r. (Typeable r, Eq r, Show r) => O_ecs -> SPut s r
optEncode (O_ECS Word8
srcbits Word8
scopebits IP
ip) = case IP
ip of
        IPv4 IPv4
ip4 -> do
                    -- XXX: More precise error?
                    Bool -> SPut s r -> SPut s r
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Word8
srcbits Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
< Word8
0 Bool -> Bool -> Bool
|| Word8
srcbits Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
> Word8
32) (SPut s r -> SPut s r) -> SPut s r -> SPut s r
forall a b. (a -> b) -> a -> b
$
                        (forall a. ErrorContext a => a -> EncodeErr a) -> SPut s r
forall r s.
ErrorContext r =>
(forall a. ErrorContext a => a -> EncodeErr a) -> SPut s r
failWith a -> EncodeErr a
forall r. (Typeable r, Show r, Eq r) => r -> EncodeErr r
forall a. ErrorContext a => a -> EncodeErr a
CantEncode
                    let (!Word8
q, !Word8
r) = (Word8
srcbits Word8 -> Word8 -> Word8
forall a. Num a => a -> a -> a
+ Word8
7) Word8 -> Word8 -> (Word8, Word8)
forall a. Integral a => a -> a -> (a, a)
`quotRem` Word8
8
                        !w :: Word32
w = IPv4 -> Word32
fromIPv4w IPv4
ip4
                    SizedBuilder -> SPut s r
forall r s. ErrorContext r => SizedBuilder -> SPut s r
putSizedBuilder (SizedBuilder -> SPut s r) -> SizedBuilder -> SPut s r
forall a b. (a -> b) -> a -> b
$
                        Word16 -> SizedBuilder
mbWord16 Word16
1
                        SizedBuilder -> SizedBuilder -> SizedBuilder
forall a. Semigroup a => a -> a -> a
<> Word8 -> SizedBuilder
mbWord8 Word8
srcbits
                        SizedBuilder -> SizedBuilder -> SizedBuilder
forall a. Semigroup a => a -> a -> a
<> Word8 -> SizedBuilder
mbWord8 Word8
scopebits
                        SizedBuilder -> SizedBuilder -> SizedBuilder
forall a. Semigroup a => a -> a -> a
<> Word32 -> Word8 -> Word8 -> SizedBuilder
encWord Word32
w Word8
q Word8
r
        IPv6 IPv6
ip6 -> do
                    -- XXX: More precise error?
                    Bool -> SPut s r -> SPut s r
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Word8
srcbits Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
< Word8
0 Bool -> Bool -> Bool
|| Word8
srcbits Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
> Word8
128) (SPut s r -> SPut s r) -> SPut s r -> SPut s r
forall a b. (a -> b) -> a -> b
$
                        (forall a. ErrorContext a => a -> EncodeErr a) -> SPut s r
forall r s.
ErrorContext r =>
(forall a. ErrorContext a => a -> EncodeErr a) -> SPut s r
failWith a -> EncodeErr a
forall r. (Typeable r, Show r, Eq r) => r -> EncodeErr r
forall a. ErrorContext a => a -> EncodeErr a
CantEncode
                    let (!Word8
q, !Word8
r) = (Word8
srcbits Word8 -> Word8 -> Word8
forall a. Num a => a -> a -> a
+ Word8
7) Word8 -> Word8 -> (Word8, Word8)
forall a. Integral a => a -> a -> (a, a)
`quotRem` Word8
8
                        (!Word32
w0, !Word32
w1, !Word32
w2, !Word32
w3) = IPv6 -> (Word32, Word32, Word32, Word32)
fromIPv6w IPv6
ip6
                    SizedBuilder -> SPut s r
forall r s. ErrorContext r => SizedBuilder -> SPut s r
putSizedBuilder (SizedBuilder -> SPut s r) -> SizedBuilder -> SPut s r
forall a b. (a -> b) -> a -> b
$
                        Word16 -> SizedBuilder
mbWord16 Word16
2
                        SizedBuilder -> SizedBuilder -> SizedBuilder
forall a. Semigroup a => a -> a -> a
<> Word8 -> SizedBuilder
mbWord8 Word8
srcbits
                        SizedBuilder -> SizedBuilder -> SizedBuilder
forall a. Semigroup a => a -> a -> a
<> Word8 -> SizedBuilder
mbWord8 Word8
scopebits
                        SizedBuilder -> SizedBuilder -> SizedBuilder
forall a. Semigroup a => a -> a -> a
<> Word32 -> Word8 -> Word8 -> SizedBuilder
encWord Word32
w0 Word8
q Word8
r
                        SizedBuilder -> SizedBuilder -> SizedBuilder
forall a. Semigroup a => a -> a -> a
<> Word32 -> Word8 -> Word8 -> SizedBuilder
encWord Word32
w1 (Word8
q Word8 -> Word8 -> Word8
forall a. Num a => a -> a -> a
- Word8
4) Word8
r
                        SizedBuilder -> SizedBuilder -> SizedBuilder
forall a. Semigroup a => a -> a -> a
<> Word32 -> Word8 -> Word8 -> SizedBuilder
encWord Word32
w2 (Word8
q Word8 -> Word8 -> Word8
forall a. Num a => a -> a -> a
- Word8
8) Word8
r
                        SizedBuilder -> SizedBuilder -> SizedBuilder
forall a. Semigroup a => a -> a -> a
<> Word32 -> Word8 -> Word8 -> SizedBuilder
encWord Word32
w3 (Word8
q Word8 -> Word8 -> Word8
forall a. Num a => a -> a -> a
- Word8
12) Word8
r
    optDecode :: forall b ->
(b ~ O_ecs) => OptionExtensionVal b -> Int -> SGet EdnsOption
optDecode _ OptionExtensionVal b
_ = Int -> SGet EdnsOption
getECS

encWord :: Word32 -> Word8 -> Word8 -> SizedBuilder
encWord :: Word32 -> Word8 -> Word8 -> SizedBuilder
encWord !Word32
w !Word8
q !Word8
r = case Word8 -> Word8 -> Word8
forall a. Ord a => a -> a -> a
min Word8
4 Word8
q of
    Word8
4 -> Word32 -> SizedBuilder
mbWord32 (Word32
w Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
mask)
    Word8
3 -> (Word16 -> SizedBuilder
mbWord16 (Word16 -> SizedBuilder)
-> (Word32 -> Word16) -> Word32 -> SizedBuilder
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word32 -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral) (Word32
w Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
16) SizedBuilder -> SizedBuilder -> SizedBuilder
forall a. Semigroup a => a -> a -> a
<>
         (Word8 -> SizedBuilder
mbWord8  (Word8 -> SizedBuilder)
-> (Word32 -> Word8) -> Word32 -> SizedBuilder
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral) ((Word32
w Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
8) Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
mask)
    Word8
2 -> (Word16 -> SizedBuilder
mbWord16 (Word16 -> SizedBuilder)
-> (Word32 -> Word16) -> Word32 -> SizedBuilder
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word32 -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral) ((Word32
w Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
16) Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
mask)
    Word8
1 -> (Word8 -> SizedBuilder
mbWord8  (Word8 -> SizedBuilder)
-> (Word32 -> Word8) -> Word32 -> SizedBuilder
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral) ((Word32
w Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
24) Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
mask)
    Word8
_ -> SizedBuilder
forall a. Monoid a => a
mempty
  where
    mask :: Word32
mask | Word8
q Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
<= Word8
4
         , Int
s <- Word8 -> Int
forall a. Enum a => a -> Int
fromEnum (Word8 -> Int) -> Word8 -> Int
forall a b. (a -> b) -> a -> b
$ Word8
7 Word8 -> Word8 -> Word8
forall a. Num a => a -> a -> a
- Word8
r
         , Int
s Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0    = (Word32
0xffff_ffff Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
s) Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
s
         | Bool
otherwise = Word32
0xffff_ffff

-- | Decode an EDNS Client Subnet (ECS) option according to the provided
-- OPTION-LENGTH Parameter to determine how many bytes the address has been
-- truncated to.
--
-- Values of the FAMILY field other than 1 (IPv4) or 2 (IPv6) are rejected
-- and cause the decoder to fail.
getECS :: Int -- ^ OPTION-LENGTH field
       -> SGet EdnsOption
getECS :: Int -> SGet EdnsOption
getECS Int
n = do
    ecs_family <- SGet Word16
get16
    ecs_source <- get8
    ecs_scope  <- get8
    case ecs_family of
        Word16
1 -> do
            ecs_addr <- Int -> SGet IPv4
getIPv4Net (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
4)
            return $ EdnsOption $ O_ECS ecs_source ecs_scope (IPv4 ecs_addr)
        Word16
2 -> do
            ecs_addr <- Int -> SGet IPv6
getIPv6Net (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
4)
            return $ EdnsOption $ O_ECS ecs_source ecs_scope (IPv6 ecs_addr)
        Word16
f -> String -> SGet EdnsOption
forall a. String -> SGet a
failSGet (String -> SGet EdnsOption) -> String -> SGet EdnsOption
forall a b. (a -> b) -> a -> b
$ String
"unsupported ECS family " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Word16 -> String
forall a. Show a => a -> String
show Word16
f
        -- XXX : consider using alternate constructor instead of failure