-- |
-- Module      : Net.DNSBase.Decode.Internal.State
-- Description : Decoder state monad and wire-format primitives
-- Copyright   : (c) IIJ Innovation Institute Inc., 2009
--               (c) Viktor Dukhovni, 2020-2026
-- License     : BSD-3-Clause
-- Maintainer  : ietf-dane@dukhovni.org
-- Stability   : unstable
{-# LANGUAGE RecordWildCards #-}

module Net.DNSBase.Decode.Internal.State
    (
    -- * DNS message element parser
      SGet
    -- * Internal state accessors
    , getPosition
    , getPacket
    , getChrono
    , getNameComp
    -- ** Deduplication support
    , getLastOwner
    , getLastCname
    , setLastOwner
    , setLastCname
    -- ** Setting a non-default error context
    , setDecodeSection
    , setDecodeTriple
    , setDecodeSource
    , local
    -- * Generic low-level decoders
    , get8
    , get16
    , get32
    , get64
    , getInt8
    , getInt16
    -- * DNS-specific low-level decoders
    , getIPv4
    , getIPv4Net
    , getIPv6
    , getIPv6Net
    , getDnsTime
    -- * Octet-string decoders
    , skipNBytes
    , getNBytes
    , getShortByteString
    , getShortNByteString
    , getShortByteStringLen8
    , getShortByteStringLen16
    , getUtf8Text
    , getUtf8TextLen8
    , getUtf8TextLen16
    -- * Sequence decoders
    , getVarWidthSequence
    , getFixedWidthSequence
    -- * Decoder sandboxing
    , seekSGet
    , fitSGet
    -- * Decoder failure
    , failSGet
    -- * Decoder driver
    , decodeAtWith
    ) where

import qualified Data.ByteString as B
import qualified Data.ByteString.Short as SB
import qualified Data.ByteString.Unsafe as B
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import Data.ByteString.Internal (ByteString(..))

import Net.DNSBase.Internal.Domain
import Net.DNSBase.Internal.Error
import Net.DNSBase.Internal.Util

-----------

-- | The 'SGet' monad internal reader environment
data SGetEnv = SGetEnv
    { SGetEnv -> ByteString
psPacket   :: ByteString
    , SGetEnv -> Int64
psChrono   :: Int64
    , SGetEnv -> Bool
psNameComp :: Bool
    , SGetEnv -> DnsSection
psSection  :: DnsSection
    , SGetEnv -> Maybe DnsTriple
psTriple   :: Maybe DnsTriple
    , SGetEnv -> Maybe MessageSource
psSource   :: Maybe MessageSource
    }

-- | The 'SGet' monad internal state
data SGetState = SGetState
    { SGetState -> Int
psOffset    :: Int
    , SGetState -> Int
psLength    :: Int
    , SGetState -> Domain
psLastOwner :: Domain
    , SGetState -> Domain
psLastCname :: Domain
    }

-- | Abort the decoder with a 'DecodeError' carrying the given
-- diagnostic message and the current 'DecodeContext' (message
-- section, RR triple, and source address) drawn from the reader
-- environment.  Used by RR-data parsers when the wire bytes
-- don't conform to the expected shape.
failSGet :: String -> SGet a
failSGet :: forall a. String -> SGet a
failSGet String
msg = do
    SGetEnv { psSection = decodeSection
            , psTriple  = decodeTriple
            , psSource  = decodeSource } <- SGet SGetEnv
ask
    throw $ DecodeError DecodeContext {..} msg

-------------

-- | Consumes and returns a 'B.ByteString' of length @n@ from the buffer
--
-- Fails if this would back-track or over-run.
getNByteString :: Int -> SGet ByteString
getNByteString :: Int -> SGet ByteString
getNByteString Int
n | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = ByteString -> SGet ByteString
forall a. a -> SGet a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
B.empty
getNByteString Int
n | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0  = do
    s <- SGet SGetState
get
    when (psLength s < n) do failSGet "requested bytecount exceeds available"
    modify' \SGetState
t -> SGetState
t { psOffset = psOffset s + n
                    , psLength = psLength s - n}
    (BS fp _) <- asks psPacket
    pure $! BS (fp `plusForeignPtr` psOffset s) n
getNByteString Int
_ = String -> SGet ByteString
forall a. String -> SGet a
failSGet String
"negative bytecount requested"
{-# INLINE getNByteString #-}

-- | Consumes and discards @n@ bytes of input from the buffer
skipNBytes :: Int -> SGet ()
skipNBytes :: Int -> SGet ()
skipNBytes Int
n | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0 = do
    s <- SGet SGetState
get
    when (psLength s < n) do
        failSGet "requested skip bytecount exceeds available"
    when (n > 0) do
        modify' $ \SGetState
t -> SGetState
t { psOffset = psOffset s + n
                          , psLength = psLength s - n }
skipNBytes Int
_ = String -> SGet ()
forall a. String -> SGet a
failSGet String
"negative skip bytecount requested"
{-# INLINE skipNBytes #-}

-- | Returns the current position relative to the start of the internal buffer
getPosition :: SGet Int
getPosition :: SGet Int
getPosition = (SGetState -> Int) -> SGet Int
forall a. (SGetState -> a) -> SGet a
gets SGetState -> Int
psOffset
{-# INLINE getPosition #-}

-- | Returns the entire contents of the internal buffer
getPacket :: SGet ByteString
getPacket :: SGet ByteString
getPacket = (SGetEnv -> ByteString) -> SGet ByteString
forall a. (SGetEnv -> a) -> SGet a
asks SGetEnv -> ByteString
psPacket
{-# INLINE getPacket #-}

-- | Returns the epoch-relative time passed to 'decodeAtWith'
getChrono :: SGet Int64
getChrono :: SGet Int64
getChrono = (SGetEnv -> Int64) -> SGet Int64
forall a. (SGetEnv -> a) -> SGet a
asks SGetEnv -> Int64
psChrono
{-# INLINE getChrono #-}

-- | Returns whether name (de)compression is applicable to the input buffer.
--
-- Generally true for full DNS messages, and false for data blobs encoded in
-- isolation.
getNameComp :: SGet Bool
getNameComp :: SGet Bool
getNameComp = (SGetEnv -> Bool) -> SGet Bool
forall a. (SGetEnv -> a) -> SGet a
asks SGetEnv -> Bool
psNameComp
{-# INLINE getNameComp #-}

getLastOwner, getLastCname :: SGet Domain
setLastOwner, setLastCname :: Domain -> SGet Domain
getLastOwner :: SGet Domain
getLastOwner = (SGetState -> Domain) -> SGet Domain
forall a. (SGetState -> a) -> SGet a
gets SGetState -> Domain
psLastOwner
getLastCname :: SGet Domain
getLastCname = (SGetState -> Domain) -> SGet Domain
forall a. (SGetState -> a) -> SGet a
gets SGetState -> Domain
psLastCname
setLastOwner :: Domain -> SGet Domain
setLastOwner Domain
d = Domain
d Domain -> SGet () -> SGet Domain
forall a b. a -> SGet b -> SGet a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ (SGetState -> SGetState) -> SGet ()
modify' \ SGetState
s -> SGetState
s { psLastOwner = d }
setLastCname :: Domain -> SGet Domain
setLastCname Domain
d = Domain
d Domain -> SGet () -> SGet Domain
forall a b. a -> SGet b -> SGet a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ (SGetState -> SGetState) -> SGet ()
modify' \ SGetState
s -> SGetState
s { psLastCname = d }
{-# INLINE getLastOwner #-}
{-# INLINE getLastCname #-}
{-# INLINE setLastOwner #-}
{-# INLINE setLastCname #-}

-- | Set message section for error reporting.
setDecodeSection :: DnsSection -> SGetEnv -> SGetEnv
setDecodeSection :: DnsSection -> SGetEnv -> SGetEnv
setDecodeSection DnsSection
s SGetEnv
env = SGetEnv
env {psSection = s}

-- | Set current RRSet name, type, class for error reporting.
setDecodeTriple :: DnsTriple -> SGetEnv -> SGetEnv
setDecodeTriple :: DnsTriple -> SGetEnv -> SGetEnv
setDecodeTriple DnsTriple
t SGetEnv
env = SGetEnv
env {psTriple = Just t}

-- | Set message source for error reporting.
setDecodeSource :: MessageSource -> SGetEnv -> SGetEnv
setDecodeSource :: MessageSource -> SGetEnv -> SGetEnv
setDecodeSource MessageSource
s SGetEnv
env = SGetEnv
env {psSource = Just s}

--------------------------------

-- | Consumes one octet and returns it as a 'Word8'
get8 :: SGet Word8
get8 :: SGet Word8
get8 = ByteString -> Int -> Word8
B.unsafeIndex (ByteString -> Int -> Word8)
-> SGet ByteString -> SGet (Int -> Word8)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SGetEnv -> ByteString) -> SGet ByteString
forall a. (SGetEnv -> a) -> SGet a
asks SGetEnv -> ByteString
psPacket SGet (Int -> Word8) -> SGet Int -> SGet Word8
forall a b. SGet (a -> b) -> SGet a -> SGet b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ((SGetState -> Int) -> SGet Int
forall a. (SGetState -> a) -> SGet a
gets SGetState -> Int
psOffset SGet Int -> SGet () -> SGet Int
forall a b. SGet a -> SGet b -> SGet a
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Int -> SGet ()
skipNBytes Int
1)
{-# INLINE get8 #-}

-- | Load a 16-bit big-endian word.
get16 :: SGet Word16
get16 :: SGet Word16
get16 = ByteString -> Word16
word16be (ByteString -> Word16) -> SGet ByteString -> SGet Word16
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> SGet ByteString
getNByteString Int
2
{-# INLINE get16 #-}

-- | Load a 32-bit big-endian word.
get32 :: SGet Word32
get32 :: SGet Word32
get32 = ByteString -> Word32
word32be (ByteString -> Word32) -> SGet ByteString -> SGet Word32
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> SGet ByteString
getNByteString Int
4
{-# INLINE get32 #-}

-- | Load a 64-bit big-endian word.
get64 :: SGet Word64
get64 :: SGet Word64
get64 = ByteString -> Word64
word64be (ByteString -> Word64) -> SGet ByteString -> SGet Word64
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> SGet ByteString
getNByteString Int
8
{-# INLINE get64 #-}

-- | Consumes one octet and returns it as an 'Int'
getInt8 :: SGet Int
getInt8 :: SGet Int
getInt8 = Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8 -> Int) -> SGet Word8 -> SGet Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SGet Word8
get8
{-# INLINE getInt8 #-}

-- | Consumes two octets and returns them as an 'Int'
-- computed using network byte order
getInt16 :: SGet Int
getInt16 :: SGet Int
getInt16 = Word16 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word16 -> Int) -> SGet Word16 -> SGet Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SGet Word16
get16
{-# INLINE getInt16 #-}

--  Not implemented, risks sign overflow on 32-bit systems.
--
--  -- | Consumes four octets and returns them as an 'Int'
--  -- computed using network byte order
--  getInt32 :: SGet Int
--  getInt32 = fromIntegral <$> get32
--  {-# INLINE getInt32 #-}

----

-- | Reads 4 octets and returns them as an 'IPv4' address
getIPv4 :: SGet IPv4
getIPv4 :: SGet IPv4
getIPv4 = Word32 -> IPv4
toIPv4w (Word32 -> IPv4) -> SGet Word32 -> SGet IPv4
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SGet Word32
get32
{-# INLINE getIPv4 #-}

-- | Reads 16 octets and returns them as an 'IPv6' address
getIPv6 :: SGet IPv6
getIPv6 :: SGet IPv6
getIPv6 = (Word32, Word32, Word32, Word32) -> IPv6
toIPv6w ((Word32, Word32, Word32, Word32) -> IPv6)
-> SGet (Word32, Word32, Word32, Word32) -> SGet IPv6
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((,,,) (Word32
 -> Word32 -> Word32 -> Word32 -> (Word32, Word32, Word32, Word32))
-> SGet Word32
-> SGet
     (Word32 -> Word32 -> Word32 -> (Word32, Word32, Word32, Word32))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SGet Word32
get32 SGet
  (Word32 -> Word32 -> Word32 -> (Word32, Word32, Word32, Word32))
-> SGet Word32
-> SGet (Word32 -> Word32 -> (Word32, Word32, Word32, Word32))
forall a b. SGet (a -> b) -> SGet a -> SGet b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SGet Word32
get32 SGet (Word32 -> Word32 -> (Word32, Word32, Word32, Word32))
-> SGet Word32 -> SGet (Word32 -> (Word32, Word32, Word32, Word32))
forall a b. SGet (a -> b) -> SGet a -> SGet b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SGet Word32
get32 SGet (Word32 -> (Word32, Word32, Word32, Word32))
-> SGet Word32 -> SGet (Word32, Word32, Word32, Word32)
forall a b. SGet (a -> b) -> SGet a -> SGet b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SGet Word32
get32)
{-# INLINE getIPv6 #-}

-- | Reads up to four octets and returns them as an 'IPv4'
-- address padded as needed with trailing 0x0 bytes.
getIPv4Net :: Int -> SGet IPv4
getIPv4Net :: Int -> SGet IPv4
getIPv4Net Int
n | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0 Bool -> Bool -> Bool
&& Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
4 =
    Int -> SGet ByteString
getNByteString Int
n SGet ByteString -> (ByteString -> SGet IPv4) -> SGet IPv4
forall a b. SGet a -> (a -> SGet b) -> SGet b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \ (BS ForeignPtr Word8
fp Int
_) -> IPv4 -> SGet IPv4
forall a. a -> SGet a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IPv4 -> SGet IPv4) -> IPv4 -> SGet IPv4
forall a b. (a -> b) -> a -> b
$! ForeignPtr Word8 -> (Ptr Word8 -> IO IPv4) -> IPv4
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> b
unsafePerformFPIO ForeignPtr Word8
fp \Ptr Word8
ptr -> do
        Int -> Int -> (Ptr Word8 -> IO IPv4) -> IO IPv4
forall a b. Int -> Int -> (Ptr a -> IO b) -> IO b
allocaBytesAligned Int
4 Int
4 ((Ptr Word8 -> IO IPv4) -> IO IPv4)
-> (Ptr Word8 -> IO IPv4) -> IO IPv4
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
buf -> do
            Ptr Word8 -> Word8 -> Int -> IO ()
forall a. Ptr a -> Word8 -> Int -> IO ()
fillBytes Ptr Word8
buf Word8
0 Int
4
            Ptr Word8 -> Ptr Word8 -> Int -> IO ()
forall a. Ptr a -> Ptr a -> Int -> IO ()
copyBytes Ptr Word8
buf Ptr Word8
ptr Int
n
            w <- (Word32 -> Word32) -> Word32 -> Word32
forall a. (a -> a) -> a -> a
toBE Word32 -> Word32
byteSwap32 (Word32 -> Word32) -> IO Word32 -> IO Word32
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr Word32 -> IO Word32
forall a. Storable a => Ptr a -> IO a
peek (Ptr Word8 -> Ptr Word32
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
buf)
            pure $ toIPv4w w
getIPv4Net Int
_ = String -> SGet IPv4
forall a. String -> SGet a
failSGet String
"invalid IPv4 prefix length"

-- | Reads up to 16 octets and returns them as an 'IPv6' address
-- padded as needed with trailing 0x0 bytes.
getIPv6Net :: Int -> SGet IPv6
getIPv6Net :: Int -> SGet IPv6
getIPv6Net Int
n | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0 Bool -> Bool -> Bool
&& Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
16 =
    Int -> SGet ByteString
getNByteString Int
n SGet ByteString -> (ByteString -> SGet IPv6) -> SGet IPv6
forall a b. SGet a -> (a -> SGet b) -> SGet b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \ (BS ForeignPtr Word8
fp Int
_) -> IPv6 -> SGet IPv6
forall a. a -> SGet a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IPv6 -> SGet IPv6) -> IPv6 -> SGet IPv6
forall a b. (a -> b) -> a -> b
$! ForeignPtr Word8 -> (Ptr Word8 -> IO IPv6) -> IPv6
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> b
unsafePerformFPIO ForeignPtr Word8
fp \Ptr Word8
ptr -> do
        Int -> Int -> (Ptr Word8 -> IO IPv6) -> IO IPv6
forall a b. Int -> Int -> (Ptr a -> IO b) -> IO b
allocaBytesAligned Int
16 Int
4 ((Ptr Word8 -> IO IPv6) -> IO IPv6)
-> (Ptr Word8 -> IO IPv6) -> IO IPv6
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
buf -> do
            Ptr Word8 -> Word8 -> Int -> IO ()
forall a. Ptr a -> Word8 -> Int -> IO ()
fillBytes Ptr Word8
buf Word8
0 Int
16
            Ptr Word8 -> Ptr Word8 -> Int -> IO ()
forall a. Ptr a -> Ptr a -> Int -> IO ()
copyBytes Ptr Word8
buf Ptr Word8
ptr Int
n
            w0 <- (Word32 -> Word32) -> Word32 -> Word32
forall a. (a -> a) -> a -> a
toBE Word32 -> Word32
byteSwap32 (Word32 -> Word32) -> IO Word32 -> IO Word32
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr Word32 -> Int -> IO Word32
forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (Ptr Word8 -> Ptr Word32
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
buf) Int
0
            w1 <- toBE byteSwap32 <$> peekElemOff (castPtr buf) 1
            w2 <- toBE byteSwap32 <$> peekElemOff (castPtr buf) 2
            w3 <- toBE byteSwap32 <$> peekElemOff (castPtr buf) 3
            pure $ toIPv6w (w0,w1,w2,w3)
getIPv6Net Int
_ = String -> SGet IPv6
forall a. String -> SGet a
failSGet String
"invalid IPv6 prefix length"

-- | Converts a 32-bit circle-arithmetic DNS time to an absolute 64-bit DNS
-- timestamp that lies within a 31-bit band of the parser state's reference
-- timestamp.
getDnsTime :: SGet Int64
getDnsTime :: SGet Int64
getDnsTime = Word32 -> Int64 -> Int64
dnsTime (Word32 -> Int64 -> Int64) -> SGet Word32 -> SGet (Int64 -> Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SGet Word32
get32 SGet (Int64 -> Int64) -> SGet Int64 -> SGet Int64
forall a b. SGet (a -> b) -> SGet a -> SGet b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SGet Int64
getChrono
  where
    dnsTime :: Word32 -- ^ DNS circle-arithmetic timestamp
            -> Int64  -- ^ reference epoch time
            -> Int64  -- ^ absolute DNS timestamp
    dnsTime :: Word32 -> Int64 -> Int64
dnsTime Word32
tdns Int64
tnow =
        let delta :: Word32
delta = Word32
tdns Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
- Int64 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
tnow
         in if Word32
delta Word32 -> Word32 -> Bool
forall a. Ord a => a -> a -> Bool
> Word32
0x7FFF_FFFF -- tdns is in the past?
               then Int64
tnow Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
- (Int64
0x1_0000_0000 Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
- Word32 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
delta)
               else Int64
tnow Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ Word32 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
delta
{-# INLINE getDnsTime #-}

----------------------------------------

-- | Consumes and returns @n@ bytes of input from the buffer.
getNBytes :: Int -> SGet [Word8]
getNBytes :: Int -> SGet [Word8]
getNBytes Int
n = ByteString -> [Word8]
B.unpack (ByteString -> [Word8]) -> SGet ByteString -> SGet [Word8]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> SGet ByteString
getNByteString Int
n
{-# INLINE getNBytes #-}

-- | Decodes a sequence of values with a fixed wire-form byte-width.
getFixedWidthSequence :: Int    -- ^ Number of octets to encode one value
                      -> SGet a -- ^ Decoder for a single value
                      -> Int    -- ^ Total number of octets in the sequence
                      -> SGet [a]
getFixedWidthSequence :: forall a. Int -> SGet a -> Int -> SGet [a]
getFixedWidthSequence Int
wdth SGet a
getOne len :: Int
len@((Int -> Int -> (Int, Int)
forall a. Integral a => a -> a -> (a, a)
`quotRem` Int
wdth) -> (Int
cnt, Int
0)) =
  Int -> SGet [a] -> SGet [a]
forall a. Int -> SGet a -> SGet a
fitSGet Int
len (SGet [a] -> SGet [a]) -> SGet [a] -> SGet [a]
forall a b. (a -> b) -> a -> b
$ Int -> SGet a -> SGet [a]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
cnt SGet a
getOne
getFixedWidthSequence Int
_ SGet a
_ Int
_ =
  String -> SGet [a]
forall a. String -> SGet a
failSGet String
"sequence length not multiple of element size"
{-# INLINE getFixedWidthSequence #-}

-- | Decodes a sequence of values with a variable wire-form byte-width.
getVarWidthSequence :: SGet a   -- ^ Decoder for a single value
                    -> Int      -- ^ Total number of octets in the sequence
                    -> SGet [a]
getVarWidthSequence :: forall a. SGet a -> Int -> SGet [a]
getVarWidthSequence SGet a
getOne = Int -> SGet [a] -> SGet [a]
forall a. Int -> SGet a -> SGet a
fitSGet (Int -> SGet [a] -> SGet [a])
-> (Int -> Int) -> Int -> SGet [a] -> SGet [a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Int
forall a. a -> a
id (Int -> SGet [a] -> SGet [a])
-> (Int -> SGet [a]) -> Int -> SGet [a]
forall a b. (Int -> a -> b) -> (Int -> a) -> Int -> b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> SGet [a]
go
  where
    go :: Int -> SGet [a]
go Int
n | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 = do
      pos0 <- SGet Int
getPosition
      x    <- getOne
      used <- (subtract pos0) <$> getPosition
      (x : ) <$> go (n - used)
    go Int
0 = [a] -> SGet [a]
forall a. a -> SGet a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
    go Int
_ = String -> SGet [a]
forall a. String -> SGet a
failSGet String
"last sequence element read past limit"
{-# INLINE getVarWidthSequence #-}

-- | Consumes the rest of the buffer as 'SB.ShortByteString'.
getShortByteString :: SGet ShortByteString
getShortByteString :: SGet ShortByteString
getShortByteString = ByteString -> ShortByteString
SB.toShort (ByteString -> ShortByteString)
-> SGet ByteString -> SGet ShortByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Int -> SGet ByteString
getNByteString (Int -> SGet ByteString) -> SGet Int -> SGet ByteString
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (SGetState -> Int) -> SGet Int
forall a. (SGetState -> a) -> SGet a
gets SGetState -> Int
psLength)
{-# INLINE getShortByteString #-}

-- | Consumes and returns a 'SB.ShortByteString' of length @n@ from the buffer.
getShortNByteString :: Int -> SGet ShortByteString
getShortNByteString :: Int -> SGet ShortByteString
getShortNByteString Int
n = ByteString -> ShortByteString
SB.toShort (ByteString -> ShortByteString)
-> SGet ByteString -> SGet ShortByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> SGet ByteString
getNByteString Int
n
{-# INLINE getShortNByteString #-}

-- | Read a ShortByteString whose length is determined by an 8-bit prefix.
getShortByteStringLen8 :: SGet ShortByteString
getShortByteStringLen8 :: SGet ShortByteString
getShortByteStringLen8 = SGet Int
getInt8 SGet Int -> (Int -> SGet ShortByteString) -> SGet ShortByteString
forall a b. SGet a -> (a -> SGet b) -> SGet b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int -> SGet ShortByteString
getShortNByteString
{-# INLINE getShortByteStringLen8 #-}

-- | Read a ShortByteString whose length is determined by a 16-bit prefix.
getShortByteStringLen16 :: SGet ShortByteString
getShortByteStringLen16 :: SGet ShortByteString
getShortByteStringLen16 = SGet Int
getInt16 SGet Int -> (Int -> SGet ShortByteString) -> SGet ShortByteString
forall a b. SGet a -> (a -> SGet b) -> SGet b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int -> SGet ShortByteString
getShortNByteString
{-# INLINE getShortByteStringLen16 #-}

-- | Read a UTF8-encoded text string of the given length.
getUtf8Text :: Int -> SGet T.Text
getUtf8Text :: Int -> SGet Text
getUtf8Text Int
len = ByteString -> Either UnicodeException Text
T.decodeUtf8' (ByteString -> Either UnicodeException Text)
-> SGet ByteString -> SGet (Either UnicodeException Text)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> SGet ByteString
getNByteString Int
len SGet (Either UnicodeException Text)
-> (Either UnicodeException Text -> SGet Text) -> SGet Text
forall a b. SGet a -> (a -> SGet b) -> SGet b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \ case
    Right Text
txt -> Text -> SGet Text
forall a. a -> SGet a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Text
txt
    Left  UnicodeException
err -> String -> SGet Text
forall a. String -> SGet a
failSGet (String -> SGet Text) -> String -> SGet Text
forall a b. (a -> b) -> a -> b
$ UnicodeException -> String
forall a. Show a => a -> String
show UnicodeException
err
{-# INLINE getUtf8Text #-}

-- | Read a UTF8-encoded text string preceded by an explicit 8-bit length.
getUtf8TextLen8 :: SGet T.Text
getUtf8TextLen8 :: SGet Text
getUtf8TextLen8 = SGet Int
getInt8 SGet Int
-> (Int -> SGet (Either UnicodeException Text))
-> SGet (Either UnicodeException Text)
forall a b. SGet a -> (a -> SGet b) -> SGet b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ByteString -> Either UnicodeException Text
T.decodeUtf8' (ByteString -> Either UnicodeException Text)
-> (Int -> SGet ByteString)
-> Int
-> SGet (Either UnicodeException Text)
forall (m :: * -> *) b c a.
Functor m =>
(b -> c) -> (a -> m b) -> a -> m c
<.> Int -> SGet ByteString
getNByteString SGet (Either UnicodeException Text)
-> (Either UnicodeException Text -> SGet Text) -> SGet Text
forall a b. SGet a -> (a -> SGet b) -> SGet b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \ case
    Right Text
txt -> Text -> SGet Text
forall a. a -> SGet a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Text
txt
    Left  UnicodeException
err -> String -> SGet Text
forall a. String -> SGet a
failSGet (String -> SGet Text) -> String -> SGet Text
forall a b. (a -> b) -> a -> b
$ UnicodeException -> String
forall a. Show a => a -> String
show UnicodeException
err
{-# INLINE getUtf8TextLen8 #-}

-- | Read a UTF8-encoded text string preceded by an explicit 16-bit length.
getUtf8TextLen16 :: SGet T.Text
getUtf8TextLen16 :: SGet Text
getUtf8TextLen16 = SGet Int
getInt16 SGet Int
-> (Int -> SGet (Either UnicodeException Text))
-> SGet (Either UnicodeException Text)
forall a b. SGet a -> (a -> SGet b) -> SGet b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ByteString -> Either UnicodeException Text
T.decodeUtf8' (ByteString -> Either UnicodeException Text)
-> (Int -> SGet ByteString)
-> Int
-> SGet (Either UnicodeException Text)
forall (m :: * -> *) b c a.
Functor m =>
(b -> c) -> (a -> m b) -> a -> m c
<.> Int -> SGet ByteString
getNByteString SGet (Either UnicodeException Text)
-> (Either UnicodeException Text -> SGet Text) -> SGet Text
forall a b. SGet a -> (a -> SGet b) -> SGet b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \ case
    Right Text
txt -> Text -> SGet Text
forall a. a -> SGet a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Text
txt
    Left  UnicodeException
err -> String -> SGet Text
forall a. String -> SGet a
failSGet (String -> SGet Text) -> String -> SGet Text
forall a b. (a -> b) -> a -> b
$ UnicodeException -> String
forall a. Show a => a -> String
show UnicodeException
err
{-# INLINE getUtf8TextLen16 #-}

-- | Seek to a given offset and run a parser that can consume at
-- most the given number of bytes.  The caller's state remains
-- unchanged.  Used exclusively for decoding DNS message name
-- compression.
seekSGet :: Word16 -> SGet a -> SGet a
seekSGet :: forall a. Word16 -> SGet a -> SGet a
seekSGet Word16
pos SGet a
parser = do
    let off :: Int
off = Word16 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
pos
    len <- ByteString -> Int
B.length (ByteString -> Int) -> SGet ByteString -> SGet Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SGet ByteString
getPacket
    when (off > len) do
        failSGet "seek attempt beyond end of buffer"
    env   <- ask
    state <- gets \ SGetState
s -> SGetState
s { psOffset = off
                           , psLength = len - off }
    case runSGet parser env state of
        Right (a
ret, SGetState
_) -> a -> SGet a
forall a. a -> SGet a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
ret
        Left DNSError
err       -> DNSError -> SGet a
forall a. DNSError -> SGet a
throw DNSError
err

-- | Runs a parser on an initial segment of the unread input.
-- Consumes exactly the specified number of bytes or fails.
fitSGet :: Int -> SGet a -> SGet a
fitSGet :: forall a. Int -> SGet a -> SGet a
fitSGet Int
len SGet a
parser | Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0 = do
    s <- SGet SGetState
get
    when (psLength s < len) do
        failSGet "requested skip bytecount exceeds available"
    when (len > 0) do
        modify' $ \SGetState
t -> SGetState
t { psOffset = psOffset s + len
                          , psLength = psLength s - len }
    env <- ask
    case runSGet parser env s { psLength = len } of
        Right (a
ret, SGetState
t)
            | SGetState -> Int
psLength SGetState
t Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 -> a -> SGet a
forall a. a -> SGet a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a -> SGet a) -> a -> SGet a
forall a b. (a -> b) -> a -> b
$! a
ret
            | Bool
otherwise       -> String -> SGet a
forall a. String -> SGet a
failSGet String
"element shorter than indicated size"
        Left DNSError
err -> DNSError -> SGet a
forall a. DNSError -> SGet a
throw DNSError
err
fitSGet Int
_ SGet a
_ = String -> SGet a
forall a. String -> SGet a
failSGet String
"negative sanbox buffer size"
{-# INLINE fitSGet #-}

--------------------------

-- | Run a decoder with a given epoch offset over specified input
decodeAtWith :: Int64       -- ^ Current absolute offset from epoch
             -> Bool        -- ^ Support name compression?
             -> SGet a      -- ^ Decoder to run
             -> ByteString  -- ^ Buffer to run decoder over
             -> Either DNSError a
decodeAtWith :: forall a.
Int64 -> Bool -> SGet a -> ByteString -> Either DNSError a
decodeAtWith Int64
t Bool
nc SGet a
parser ByteString
inp =
    SGet a -> SGetEnv -> SGetState -> Either DNSError a
forall a. SGet a -> SGetEnv -> SGetState -> Either DNSError a
evalSGet SGet a
parser SGetEnv{Bool
Int64
Maybe MessageSource
Maybe DnsTriple
ByteString
DnsSection
forall {a}. Maybe a
psPacket :: ByteString
psChrono :: Int64
psNameComp :: Bool
psSection :: DnsSection
psTriple :: Maybe DnsTriple
psSource :: Maybe MessageSource
psPacket :: ByteString
psChrono :: Int64
psNameComp :: Bool
psSection :: DnsSection
psTriple :: forall {a}. Maybe a
psSource :: forall {a}. Maybe a
..} SGetState{Int
Domain
psOffset :: Int
psLength :: Int
psLastOwner :: Domain
psLastCname :: Domain
psOffset :: Int
psLength :: Int
psLastOwner :: Domain
psLastCname :: Domain
..}
  where
    psPacket :: ByteString
psPacket     = ByteString
inp
    psChrono :: Int64
psChrono     = Int64
t
    psNameComp :: Bool
psNameComp   = Bool
nc
    psSection :: DnsSection
psSection    = DnsSection
DnsNonSection
    psTriple :: Maybe a
psTriple     = Maybe a
forall {a}. Maybe a
Nothing
    psSource :: Maybe a
psSource     = Maybe a
forall {a}. Maybe a
Nothing
    psOffset :: Int
psOffset     = Int
0
    psLength :: Int
psLength     = ByteString -> Int
B.length ByteString
inp
    psLastOwner :: Domain
psLastOwner  = Domain
RootDomain
    psLastCname :: Domain
psLastCname  = Domain
RootDomain

--------------------------

-- | Minimal Reader + State + Except Monad.
newtype SGet a = SGet { forall a.
SGet a -> SGetEnv -> SGetState -> Either DNSError (a, SGetState)
runSGet :: SGetEnv -> SGetState -> Either DNSError (a, SGetState) }

evalSGet :: SGet a -> SGetEnv -> SGetState -> Either DNSError a
evalSGet :: forall a. SGet a -> SGetEnv -> SGetState -> Either DNSError a
evalSGet SGet a
m = \SGetEnv
r SGetState
s -> (a, SGetState) -> a
forall a b. (a, b) -> a
fst ((a, SGetState) -> a)
-> Either DNSError (a, SGetState) -> Either DNSError a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SGet a -> SGetEnv -> SGetState -> Either DNSError (a, SGetState)
forall a.
SGet a -> SGetEnv -> SGetState -> Either DNSError (a, SGetState)
runSGet SGet a
m SGetEnv
r SGetState
s
{-# INLINE evalSGet #-}

instance Functor SGet where
    fmap :: forall a b. (a -> b) -> SGet a -> SGet b
fmap a -> b
f SGet a
m = (SGetEnv -> SGetState -> Either DNSError (b, SGetState)) -> SGet b
forall a.
(SGetEnv -> SGetState -> Either DNSError (a, SGetState)) -> SGet a
SGet ((SGetEnv -> SGetState -> Either DNSError (b, SGetState))
 -> SGet b)
-> (SGetEnv -> SGetState -> Either DNSError (b, SGetState))
-> SGet b
forall a b. (a -> b) -> a -> b
$ \SGetEnv
r SGetState
s -> do
        (a, t) <- SGet a -> SGetEnv -> SGetState -> Either DNSError (a, SGetState)
forall a.
SGet a -> SGetEnv -> SGetState -> Either DNSError (a, SGetState)
runSGet SGet a
m SGetEnv
r SGetState
s
        pure (f a, t)
    {-# INLINE fmap #-}

instance Applicative SGet where
    pure :: forall a. a -> SGet a
pure a
a = (SGetEnv -> SGetState -> Either DNSError (a, SGetState)) -> SGet a
forall a.
(SGetEnv -> SGetState -> Either DNSError (a, SGetState)) -> SGet a
SGet ((SGetEnv -> SGetState -> Either DNSError (a, SGetState))
 -> SGet a)
-> (SGetEnv -> SGetState -> Either DNSError (a, SGetState))
-> SGet a
forall a b. (a -> b) -> a -> b
$ \SGetEnv
_ SGetState
s -> (a, SGetState) -> Either DNSError (a, SGetState)
forall a. a -> Either DNSError a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
a, SGetState
s)
    {-# INLINE pure #-}
    SGet (a -> b)
mf <*> :: forall a b. SGet (a -> b) -> SGet a -> SGet b
<*> SGet a
ma = (SGetEnv -> SGetState -> Either DNSError (b, SGetState)) -> SGet b
forall a.
(SGetEnv -> SGetState -> Either DNSError (a, SGetState)) -> SGet a
SGet ((SGetEnv -> SGetState -> Either DNSError (b, SGetState))
 -> SGet b)
-> (SGetEnv -> SGetState -> Either DNSError (b, SGetState))
-> SGet b
forall a b. (a -> b) -> a -> b
$ \SGetEnv
r SGetState
s -> do
        (f, t) <- SGet (a -> b)
-> SGetEnv -> SGetState -> Either DNSError (a -> b, SGetState)
forall a.
SGet a -> SGetEnv -> SGetState -> Either DNSError (a, SGetState)
runSGet SGet (a -> b)
mf SGetEnv
r SGetState
s
        (a, u) <- runSGet ma r t
        pure (f a, u)
    {-# INLINE (<*>) #-}
    liftA2 :: forall a b c. (a -> b -> c) -> SGet a -> SGet b -> SGet c
liftA2 a -> b -> c
f SGet a
ma SGet b
mb = (SGetEnv -> SGetState -> Either DNSError (c, SGetState)) -> SGet c
forall a.
(SGetEnv -> SGetState -> Either DNSError (a, SGetState)) -> SGet a
SGet ((SGetEnv -> SGetState -> Either DNSError (c, SGetState))
 -> SGet c)
-> (SGetEnv -> SGetState -> Either DNSError (c, SGetState))
-> SGet c
forall a b. (a -> b) -> a -> b
$ \SGetEnv
r SGetState
s -> do
        (a, t) <- SGet a -> SGetEnv -> SGetState -> Either DNSError (a, SGetState)
forall a.
SGet a -> SGetEnv -> SGetState -> Either DNSError (a, SGetState)
runSGet SGet a
ma SGetEnv
r SGetState
s
        (b, u) <- runSGet mb r t
        pure (f a b, u)
    {-# INLINE liftA2 #-}
    SGet a
ma *> :: forall a b. SGet a -> SGet b -> SGet b
*> SGet b
mb = (SGetEnv -> SGetState -> Either DNSError (b, SGetState)) -> SGet b
forall a.
(SGetEnv -> SGetState -> Either DNSError (a, SGetState)) -> SGet a
SGet ((SGetEnv -> SGetState -> Either DNSError (b, SGetState))
 -> SGet b)
-> (SGetEnv -> SGetState -> Either DNSError (b, SGetState))
-> SGet b
forall a b. (a -> b) -> a -> b
$ \SGetEnv
r SGetState
s -> do
        (_, t) <- SGet a -> SGetEnv -> SGetState -> Either DNSError (a, SGetState)
forall a.
SGet a -> SGetEnv -> SGetState -> Either DNSError (a, SGetState)
runSGet SGet a
ma SGetEnv
r SGetState
s
        runSGet mb r t
    {-# INLINE (*>) #-}
    SGet a
ma <* :: forall a b. SGet a -> SGet b -> SGet a
<* SGet b
mb = (SGetEnv -> SGetState -> Either DNSError (a, SGetState)) -> SGet a
forall a.
(SGetEnv -> SGetState -> Either DNSError (a, SGetState)) -> SGet a
SGet ((SGetEnv -> SGetState -> Either DNSError (a, SGetState))
 -> SGet a)
-> (SGetEnv -> SGetState -> Either DNSError (a, SGetState))
-> SGet a
forall a b. (a -> b) -> a -> b
$ \SGetEnv
r SGetState
s -> do
        (a, t) <- SGet a -> SGetEnv -> SGetState -> Either DNSError (a, SGetState)
forall a.
SGet a -> SGetEnv -> SGetState -> Either DNSError (a, SGetState)
runSGet SGet a
ma SGetEnv
r SGetState
s
        (_, u) <- runSGet mb r t
        pure (a, u)
    {-# INLINE (<*) #-}

instance Monad SGet where
    SGet a
ma >>= :: forall a b. SGet a -> (a -> SGet b) -> SGet b
>>= a -> SGet b
f = (SGetEnv -> SGetState -> Either DNSError (b, SGetState)) -> SGet b
forall a.
(SGetEnv -> SGetState -> Either DNSError (a, SGetState)) -> SGet a
SGet ((SGetEnv -> SGetState -> Either DNSError (b, SGetState))
 -> SGet b)
-> (SGetEnv -> SGetState -> Either DNSError (b, SGetState))
-> SGet b
forall a b. (a -> b) -> a -> b
$ \SGetEnv
r SGetState
s -> do
        (a, t) <- SGet a -> SGetEnv -> SGetState -> Either DNSError (a, SGetState)
forall a.
SGet a -> SGetEnv -> SGetState -> Either DNSError (a, SGetState)
runSGet SGet a
ma SGetEnv
r SGetState
s
        runSGet (f a) r t
    {-# INLINE (>>=) #-}

ask :: SGet SGetEnv
ask :: SGet SGetEnv
ask  = (SGetEnv -> SGetState -> Either DNSError (SGetEnv, SGetState))
-> SGet SGetEnv
forall a.
(SGetEnv -> SGetState -> Either DNSError (a, SGetState)) -> SGet a
SGet ((SGetEnv -> SGetState -> Either DNSError (SGetEnv, SGetState))
 -> SGet SGetEnv)
-> (SGetEnv -> SGetState -> Either DNSError (SGetEnv, SGetState))
-> SGet SGetEnv
forall a b. (a -> b) -> a -> b
$ \SGetEnv
r SGetState
s -> (SGetEnv, SGetState) -> Either DNSError (SGetEnv, SGetState)
forall a. a -> Either DNSError a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SGetEnv
r, SGetState
s)
{-# INLINE ask #-}

asks :: (SGetEnv -> a) -> SGet a
asks :: forall a. (SGetEnv -> a) -> SGet a
asks SGetEnv -> a
f = (SGetEnv -> SGetState -> Either DNSError (a, SGetState)) -> SGet a
forall a.
(SGetEnv -> SGetState -> Either DNSError (a, SGetState)) -> SGet a
SGet ((SGetEnv -> SGetState -> Either DNSError (a, SGetState))
 -> SGet a)
-> (SGetEnv -> SGetState -> Either DNSError (a, SGetState))
-> SGet a
forall a b. (a -> b) -> a -> b
$ \SGetEnv
r SGetState
s -> (a, SGetState) -> Either DNSError (a, SGetState)
forall a. a -> Either DNSError a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SGetEnv -> a
f SGetEnv
r, SGetState
s)
{-# INLINE asks #-}

get :: SGet SGetState
get :: SGet SGetState
get = (SGetEnv -> SGetState -> Either DNSError (SGetState, SGetState))
-> SGet SGetState
forall a.
(SGetEnv -> SGetState -> Either DNSError (a, SGetState)) -> SGet a
SGet ((SGetEnv -> SGetState -> Either DNSError (SGetState, SGetState))
 -> SGet SGetState)
-> (SGetEnv -> SGetState -> Either DNSError (SGetState, SGetState))
-> SGet SGetState
forall a b. (a -> b) -> a -> b
$ \SGetEnv
_ SGetState
s -> (SGetState, SGetState) -> Either DNSError (SGetState, SGetState)
forall a. a -> Either DNSError a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SGetState
s, SGetState
s)
{-# INLINE get #-}

gets :: (SGetState -> a) -> SGet a
gets :: forall a. (SGetState -> a) -> SGet a
gets SGetState -> a
f = (SGetEnv -> SGetState -> Either DNSError (a, SGetState)) -> SGet a
forall a.
(SGetEnv -> SGetState -> Either DNSError (a, SGetState)) -> SGet a
SGet ((SGetEnv -> SGetState -> Either DNSError (a, SGetState))
 -> SGet a)
-> (SGetEnv -> SGetState -> Either DNSError (a, SGetState))
-> SGet a
forall a b. (a -> b) -> a -> b
$ \SGetEnv
_ SGetState
s -> (a, SGetState) -> Either DNSError (a, SGetState)
forall a. a -> Either DNSError a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SGetState -> a
f SGetState
s, SGetState
s)
{-# INLINE gets #-}

local :: (SGetEnv -> SGetEnv) -> SGet a -> SGet a
local :: forall a. (SGetEnv -> SGetEnv) -> SGet a -> SGet a
local SGetEnv -> SGetEnv
f SGet a
m = (SGetEnv -> SGetState -> Either DNSError (a, SGetState)) -> SGet a
forall a.
(SGetEnv -> SGetState -> Either DNSError (a, SGetState)) -> SGet a
SGet ((SGetEnv -> SGetState -> Either DNSError (a, SGetState))
 -> SGet a)
-> (SGetEnv -> SGetState -> Either DNSError (a, SGetState))
-> SGet a
forall a b. (a -> b) -> a -> b
$ \ SGetEnv
r SGetState
s -> SGet a -> SGetEnv -> SGetState -> Either DNSError (a, SGetState)
forall a.
SGet a -> SGetEnv -> SGetState -> Either DNSError (a, SGetState)
runSGet SGet a
m (SGetEnv -> SGetEnv
f SGetEnv
r) SGetState
s
{-# INLINE local #-}

modify' :: (SGetState -> SGetState) -> SGet ()
modify' :: (SGetState -> SGetState) -> SGet ()
modify' SGetState -> SGetState
f = (SGetEnv -> SGetState -> Either DNSError ((), SGetState))
-> SGet ()
forall a.
(SGetEnv -> SGetState -> Either DNSError (a, SGetState)) -> SGet a
SGet ((SGetEnv -> SGetState -> Either DNSError ((), SGetState))
 -> SGet ())
-> (SGetEnv -> SGetState -> Either DNSError ((), SGetState))
-> SGet ()
forall a b. (a -> b) -> a -> b
$ \SGetEnv
_ SGetState
s -> let !t :: SGetState
t = SGetState -> SGetState
f SGetState
s in ((), SGetState) -> Either DNSError ((), SGetState)
forall a. a -> Either DNSError a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((), SGetState
t)
{-# INLINE modify' #-}

throw :: DNSError -> SGet a
throw :: forall a. DNSError -> SGet a
throw DNSError
e = (SGetEnv -> SGetState -> Either DNSError (a, SGetState)) -> SGet a
forall a.
(SGetEnv -> SGetState -> Either DNSError (a, SGetState)) -> SGet a
SGet ((SGetEnv -> SGetState -> Either DNSError (a, SGetState))
 -> SGet a)
-> (SGetEnv -> SGetState -> Either DNSError (a, SGetState))
-> SGet a
forall a b. (a -> b) -> a -> b
$ \SGetEnv
_ SGetState
_ -> DNSError -> Either DNSError (a, SGetState)
forall a b. a -> Either a b
Left DNSError
e
{-# INLINE throw #-}