-- |
-- Module      : Net.DNSBase.Internal.Util
-- Description : TBD
-- Copyright   : (c) Viktor Dukhovni, 2026
-- License     : BSD-3-Clause
-- Maintainer  : ietf-dane@dukhovni.org
-- Stability   : unstable
module Net.DNSBase.Internal.Util
    ( (&), (.=), (<.>), (<$.>)
    , bool, cond
    , compose4
    , ByteArray(..), baToShortByteString, modifyArray
    , sbsToByteArray, sbsToMutableByteArray
    , Down(..), comparing
    , (.|.), (.&.), clearBit, countLeadingZeros, complement, setBit
    , shiftL, shiftR, testBit, unsafeShiftL, unsafeShiftR
    , (<|>), (>=>), forM, forM_, guard, join, mzero, replicateM, unless, void, when
    , lift, ExceptT(ExceptT), throwE, catchE, runExceptT, withExceptT
    , ByteString, Builder, ShortByteString(..), Text
    , Coercible, coerce
    , Int8, Int16, Int32, Int64
    , Word8, Word16, Word32, Word64, word16be, word32be, word64be, toBE
    , IP(..), IPv4, IPv6, fromIPv4w, fromIPv6b, fromIPv6w, toIPv4w, toIPv6b, toIPv6w
    , All(..), Sum(..)
    , catMaybes, fromMaybe, isJust, isNothing, listToMaybe
    , NonEmpty(..)
    , shows', showsP
    , Type, Typeable, (:~:)(..), Proxy(..), cast, teq
    , allocaBytesAligned, castPtr, copyBytes, byteSwap32
    , fillBytes, minusPtr, peek, peekElemOff, plusForeignPtr
    , unsafePerformFPIO
    ) where

import qualified Data.Primitive.ByteArray as A
import qualified Data.ByteString.Short as SB
import Control.Applicative ((<|>))
import Control.Monad ( (>=>), forM, forM_, guard, join, mzero, replicateM )
import Control.Monad ( unless, void, when )
import Control.Monad.ST (ST)
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.Except (ExceptT(ExceptT), throwE, catchE, runExceptT, withExceptT)
import Data.Array.Byte (ByteArray(..), MutableByteArray(..))
import Data.Bits ((.|.), (.&.), clearBit, countLeadingZeros, complement)
import Data.Bits (setBit, shiftL, shiftR, testBit, unsafeShiftL, unsafeShiftR)
import Data.Bool (bool)
import Data.ByteString (ByteString)
import Data.ByteString.Builder (Builder)
import Data.ByteString.Internal (ByteString(..), accursedUnutterablePerformIO)
import Data.ByteString.Short (ShortByteString(SBS))
import Data.Coerce (Coercible, coerce)
import Data.Function ((&))
import Data.IP (IP(..), IPv4, IPv6)
import Data.IP (fromIPv4w, fromIPv6b, fromIPv6w, toIPv4w, toIPv6b, toIPv6w)
import Data.Int (Int64, Int32, Int16, Int8)
import Data.Kind (Type)
import Data.List.NonEmpty (NonEmpty(..))
import Data.Maybe (catMaybes, fromMaybe, isJust, isNothing, listToMaybe)
import Data.Monoid (All(..), Sum(..))
import Data.Ord (Down(..), comparing)
import Data.Proxy (Proxy(..))
import Data.Text (Text)
import Data.Type.Equality ((:~:)(..), testEquality)
import Data.Typeable (Typeable, cast)
import Data.Word (Word8, Word16, Word32, Word64, byteSwap16, byteSwap32, byteSwap64)
import Foreign (ForeignPtr, Ptr, allocaBytesAligned, castPtr, copyBytes)
import Foreign (fillBytes, minusPtr, peek, peekElemOff, plusForeignPtr)
import GHC.ByteOrder (ByteOrder(..), targetByteOrder)
import GHC.ForeignPtr (unsafeWithForeignPtr)
import Type.Reflection (TypeRep, pattern TypeRep)

(.=) :: Eq b => (a -> b) -> b -> (a -> Bool)
a -> b
f .= :: forall b a. Eq b => (a -> b) -> b -> a -> Bool
.= (!b
x) = (b -> b -> Bool
forall a. Eq a => a -> a -> Bool
==b
x)(b -> Bool) -> (a -> b) -> a -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
.a -> b
f
{-# INLINE (.=) #-}
infix 9 .=

-- | Map over a functor after composition, priority just below that of @'(.)'@.
(<.>) :: Functor m => (b -> c) -> (a -> m b) -> a -> m c
b -> c
f <.> :: forall (m :: * -> *) b c a.
Functor m =>
(b -> c) -> (a -> m b) -> a -> m c
<.> a -> m b
g = (b -> c) -> m b -> m c
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap b -> c
f (m b -> m c) -> (a -> m b) -> a -> m c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> m b
g
{-# INLINE (<.>) #-}
infixr 8 <.>

-- | Right associative <$> with reduced priority.
(<$.>) :: Functor m => (a -> b) -> m a -> m b
<$.> :: forall (m :: * -> *) a b. Functor m => (a -> b) -> m a -> m b
(<$.>) = (a -> b) -> m a -> m b
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap
{-# INLINE (<$.>) #-}
infixr 2 <$.>

compose4 :: (e -> f) -> (a -> b -> c -> d -> e) -> (a -> b -> c -> d -> f)
e -> f
f compose4 :: forall e f a b c d.
(e -> f) -> (a -> b -> c -> d -> e) -> a -> b -> c -> d -> f
`compose4` a -> b -> c -> d -> e
g = \a
a b
b c
c d
d -> e -> f
f (e -> f) -> e -> f
forall a b. (a -> b) -> a -> b
$ a -> b -> c -> d -> e
g a
a b
b c
c d
d
{-# INLINE compose4 #-}

cond :: (a -> Bool) -> (a -> b) -> (a -> b) -> (a -> b)
cond :: forall a b. (a -> Bool) -> (a -> b) -> (a -> b) -> a -> b
cond a -> Bool
p a -> b
f a -> b
g = \a
x -> (a -> b) -> (a -> b) -> Bool -> a -> b
forall a. a -> a -> Bool -> a
bool a -> b
g a -> b
f (a -> Bool
p a
x) a
x
{-# INLINE cond #-}

app_prec :: Int
app_prec :: Int
app_prec = Int
10

-- | Show a constructor or function argument.
shows' :: Show a => a -> ShowS
shows' :: forall a. Show a => a -> ShowS
shows' = Int -> a -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec (Int
app_prec Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)

-- | Show a constructor with arguments.
showsP :: Int -> ShowS -> ShowS
showsP :: Int -> ShowS -> ShowS
showsP = Bool -> ShowS -> ShowS
showParen (Bool -> ShowS -> ShowS) -> (Int -> Bool) -> Int -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
10)

toBE :: (a -> a) -> a -> a
toBE :: forall a. (a -> a) -> a -> a
toBE a -> a
swap !a
x =
  case ByteOrder
targetByteOrder of
    ByteOrder
LittleEndian -> a -> a
swap a
x
    ByteOrder
BigEndian -> a
x
{-# INLINE toBE #-}

-- | Extremely unsafe, uses 'accursedUnutterablePerformIO' from
-- "Data.ByteString.Internal" and comes with all the associated caveats.
unsafePerformFPIO :: ForeignPtr a -> (Ptr a -> IO b) -> b
unsafePerformFPIO :: forall a b. ForeignPtr a -> (Ptr a -> IO b) -> b
unsafePerformFPIO ForeignPtr a
fp = IO b -> b
forall a. IO a -> a
accursedUnutterablePerformIO (IO b -> b) -> ((Ptr a -> IO b) -> IO b) -> (Ptr a -> IO b) -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ForeignPtr a -> (Ptr a -> IO b) -> IO b
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
unsafeWithForeignPtr ForeignPtr a
fp
{-# INLINE unsafePerformFPIO #-}

-- | Caller must ensure the input is exactly 2-bytes long.
word16be :: ByteString -> Word16
word16be :: ByteString -> Word16
word16be (BS ForeignPtr Word8
fp Int
2) = ForeignPtr Word8 -> (Ptr Word8 -> IO Word16) -> Word16
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> b
unsafePerformFPIO ForeignPtr Word8
fp ((Ptr Word8 -> IO Word16) -> Word16)
-> (Ptr Word8 -> IO Word16) -> Word16
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr -> do
    Int -> Int -> (Ptr Word8 -> IO Word16) -> IO Word16
forall a b. Int -> Int -> (Ptr a -> IO b) -> IO b
allocaBytesAligned Int
2 Int
2 ((Ptr Word8 -> IO Word16) -> IO Word16)
-> (Ptr Word8 -> IO Word16) -> IO Word16
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
buf -> do
        Ptr Word8 -> Ptr Word8 -> Int -> IO ()
forall a. Ptr a -> Ptr a -> Int -> IO ()
copyBytes Ptr Word8
buf Ptr Word8
ptr Int
2
        w16 <- Ptr Word16 -> IO Word16
forall a. Storable a => Ptr a -> IO a
peek (Ptr Word16 -> IO Word16) -> Ptr Word16 -> IO Word16
forall a b. (a -> b) -> a -> b
$ Ptr Word8 -> Ptr Word16
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
buf
        pure $ toBE byteSwap16 w16
word16be ByteString
_ = [Char] -> Word16
forall a. HasCallStack => [Char] -> a
error [Char]
"word16be invalid input"
{-# INLINE word16be #-}

-- | Caller must ensure the input is exactly 4-bytes long.
word32be :: ByteString -> Word32
word32be :: ByteString -> Word32
word32be (BS ForeignPtr Word8
fp Int
4) = ForeignPtr Word8 -> (Ptr Word8 -> IO Word32) -> Word32
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> b
unsafePerformFPIO ForeignPtr Word8
fp ((Ptr Word8 -> IO Word32) -> Word32)
-> (Ptr Word8 -> IO Word32) -> Word32
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr -> do
    Int -> Int -> (Ptr Word8 -> IO Word32) -> IO Word32
forall a b. Int -> Int -> (Ptr a -> IO b) -> IO b
allocaBytesAligned Int
4 Int
4 ((Ptr Word8 -> IO Word32) -> IO Word32)
-> (Ptr Word8 -> IO Word32) -> IO Word32
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
buf -> do
        Ptr Word8 -> Ptr Word8 -> Int -> IO ()
forall a. Ptr a -> Ptr a -> Int -> IO ()
copyBytes Ptr Word8
buf Ptr Word8
ptr Int
4
        w32 <- Ptr Word32 -> IO Word32
forall a. Storable a => Ptr a -> IO a
peek (Ptr Word32 -> IO Word32) -> Ptr Word32 -> IO Word32
forall a b. (a -> b) -> a -> b
$ Ptr Word8 -> Ptr Word32
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
buf
        pure $ toBE byteSwap32 w32
word32be ByteString
_ = [Char] -> Word32
forall a. HasCallStack => [Char] -> a
error [Char]
"word32be invalid input"
{-# INLINE word32be #-}

-- | Caller must ensure the input is exactly 8-bytes long.
word64be :: ByteString -> Word64
word64be :: ByteString -> Word64
word64be (BS ForeignPtr Word8
fp Int
8) = ForeignPtr Word8 -> (Ptr Word8 -> IO Word64) -> Word64
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> b
unsafePerformFPIO ForeignPtr Word8
fp ((Ptr Word8 -> IO Word64) -> Word64)
-> (Ptr Word8 -> IO Word64) -> Word64
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr -> do
    Int -> Int -> (Ptr Word8 -> IO Word64) -> IO Word64
forall a b. Int -> Int -> (Ptr a -> IO b) -> IO b
allocaBytesAligned Int
8 Int
8 ((Ptr Word8 -> IO Word64) -> IO Word64)
-> (Ptr Word8 -> IO Word64) -> IO Word64
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
buf -> do
        Ptr Word8 -> Ptr Word8 -> Int -> IO ()
forall a. Ptr a -> Ptr a -> Int -> IO ()
copyBytes Ptr Word8
buf Ptr Word8
ptr Int
8
        w64 <- Ptr Word64 -> IO Word64
forall a. Storable a => Ptr a -> IO a
peek (Ptr Word64 -> IO Word64) -> Ptr Word64 -> IO Word64
forall a b. (a -> b) -> a -> b
$ Ptr Word8 -> Ptr Word64
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
buf
        pure $ toBE byteSwap64 w64
word64be ByteString
_ = [Char] -> Word64
forall a. HasCallStack => [Char] -> a
error [Char]
"word64be invalid input"
{-# INLINE word64be #-}

----- Type equality

teq :: forall a -> forall b -> (Typeable a, Typeable b) => Maybe (a :~: b)
teq :: forall {k}.
forall (a :: k) (b :: k) ->
(Typeable a, Typeable b) => Maybe (a :~: b)
teq a b = TypeRep a -> TypeRep b -> Maybe (a :~: b)
forall (a :: k) (b :: k). TypeRep a -> TypeRep b -> Maybe (a :~: b)
forall {k} (f :: k -> *) (a :: k) (b :: k).
TestEquality f =>
f a -> f b -> Maybe (a :~: b)
testEquality (rep a) (rep b)
  where
    rep :: forall c -> Typeable c => TypeRep c
    rep :: forall {k}. forall (c :: k) -> Typeable c => TypeRep c
rep _ = TypeRep c
forall {k}. forall (c :: k) -> Typeable c => TypeRep c
TypeRep
{-# INLINE teq #-}

----- Wrappers around "primitive" API

baToShortByteString :: ByteArray -> ShortByteString
baToShortByteString :: ByteArray -> ShortByteString
baToShortByteString (ByteArray ByteArray#
ba) = ByteArray# -> ShortByteString
SBS ByteArray#
ba

modifyArray :: MutableByteArray s -> Int -> (Word8 -> Word8) -> ST s ()
modifyArray :: forall s. MutableByteArray s -> Int -> (Word8 -> Word8) -> ST s ()
modifyArray MutableByteArray s
marr Int
i Word8 -> Word8
f = MutableByteArray (PrimState (ST s)) -> Int -> ST s Word8
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> m a
A.readByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
marr Int
i ST s Word8 -> (Word8 -> ST s ()) -> ST s ()
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
A.writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
marr Int
i (Word8 -> ST s ()) -> (Word8 -> Word8) -> Word8 -> ST s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word8 -> Word8
f

sbsToByteArray :: ShortByteString -> ByteArray
sbsToByteArray :: ShortByteString -> ByteArray
sbsToByteArray (SBS ByteArray#
ba) = (ByteArray# -> ByteArray
ByteArray ByteArray#
ba)

sbsToMutableByteArray :: ShortByteString -> ST s (MutableByteArray s)
sbsToMutableByteArray :: forall s. ShortByteString -> ST s (MutableByteArray s)
sbsToMutableByteArray sb :: ShortByteString
sb@(SBS ByteArray#
ba) =
    ByteArray
-> Int -> Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
ByteArray -> Int -> Int -> m (MutableByteArray (PrimState m))
A.thawByteArray (ByteArray# -> ByteArray
ByteArray ByteArray#
ba) Int
0 (ShortByteString -> Int
SB.length ShortByteString
sb)