{-# LANGUAGE OverloadedStrings #-}

{-|
Module      : Keter.RateLimiter.RequestUtils
Description : Utility functions for extracting data from WAI requests
Copyright   : (c) 2025 Oleksandr Zhabenko
License     : MIT
Maintainer  : oleksandr.zhabenko@yahoo.com
Stability   : stable
Portability : portable

Utility helpers for extracting /stable textual keys/ from a WAI
'Network.Wai.Request'.  They are primarily intended for use with
rate-limiting middleware (see the @keter-rate-limiting-plugin@ package) but are fully
generic and can be employed anywhere you need a deterministic identifier that
ties a request to its origin (IP address, path, user-agent, …).

The helpers follow these rules:

1.  /Zero/ allocation whenever the value is already available in the request
    record (e.g. @rawPathInfo@ or @requestMethod@ are reused verbatim).
2.  No reverse DNS or other network round-trips -- the functions are pure and
    fast.
3.  Header names are handled case-insensitively via the
    'Data.CaseInsensitive.CI' type.

== Quick example

@
{-# LANGUAGE OverloadedStrings #-}
import Control.Monad.IO.Class (liftIO)
import Network.Wai
import Network.Wai.Handler.Warp (run)
import Network.HTTP.Types (status200)
import Keter.RateLimiter.RequestUtils (byIPAndPath)
import Data.Text.IO as TIO

logKey :: Request -> IO ()
logKey req = do
  mk <- byIPAndPath req
  case mk of
    Nothing  -> TIO.putStrLn "cannot build key"
    Just key -> TIO.putStrLn ("request key = " <> key)

app :: Application
app req respond = liftIO (logKey req) >> respond (responseLBS status200 [] "OK")

main :: IO ()
main = run 8080 app
@

== Converting sockets to text

Functions 'ipv4ToString' and 'ipv6ToString' perform a /lossless/ conversion of
binary socket addresses to their canonical textual representations.  The
implementation is intentionally simple and does not attempt to compress IPv6
zeros (you get four-hextet groups padded to 4 digits).

-}

module Keter.RateLimiter.RequestUtils
  ( -- * Low-level helpers
    ipv4ToString
  , ipv6ToString

    -- * Basic request information
  , getClientIP
  , getRequestPath
  , getRequestMethod
  , getRequestHost
  , getRequestUserAgent

    -- * Composite key builders
  , byIP
  , byIPAndPath
  , byIPAndUserAgent
  , byHeaderAndIP
  ) where

import Data.Text (Text)
import qualified Data.Text as T
import qualified Data.Text.Encoding as TE
import qualified Data.Text.Encoding.Error as TEE
import Network.Wai (Request)
import qualified Network.Wai as WAI
import Network.Socket
       ( SockAddr(..)
       , HostAddress
       , HostAddress6
       , hostAddressToTuple
       )
import Network.HTTP.Types.Header (HeaderName, hHost, hUserAgent)
import Data.Bits ((.&.), shiftR)
import Data.CaseInsensitive (mk)
import Numeric (showHex)

----------------------------------------------------------------------
-- Low-level helpers
----------------------------------------------------------------------

-- | Convert an IPv4 'HostAddress' to dotted-decimal 'Text'.
--
-- ==== __Example__
--
-- >>> ipv4ToString 0x7f000001     -- 127.0.0.1
-- "127.0.0.1"
ipv4ToString :: HostAddress -> Text
ipv4ToString :: Word32 -> Text
ipv4ToString Word32
addr =
  let (Word8
o1, Word8
o2, Word8
o3, Word8
o4) = Word32 -> (Word8, Word8, Word8, Word8)
hostAddressToTuple Word32
addr
  in Text -> [Text] -> Text
T.intercalate Text
"." ((Word8 -> Text) -> [Word8] -> [Text]
forall a b. (a -> b) -> [a] -> [b]
map (String -> Text
T.pack (String -> Text) -> (Word8 -> String) -> Word8 -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word8 -> String
forall a. Show a => a -> String
show) [Word8
o1, Word8
o2, Word8
o3, Word8
o4])

-- | Render an IPv6 'HostAddress6' as eight 16-bit hex blocks separated
-- by ':'.  Each block is zero-padded to four characters. This rendering
-- is canonical but not compressed (e.g., it does not use @::@).
--
-- The function is micro-optimised to avoid lists and string formatting functions.
--
-- ==== __Example__
--
-- >>> ipv6ToString (0,0,0,1)
-- "0000:0000:0000:0000:0000:0000:0000:0001"
ipv6ToString :: HostAddress6 -> Text
ipv6ToString :: HostAddress6 -> Text
ipv6ToString (Word32
w1, Word32
w2, Word32
w3, Word32
w4) =
  Text -> [Text] -> Text
T.intercalate Text
":" ([Text] -> Text) -> [Text] -> Text
forall a b. (a -> b) -> a -> b
$ (Word32 -> Text) -> [Word32] -> [Text]
forall a b. (a -> b) -> [a] -> [b]
map (String -> Text
T.pack (String -> Text) -> (Word32 -> String) -> Word32 -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> String
pad4 (String -> String) -> (Word32 -> String) -> Word32 -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Word32 -> String -> String
forall a. Integral a => a -> String -> String
`showHex` String
"")) [Word32]
words16
  where
    words16 :: [Word32]
words16 =
      [ Word32
w1 Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
16, Word32
w1 Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
0xFFFF
      , Word32
w2 Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
16, Word32
w2 Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
0xFFFF
      , Word32
w3 Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
16, Word32
w3 Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
0xFFFF
      , Word32
w4 Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
16, Word32
w4 Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
0xFFFF
      ]
    pad4 :: String -> String
pad4 String
s = Int -> Char -> String
forall a. Int -> a -> [a]
replicate (Int
4 Int -> Int -> Int
forall a. Num a => a -> a -> a
- String -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length String
s) Char
'0' String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
s

----------------------------------------------------------------------
-- Basic request information
----------------------------------------------------------------------

-- | Best-effort client IP address detection.
--
-- This function attempts to find the most accurate client IP address by checking
-- common proxy headers first, falling back to the direct socket address if they
-- are not present.
--
-- The priority order for detection is:
--
-- 1.  @X-Forwarded-For@ (takes the first IP in the comma-separated list).
-- 2.  @X-Real-Ip@.
-- 3.  The 'Network.Wai.remoteHost' from the WAI 'Request' object.
--
-- Header names are matched case-insensitively. IPv4 and IPv6 addresses are
-- converted to text using 'ipv4ToString' and 'ipv6ToString' respectively.
-- Unix sockets are represented by their file path.
getClientIP :: Request -> IO Text
getClientIP :: Request -> IO Text
getClientIP Request
req = do
  let safeDecode :: ByteString -> Text
safeDecode = OnDecodeError -> ByteString -> Text
TE.decodeUtf8With OnDecodeError
TEE.lenientDecode
      ipTxt :: Text
ipTxt = case HeaderName -> [(HeaderName, ByteString)] -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup (ByteString -> HeaderName
forall s. FoldCase s => s -> CI s
mk ByteString
"x-forwarded-for") (Request -> [(HeaderName, ByteString)]
WAI.requestHeaders Request
req) of
        Just ByteString
xff -> (Char -> Bool) -> Text -> Text
T.takeWhile (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
/= Char
',') (Text -> Text) -> (ByteString -> Text) -> ByteString -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Text
safeDecode (ByteString -> Text) -> ByteString -> Text
forall a b. (a -> b) -> a -> b
$ ByteString
xff
        Maybe ByteString
Nothing  -> case HeaderName -> [(HeaderName, ByteString)] -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup (ByteString -> HeaderName
forall s. FoldCase s => s -> CI s
mk ByteString
"x-real-ip") (Request -> [(HeaderName, ByteString)]
WAI.requestHeaders Request
req) of
          Just ByteString
rip -> ByteString -> Text
safeDecode ByteString
rip
          Maybe ByteString
Nothing  -> case Request -> SockAddr
WAI.remoteHost Request
req of
            SockAddrInet  PortNumber
_ Word32
addr        -> Word32 -> Text
ipv4ToString Word32
addr
            SockAddrInet6 PortNumber
_ Word32
_ HostAddress6
addr Word32
_    -> HostAddress6 -> Text
ipv6ToString HostAddress6
addr
            SockAddrUnix   String
path         -> String -> Text
T.pack String
path
  Text -> IO Text
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Text
ipTxt

-- | Extracts the raw path info from the request and decodes it using lenient UTF-8 'Text'.
-- This is equivalent to @'TE.decodeUtf8With' 'TEE.lenientDecode' . 'WAI.rawPathInfo'@.
getRequestPath :: Request -> Text
getRequestPath :: Request -> Text
getRequestPath = OnDecodeError -> ByteString -> Text
TE.decodeUtf8With OnDecodeError
TEE.lenientDecode (ByteString -> Text) -> (Request -> ByteString) -> Request -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> ByteString
WAI.rawPathInfo

-- | Extracts the HTTP request method (e.g., @"GET"@, @"POST"@) and returns it as a 'Text' value.
-- This is equivalent to @'TE.decodeUtf8' . 'WAI.requestMethod'@ (methods are ASCII).
getRequestMethod :: Request -> Text
getRequestMethod :: Request -> Text
getRequestMethod = ByteString -> Text
TE.decodeUtf8 (ByteString -> Text) -> (Request -> ByteString) -> Request -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> ByteString
WAI.requestMethod

-- | Extracts the value of the @Host@ header, if present, using lenient UTF-8 decoding.
getRequestHost :: Request -> Maybe Text
getRequestHost :: Request -> Maybe Text
getRequestHost = (ByteString -> Text) -> Maybe ByteString -> Maybe Text
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (OnDecodeError -> ByteString -> Text
TE.decodeUtf8With OnDecodeError
TEE.lenientDecode) (Maybe ByteString -> Maybe Text)
-> (Request -> Maybe ByteString) -> Request -> Maybe Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HeaderName -> [(HeaderName, ByteString)] -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
hHost ([(HeaderName, ByteString)] -> Maybe ByteString)
-> (Request -> [(HeaderName, ByteString)])
-> Request
-> Maybe ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> [(HeaderName, ByteString)]
WAI.requestHeaders

-- | Extracts the value of the @User-Agent@ header, if present, using lenient UTF-8 decoding.
getRequestUserAgent :: Request -> Maybe Text
getRequestUserAgent :: Request -> Maybe Text
getRequestUserAgent = (ByteString -> Text) -> Maybe ByteString -> Maybe Text
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (OnDecodeError -> ByteString -> Text
TE.decodeUtf8With OnDecodeError
TEE.lenientDecode) (Maybe ByteString -> Maybe Text)
-> (Request -> Maybe ByteString) -> Request -> Maybe Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HeaderName -> [(HeaderName, ByteString)] -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
hUserAgent ([(HeaderName, ByteString)] -> Maybe ByteString)
-> (Request -> [(HeaderName, ByteString)])
-> Request
-> Maybe ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> [(HeaderName, ByteString)]
WAI.requestHeaders

----------------------------------------------------------------------
-- Composite key builders
----------------------------------------------------------------------

-- | Creates a request key based solely on the client's IP address.
--
-- This function always succeeds and returns a 'Just' value, as every WAI
-- request has an associated socket address (IPv4, IPv6, or Unix socket).
--
-- ==== __Example__
--
-- @
-- byIP req ⇨ pure (Just "127.0.0.1")
-- @
byIP :: Request -> IO (Maybe Text)
byIP :: Request -> IO (Maybe Text)
byIP Request
req = Text -> Maybe Text
forall a. a -> Maybe a
Just (Text -> Maybe Text) -> IO Text -> IO (Maybe Text)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Request -> IO Text
getClientIP Request
req

-- | Creates a composite key by combining the client IP and the request path,
-- separated by a colon.
--
-- This is useful for rate-limiting access to specific endpoints rather than
-- penalizing a client for all of its requests. This function always succeeds
-- since both IP and path are always available.
--
-- ==== __Example__
--
-- @
-- -- For a request to \/api\/v1\/users from 192.168.1.10
-- byIPAndPath req ⇨ pure (Just "192.168.1.10:\/api\/v1\/users")
-- @
byIPAndPath :: Request -> IO (Maybe Text)
byIPAndPath :: Request -> IO (Maybe Text)
byIPAndPath Request
req = do
  Text
ip <- Request -> IO Text
getClientIP Request
req
  Maybe Text -> IO (Maybe Text)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Text -> IO (Maybe Text))
-> (Text -> Maybe Text) -> Text -> IO (Maybe Text)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Maybe Text
forall a. a -> Maybe a
Just (Text -> IO (Maybe Text)) -> Text -> IO (Maybe Text)
forall a b. (a -> b) -> a -> b
$ Text
ip Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
":" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Request -> Text
getRequestPath Request
req

-- | Creates a composite key by combining the client IP and the @User-Agent@
-- header, separated by a colon.
--
-- Returns 'Nothing' if the @User-Agent@ header is not present in the request.
--
-- ==== __Example__
--
-- @
-- -- For a request from Googlebot at 8.8.8.8
-- byIPAndUserAgent req ⇨ pure (Just "8.8.8.8:Mozilla\/5.0 (compatible; Googlebot\/2.1)")
-- @
byIPAndUserAgent :: Request -> IO (Maybe Text)
byIPAndUserAgent :: Request -> IO (Maybe Text)
byIPAndUserAgent Request
req = do
  Text
ip <- Request -> IO Text
getClientIP Request
req
  Maybe Text -> IO (Maybe Text)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Text -> IO (Maybe Text)) -> Maybe Text -> IO (Maybe Text)
forall a b. (a -> b) -> a -> b
$ case Request -> Maybe Text
getRequestUserAgent Request
req of
           Maybe Text
Nothing  -> Maybe Text
forall a. Maybe a
Nothing
           Just Text
ua  -> Text -> Maybe Text
forall a. a -> Maybe a
Just (Text
ip Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
":" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
ua)

-- | Builds a key from an arbitrary header and the client IP, joined by a colon.
--
-- Header lookup is case-insensitive. Returns 'Nothing' if the specified header
-- is absent from the request.
--
-- ==== __Example__
--
-- This can be used to rate-limit based on an API key plus the user's IP.
--
-- @
-- -- Given a request with header "X-Api-Key: mysecret" from 1.2.3.4
-- byHeaderAndIP "x-api-key" req ⇨ pure (Just "1.2.3.4:mysecret")
-- @
byHeaderAndIP :: HeaderName -> Request -> IO (Maybe Text)
byHeaderAndIP :: HeaderName -> Request -> IO (Maybe Text)
byHeaderAndIP HeaderName
headerName Request
req = do
  Text
ip <- Request -> IO Text
getClientIP Request
req
  let mVal :: Maybe ByteString
mVal = HeaderName -> [(HeaderName, ByteString)] -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
headerName (Request -> [(HeaderName, ByteString)]
WAI.requestHeaders Request
req)
  Maybe Text -> IO (Maybe Text)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Text -> IO (Maybe Text)) -> Maybe Text -> IO (Maybe Text)
forall a b. (a -> b) -> a -> b
$ (ByteString -> Text) -> Maybe ByteString -> Maybe Text
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\ByteString
hv -> Text
ip Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
":" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> OnDecodeError -> ByteString -> Text
TE.decodeUtf8With OnDecodeError
TEE.lenientDecode ByteString
hv) Maybe ByteString
mVal