-- |
-- Module      : Net.DNSBase.Internal.NameComp
-- Description : TBD
-- Copyright   : (c) Viktor Dukhovni, 2026
-- License     : BSD-3-Clause
-- Maintainer  : ietf-dane@dukhovni.org
-- Stability   : unstable
module Net.DNSBase.Internal.NameComp
    ( NCTree
    , empty
    , insert
    , lookup
    ) where

import Prelude hiding (lookup)
import Control.Monad.ST as ST
import qualified Data.ByteString as B
import qualified Data.HashTable.ST.Basic as LH
import qualified Data.HashTable.Class as H

type Path = [B.ByteString]
type Map s = LH.HashTable s B.ByteString
data NCTree s = NCTree (Map s (NCTree s)) Int

-- | Create a root node with given value
empty :: Int -> ST.ST s (NCTree s)
empty :: forall s. Int -> ST s (NCTree s)
empty Int
n = (Map s (NCTree s) -> Int -> NCTree s)
-> Int -> Map s (NCTree s) -> NCTree s
forall a b c. (a -> b -> c) -> b -> a -> c
flip Map s (NCTree s) -> Int -> NCTree s
forall s. Map s (NCTree s) -> Int -> NCTree s
NCTree Int
n (Map s (NCTree s) -> NCTree s)
-> ST s (Map s (NCTree s)) -> ST s (NCTree s)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ST s (Map s (NCTree s))
forall s k v. ST s (HashTable s k v)
forall (h :: * -> * -> * -> *) s k v. HashTable h => ST s (h s k v)
H.new

-- | Insert a domain with the given labels with last label ending at given
-- offset (if stored uncompressed).  The caller MUST not insert any paths whose
-- tail lies beyond the first 16k of the DNS message.  That is the end offset
-- must not exceed @0x3fff@.
insert :: Path -> Int -> NCTree s -> ST.ST s ()
insert :: forall s. Path -> Int -> NCTree s -> ST s ()
insert = Path -> Int -> NCTree s -> ST s ()
forall s. Path -> Int -> NCTree s -> ST s ()
go
  where
    go :: Path -> Int  -> NCTree s -> ST.ST s ()
    go :: forall s. Path -> Int -> NCTree s -> ST s ()
go (!ByteString
l:Path
ls) !Int
end !(NCTree Map s (NCTree s)
m Int
_) =
        Map s (NCTree s)
-> ByteString
-> (Maybe (NCTree s) -> ST s (Maybe (NCTree s), ()))
-> ST s ()
forall k s v a.
(Eq k, Hashable k) =>
HashTable s k v -> k -> (Maybe v -> ST s (Maybe v, a)) -> ST s a
forall (h :: * -> * -> * -> *) k s v a.
(HashTable h, Eq k, Hashable k) =>
h s k v -> k -> (Maybe v -> ST s (Maybe v, a)) -> ST s a
H.mutateST Map s (NCTree s)
m ByteString
l ((Maybe (NCTree s) -> ST s (Maybe (NCTree s), ())) -> ST s ())
-> (Maybe (NCTree s) -> ST s (Maybe (NCTree s), ())) -> ST s ()
forall a b. (a -> b) -> a -> b
$ Path -> Int -> Maybe (NCTree s) -> ST s (Maybe (NCTree s), ())
forall s.
Path -> Int -> Maybe (NCTree s) -> ST s (Maybe (NCTree s), ())
alter Path
ls (Int -> Maybe (NCTree s) -> ST s (Maybe (NCTree s), ()))
-> Int -> Maybe (NCTree s) -> ST s (Maybe (NCTree s), ())
forall a b. (a -> b) -> a -> b
$! Int
end Int -> Int -> Int
forall a. Num a => a -> a -> a
- ByteString -> Int
B.length ByteString
l Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
    go Path
_ Int
_ NCTree s
_ = () -> ST s ()
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

    -- | Alter (create or update) the node with given index and start position
    alter :: Path -> Int -> Maybe (NCTree s) -> ST.ST s (Maybe (NCTree s), ())
    alter :: forall s.
Path -> Int -> Maybe (NCTree s) -> ST s (Maybe (NCTree s), ())
alter !Path
ls !Int
start !Maybe (NCTree s)
old = case Maybe (NCTree s)
old of
        -- At existing intermediate nodes recurse to store the rest of the path
        Just  NCTree s
n | Path -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Path
ls   -> (Maybe (NCTree s), ()) -> ST s (Maybe (NCTree s), ())
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Maybe (NCTree s), ()) -> ST s (Maybe (NCTree s), ()))
-> (Maybe (NCTree s), ()) -> ST s (Maybe (NCTree s), ())
forall a b. (a -> b) -> a -> b
$ NCTree s -> (Maybe (NCTree s), ())
forall {a}. a -> (Maybe a, ())
node NCTree s
n
                | Bool
otherwise -> NCTree s -> (Maybe (NCTree s), ())
forall {a}. a -> (Maybe a, ())
node NCTree s
n (Maybe (NCTree s), ()) -> ST s () -> ST s (Maybe (NCTree s), ())
forall a b. a -> ST s b -> ST s a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ Path -> Int -> NCTree s -> ST s ()
forall s. Path -> Int -> NCTree s -> ST s ()
go Path
ls Int
start NCTree s
n
        -- In new intermediate nodes store the tip offset + distance from tip
        Maybe (NCTree s)
Nothing | Path -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Path
ls   -> NCTree s -> (Maybe (NCTree s), ())
forall {a}. a -> (Maybe a, ())
node (NCTree s -> (Maybe (NCTree s), ()))
-> ST s (NCTree s) -> ST s (Maybe (NCTree s), ())
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> ST s (NCTree s)
forall s. Int -> ST s (NCTree s)
empty Int
start
                | Bool
otherwise -> do
                    e <- Int -> ST s (NCTree s)
forall s. Int -> ST s (NCTree s)
empty Int
start
                    node e <$ go ls start e
      where
        node :: a -> (Maybe a, ())
node a
n = (a -> Maybe a
forall a. a -> Maybe a
Just a
n, ())

-- | Return the length of the path prefix (domain suffix) and corresponding
-- offset for the input path (reversed list of wire-form labels), counting both
-- the length (1) and payload size of each label, not including the terminal
-- NUL label.
lookup :: Path -> (NCTree s) -> ST.ST s (Int, Int)
lookup :: forall s. Path -> NCTree s -> ST s (Int, Int)
lookup Path
labels NCTree s
root = Path -> NCTree s -> Int -> ST s (Int, Int)
forall {s}. Path -> NCTree s -> Int -> ST s (Int, Int)
go Path
labels NCTree s
root Int
0
  where
    go :: Path -> NCTree s -> Int -> ST s (Int, Int)
go (!ByteString
l:Path
ls) !(NCTree Map s (NCTree s)
m Int
off) !Int
slen = do
        mn <- Map s (NCTree s) -> ByteString -> ST s (Maybe (NCTree s))
forall k s v.
(Eq k, Hashable k) =>
HashTable s k v -> k -> ST s (Maybe v)
forall (h :: * -> * -> * -> *) k s v.
(HashTable h, Eq k, Hashable k) =>
h s k v -> k -> ST s (Maybe v)
H.lookup Map s (NCTree s)
m ByteString
l
        case mn of
            Just NCTree s
n  -> Path -> NCTree s -> Int -> ST s (Int, Int)
go Path
ls NCTree s
n (Int -> ST s (Int, Int)) -> Int -> ST s (Int, Int)
forall a b. (a -> b) -> a -> b
$! Int
slen Int -> Int -> Int
forall a. Num a => a -> a -> a
+ ByteString -> Int
B.length ByteString
l Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
            Maybe (NCTree s)
_       -> (Int, Int) -> ST s (Int, Int)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
slen, Int
off)
    go Path
_ (NCTree Map s (NCTree s)
_ Int
off) Int
slen = (Int, Int) -> ST s (Int, Int)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
slen, Int
off)