{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE DeriveGeneric #-}

{-|
Module      : Keter.RateLimiter.WAI
Description : WAI-compatible, plugin-friendly rate limiting middleware with IP-zone support
License     : MIT
Maintainer  : oleksandr.zhabenko@yahoo.com
Copyright   : (c) 2025 Oleksandr Zhabenko
Stability   : stable
Portability : portable

This file is a ported to Haskell language code with some simplifications of rack-attack
<https://github.com/rack/rack-attack/blob/main/lib/rack/attack.rb>
and is based on the structure of the original code of
rack-attack, Copyright (c) 2016 by Kickstarter, PBC, under the MIT License.
Oleksandr Zhabenko added several implementations of the window algorithm: tinyLRU, sliding window, token bucket window, leaky bucket window alongside with the initial count algorithm using AI chatbots.
IP Zone functionality added to allow separate caches per IP zone.

Overview
========

This module provides WAI middleware for declarative, IP-zone-aware rate limiting with
multiple algorithms:

- Fixed Window
- Sliding Window
- Token Bucket
- Leaky Bucket
- TinyLRU

Key points
----------

- Plugin-friendly construction: build an environment once ('Env') from 'RateLimiterConfig'
  and produce a pure WAI 'Middleware'. This matches common WAI patterns and avoids
  per-request setup or global mutable state.

- Concurrency model: all shared structures inside 'Env' use STM 'TVar', not 'IORef'.
  This ensures thread-safe updates under GHC's lightweight (green) threads.

- Zone-specific caches: per-IP-zone caches are stored in a HashMap keyed by zone
  identifiers. Zones are derived from a configurable strategy ('ZoneBy'), with a default.

- No global caches in Keter: you can build one 'Env' per compiled middleware chain
  and cache that chain externally (e.g., per-vhost + middleware-list), preserving
  counters/windows across requests.

Quick start
-----------

1) Declarative configuration (e.g., parsed from JSON/YAML):

@
let cfg = RateLimiterConfig
      { rlZoneBy = ZoneDefault
      , rlThrottles =
          [ RLThrottle "api"   1000 3600 FixedWindow IdIP Nothing
          , RLThrottle "login" 5    300  TokenBucket IdIP (Just 600)
          ]
      }
@

2) Build 'Env' once and obtain a pure 'Middleware':

@
env <- buildEnvFromConfig cfg
let mw = buildRateLimiterWithEnv env
app = mw baseApplication
@

Alternatively:

@
mw <- buildRateLimiter cfg  -- convenience: Env creation + Middleware
app = mw baseApplication
@

Usage patterns
--------------

__Declarative approach (recommended):__

@
import Keter.RateLimiter.WAI
import Keter.RateLimiter.Cache (Algorithm(..))

main = do
  let config = RateLimiterConfig
        { rlZoneBy = ZoneIP
        , rlThrottles = 
            [ RLThrottle "api" 100 3600 FixedWindow IdIP Nothing
            ]
        }
  middleware <- buildRateLimiter config
  let app = middleware baseApp
  run 8080 app
@

__Programmatic approach (advanced):__

@
import Keter.RateLimiter.WAI
import Keter.RateLimiter.Cache (Algorithm(..))

main = do
  env <- initConfig (\\req -> "zone1")
  let throttleConfig = ThrottleConfig
        { throttleLimit = 100
        , throttlePeriod = 3600
        , throttleAlgorithm = FixedWindow
        , throttleIdentifierBy = IdIP
        , throttleTokenBucketTTL = Nothing
        }
  env' <- addThrottle env "api" throttleConfig
  let middleware = buildRateLimiterWithEnv env'
      app = middleware baseApp
  run 8080 app
@

Configuration reference
-----------------------

__Client identification strategies ('IdentifierBy'):__

- 'IdIP' - Identify by client IP address
- 'IdIPAndPath' - Identify by IP address and request path
- 'IdIPAndUA' - Identify by IP address and User-Agent header
- @'IdHeader' headerName@ - Identify by custom header value
- @'IdCookie' cookieName@ - Identify by cookie value
- @'IdHeaderAndIP' headerName@ - Identify by header value combined with IP

__Zone derivation strategies ('ZoneBy'):__

- 'ZoneDefault' - All requests use the same cache (no zone separation)
- 'ZoneIP' - Separate zones by client IP address
- @'ZoneHeader' headerName@ - Separate zones by custom header value

__Rate limiting algorithms:__

- 'FixedWindow' - Traditional fixed-window counting
- 'SlidingWindow' - Precise sliding-window with timestamp tracking
- 'TokenBucket' - Allow bursts up to capacity, refill over time
- 'LeakyBucket' - Smooth rate limiting with configurable leak rate
- 'TinyLRU' - Least-recently-used eviction for memory efficiency

-}

module Keter.RateLimiter.WAI
  ( -- * Environment & Configuration
    Env(..)
  , ThrottleConfig(..)
  , IdentifierBy(..)
  , ZoneBy(..)
  , RLThrottle(..)
  , RateLimiterConfig(..)
  , initConfig
  , addThrottle

    -- * Middleware
  , attackMiddleware         -- ^ Low-level: apply throttling with an existing 'Env'
  , buildRateLimiter         -- ^ Convenience: build 'Env' from config, return 'Middleware'
  , buildRateLimiterWithEnv  -- ^ Preferred: pure 'Middleware' from a pre-built 'Env'
  , buildEnvFromConfig       -- ^ Build 'Env' once from 'RateLimiterConfig'

    -- * Manual Control & Inspection
  , instrument
  , cacheResetAll

    -- * Helpers for configuration
  , registerThrottle
  , mkIdentifier
  , mkZoneFn
  , getClientIPPure
  , hdr
  , fromHeaderName
  ) where

import Control.Concurrent.STM
import Data.Aeson hiding (pairs)
import qualified Data.ByteString as S
import qualified Data.ByteString.Lazy as LBS
import Data.CaseInsensitive (mk, original)
import Data.Foldable (asum)
import Data.Hashable (Hashable(..))
import qualified Data.HashMap.Strict as HM
import Data.Maybe (fromMaybe)
import Data.Text (Text)
import qualified Data.Text as Tx
import qualified Data.Text.Encoding as TE
import qualified Data.Text.Encoding.Error as TEE
import GHC.Generics
import Network.HTTP.Types (HeaderName, hCookie, status429)
import Network.Socket (SockAddr (..))
import Network.Wai
import qualified Web.Cookie as WC

-- Import Cache with hiding Algorithm to avoid conflict, then import Algorithm explicitly
import Keter.RateLimiter.Cache hiding (Algorithm)
import Keter.RateLimiter.Cache (Algorithm(..))
import Keter.RateLimiter.CacheWithZone (allowFixedWindowRequest)
import Keter.RateLimiter.IPZones
  ( IPZoneIdentifier
  , ZoneSpecificCaches(..)
  , createZoneCaches
  , defaultIPZone
  )
import qualified Keter.RateLimiter.LeakyBucket as LeakyBucket
import qualified Keter.RateLimiter.RequestUtils as RU
import qualified Keter.RateLimiter.SlidingWindow as SlidingWindow
import qualified Keter.RateLimiter.TokenBucket as TokenBucket
import Data.TinyLRU (allowRequestTinyLRU)
import System.Clock (Clock (Monotonic), getTime)
import Data.Time.Clock.POSIX (getPOSIXTime)

--------------------------------------------------------------------------------
-- Configuration and Environment

-- | Runtime throttle parameters assembled from declarative configuration.
--
-- See 'RLThrottle' for the declarative counterpart.
data ThrottleConfig = ThrottleConfig
  { ThrottleConfig -> Int
throttleLimit :: !Int
    -- ^ Maximum allowed requests per period.
  , ThrottleConfig -> Int
throttlePeriod :: !Int
    -- ^ Period length in seconds.
  , ThrottleConfig -> Algorithm
throttleAlgorithm :: !Algorithm
    -- ^ Which throttling algorithm to use.
  , ThrottleConfig -> IdentifierBy
throttleIdentifierBy :: !IdentifierBy
    -- ^ Declarative spec for extracting an identifier (e.g., IP, header, cookie).
    -- At runtime we derive the extractor using 'mkIdentifier' and compute it
    -- at most once per request per IdentifierBy group. If extraction yields
    -- Nothing, this throttle does not apply to the request.
  , ThrottleConfig -> Maybe Int
throttleTokenBucketTTL :: !(Maybe Int)
    -- ^ Optional TTL (seconds) for TokenBucket entries.
  } deriving (Int -> ThrottleConfig -> ShowS
[ThrottleConfig] -> ShowS
ThrottleConfig -> String
(Int -> ThrottleConfig -> ShowS)
-> (ThrottleConfig -> String)
-> ([ThrottleConfig] -> ShowS)
-> Show ThrottleConfig
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ThrottleConfig -> ShowS
showsPrec :: Int -> ThrottleConfig -> ShowS
$cshow :: ThrottleConfig -> String
show :: ThrottleConfig -> String
$cshowList :: [ThrottleConfig] -> ShowS
showList :: [ThrottleConfig] -> ShowS
Show, ThrottleConfig -> ThrottleConfig -> Bool
(ThrottleConfig -> ThrottleConfig -> Bool)
-> (ThrottleConfig -> ThrottleConfig -> Bool) -> Eq ThrottleConfig
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ThrottleConfig -> ThrottleConfig -> Bool
== :: ThrottleConfig -> ThrottleConfig -> Bool
$c/= :: ThrottleConfig -> ThrottleConfig -> Bool
/= :: ThrottleConfig -> ThrottleConfig -> Bool
Eq, (forall x. ThrottleConfig -> Rep ThrottleConfig x)
-> (forall x. Rep ThrottleConfig x -> ThrottleConfig)
-> Generic ThrottleConfig
forall x. Rep ThrottleConfig x -> ThrottleConfig
forall x. ThrottleConfig -> Rep ThrottleConfig x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. ThrottleConfig -> Rep ThrottleConfig x
from :: forall x. ThrottleConfig -> Rep ThrottleConfig x
$cto :: forall x. Rep ThrottleConfig x -> ThrottleConfig
to :: forall x. Rep ThrottleConfig x -> ThrottleConfig
Generic)

-- | Thread-safe, shared state for rate limiting.
--
-- = Concurrency model
--
-- - Uses 'TVar' from STM for in-memory HashMaps.
-- - Safe for green-threaded request handlers.
-- - No global variables: construct 'Env' in your wiring/bootstrap and reuse it.
data Env = Env
  { Env -> TVar (HashMap IPZoneIdentifier ZoneSpecificCaches)
envZoneCachesMap    :: TVar (HM.HashMap IPZoneIdentifier ZoneSpecificCaches)
    -- ^ Per-zone caches for all algorithms.
  , Env -> TVar (HashMap IPZoneIdentifier ThrottleConfig)
envThrottles        :: TVar (HM.HashMap Text ThrottleConfig)
    -- ^ Named throttle configurations.
  , Env -> Request -> IPZoneIdentifier
envGetRequestIPZone :: Request -> IPZoneIdentifier
    -- ^ Function deriving the IP zone for a given request.
  }

-- | Initialize an empty environment with a zone-derivation function.
--
-- Populates the default zone lazily as needed; a default cache is allocated
-- immediately for the default zone to keep fast-path lookups cheap.
initConfig
  :: (Request -> IPZoneIdentifier)  -- ^ Request -> zone label
  -> IO Env
initConfig :: (Request -> IPZoneIdentifier) -> IO Env
initConfig Request -> IPZoneIdentifier
getIPZone = do
  ZoneSpecificCaches
defaultCaches <- IO ZoneSpecificCaches
createZoneCaches
  TVar (HashMap IPZoneIdentifier ZoneSpecificCaches)
zoneCachesMap <- HashMap IPZoneIdentifier ZoneSpecificCaches
-> IO (TVar (HashMap IPZoneIdentifier ZoneSpecificCaches))
forall a. a -> IO (TVar a)
newTVarIO (HashMap IPZoneIdentifier ZoneSpecificCaches
 -> IO (TVar (HashMap IPZoneIdentifier ZoneSpecificCaches)))
-> HashMap IPZoneIdentifier ZoneSpecificCaches
-> IO (TVar (HashMap IPZoneIdentifier ZoneSpecificCaches))
forall a b. (a -> b) -> a -> b
$ IPZoneIdentifier
-> ZoneSpecificCaches
-> HashMap IPZoneIdentifier ZoneSpecificCaches
forall k v. Hashable k => k -> v -> HashMap k v
HM.singleton IPZoneIdentifier
defaultIPZone ZoneSpecificCaches
defaultCaches
  TVar (HashMap IPZoneIdentifier ThrottleConfig)
throttles     <- HashMap IPZoneIdentifier ThrottleConfig
-> IO (TVar (HashMap IPZoneIdentifier ThrottleConfig))
forall a. a -> IO (TVar a)
newTVarIO HashMap IPZoneIdentifier ThrottleConfig
forall k v. HashMap k v
HM.empty
  Env -> IO Env
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Env -> IO Env) -> Env -> IO Env
forall a b. (a -> b) -> a -> b
$ TVar (HashMap IPZoneIdentifier ZoneSpecificCaches)
-> TVar (HashMap IPZoneIdentifier ThrottleConfig)
-> (Request -> IPZoneIdentifier)
-> Env
Env TVar (HashMap IPZoneIdentifier ZoneSpecificCaches)
zoneCachesMap TVar (HashMap IPZoneIdentifier ThrottleConfig)
throttles Request -> IPZoneIdentifier
getIPZone

-- | Add or replace a named throttle configuration.
--
-- STM-backed insertion for concurrency safety.
addThrottle
  :: Env
  -> Text
  -> ThrottleConfig
  -> IO Env
addThrottle :: Env -> IPZoneIdentifier -> ThrottleConfig -> IO Env
addThrottle Env
env IPZoneIdentifier
name ThrottleConfig
config = do
  STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar (HashMap IPZoneIdentifier ThrottleConfig)
-> (HashMap IPZoneIdentifier ThrottleConfig
    -> HashMap IPZoneIdentifier ThrottleConfig)
-> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' (Env -> TVar (HashMap IPZoneIdentifier ThrottleConfig)
envThrottles Env
env) ((HashMap IPZoneIdentifier ThrottleConfig
  -> HashMap IPZoneIdentifier ThrottleConfig)
 -> STM ())
-> (HashMap IPZoneIdentifier ThrottleConfig
    -> HashMap IPZoneIdentifier ThrottleConfig)
-> STM ()
forall a b. (a -> b) -> a -> b
$ IPZoneIdentifier
-> ThrottleConfig
-> HashMap IPZoneIdentifier ThrottleConfig
-> HashMap IPZoneIdentifier ThrottleConfig
forall k v.
(Eq k, Hashable k) =>
k -> v -> HashMap k v -> HashMap k v
HM.insert IPZoneIdentifier
name ThrottleConfig
config
  Env -> IO Env
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Env
env

--------------------------------------------------------------------------------
-- Middleware (application of throttles)

-- | Low-level middleware: apply throttling using an existing 'Env'.
--
-- If any throttle denies the request, a 429 response is returned.
-- Otherwise, 'app' is invoked.
attackMiddleware
  :: Env
  -> Application
  -> Application
attackMiddleware :: Env -> Application -> Application
attackMiddleware Env
env Application
app Request
req Response -> IO ResponseReceived
respond = do
  Bool
blocked <- Env -> Request -> IO Bool
instrument Env
env Request
req
  if Bool
blocked
    then Response -> IO ResponseReceived
respond (Response -> IO ResponseReceived)
-> Response -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ Status -> ResponseHeaders -> ByteString -> Response
responseLBS Status
status429 [(CI ByteString
"Content-Type",ByteString
"text/plain; charset=utf-8")]
                      (ByteString -> ByteString
LBS.fromStrict (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ IPZoneIdentifier -> ByteString
TE.encodeUtf8 IPZoneIdentifier
"Too Many Requests")
    else Application
app Request
req Response -> IO ResponseReceived
respond

-- | Inspect all active throttles in 'Env' for the given request.
--
-- Returns True if the request should be blocked under any rule.
instrument :: Env -> Request -> IO Bool
instrument :: Env -> Request -> IO Bool
instrument Env
env Request
req = do
  HashMap IPZoneIdentifier ThrottleConfig
throttles <- TVar (HashMap IPZoneIdentifier ThrottleConfig)
-> IO (HashMap IPZoneIdentifier ThrottleConfig)
forall a. TVar a -> IO a
readTVarIO (Env -> TVar (HashMap IPZoneIdentifier ThrottleConfig)
envThrottles Env
env)
  if HashMap IPZoneIdentifier ThrottleConfig -> Bool
forall k v. HashMap k v -> Bool
HM.null HashMap IPZoneIdentifier ThrottleConfig
throttles
    then Bool -> IO Bool
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
    else do
      let zone :: IPZoneIdentifier
zone = Env -> Request -> IPZoneIdentifier
envGetRequestIPZone Env
env Request
req
      ZoneSpecificCaches
caches <- Env -> IPZoneIdentifier -> IO ZoneSpecificCaches
getOrCreateZoneCaches Env
env IPZoneIdentifier
zone
      let buckets :: Grouped
buckets = HashMap IPZoneIdentifier ThrottleConfig -> Grouped
groupByIdentifier HashMap IPZoneIdentifier ThrottleConfig
throttles
      (IdentifierBy -> [(IPZoneIdentifier, ThrottleConfig)] -> IO Bool)
-> Grouped -> IO Bool
forall k v. (k -> v -> IO Bool) -> HashMap k v -> IO Bool
anyMHashMap
        (\IdentifierBy
idBy [(IPZoneIdentifier, ThrottleConfig)]
group ->
           case [(IPZoneIdentifier, ThrottleConfig)]
group of
             [] -> Bool -> IO Bool
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
             ((IPZoneIdentifier
_name0, ThrottleConfig
_cfg0):[(IPZoneIdentifier, ThrottleConfig)]
_) -> do
               -- Compute identifier once per IdentifierBy group
               Maybe IPZoneIdentifier
mIdent <- IdentifierBy -> Request -> IO (Maybe IPZoneIdentifier)
mkIdentifier IdentifierBy
idBy Request
req
               case Maybe IPZoneIdentifier
mIdent of
                 Maybe IPZoneIdentifier
Nothing    -> Bool -> IO Bool
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
                 Just IPZoneIdentifier
ident ->
                   ((IPZoneIdentifier, ThrottleConfig) -> IO Bool)
-> [(IPZoneIdentifier, ThrottleConfig)] -> IO Bool
forall a. (a -> IO Bool) -> [a] -> IO Bool
anyMList
                     (\(IPZoneIdentifier
name, ThrottleConfig
cfg) ->
                        ZoneSpecificCaches
-> IPZoneIdentifier
-> Request
-> IPZoneIdentifier
-> ThrottleConfig
-> Maybe IPZoneIdentifier
-> IO Bool
checkThrottleWithIdent ZoneSpecificCaches
caches IPZoneIdentifier
zone Request
req IPZoneIdentifier
name ThrottleConfig
cfg (IPZoneIdentifier -> Maybe IPZoneIdentifier
forall a. a -> Maybe a
Just IPZoneIdentifier
ident)
                     )
                     [(IPZoneIdentifier, ThrottleConfig)]
group
        )
        Grouped
buckets

-- | Check an individual throttle with a precomputed identifier.
--
-- True = block, False = allow.
checkThrottleWithIdent
  :: ZoneSpecificCaches
  -> Text                 -- ^ zone
  -> Request
  -> Text                 -- ^ throttle name
  -> ThrottleConfig
  -> Maybe Text           -- ^ precomputed identifier
  -> IO Bool
checkThrottleWithIdent :: ZoneSpecificCaches
-> IPZoneIdentifier
-> Request
-> IPZoneIdentifier
-> ThrottleConfig
-> Maybe IPZoneIdentifier
-> IO Bool
checkThrottleWithIdent ZoneSpecificCaches
caches IPZoneIdentifier
zone Request
_req IPZoneIdentifier
throttleName ThrottleConfig
cfg Maybe IPZoneIdentifier
mIdentifier =
  case Maybe IPZoneIdentifier
mIdentifier of
    Maybe IPZoneIdentifier
Nothing    -> Bool -> IO Bool
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
    Just IPZoneIdentifier
ident ->
      case ThrottleConfig -> Algorithm
throttleAlgorithm ThrottleConfig
cfg of
        -- Use unqualified Algorithm constructors since we imported them explicitly
        Algorithm
FixedWindow ->
          -- allowFixedWindowRequest cache throttleName zone ident limit period
          Bool -> Bool
not (Bool -> Bool) -> IO Bool -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Cache (InMemoryStore 'FixedWindow)
-> IPZoneIdentifier
-> IPZoneIdentifier
-> IPZoneIdentifier
-> Int
-> Int
-> IO Bool
allowFixedWindowRequest
                    (ZoneSpecificCaches -> Cache (InMemoryStore 'FixedWindow)
zscCounterCache ZoneSpecificCaches
caches)
                    IPZoneIdentifier
throttleName
                    IPZoneIdentifier
zone
                    IPZoneIdentifier
ident
                    (ThrottleConfig -> Int
throttleLimit ThrottleConfig
cfg)
                    (ThrottleConfig -> Int
throttlePeriod ThrottleConfig
cfg)

        Algorithm
SlidingWindow -> case ZoneSpecificCaches -> Cache (InMemoryStore 'SlidingWindow)
zscTimestampCache ZoneSpecificCaches
caches of
          Cache { cacheStore :: forall store. Cache store -> store
cacheStore = TimestampStore TVar (Map IPZoneIdentifier [Double])
tvar } ->
            -- SlidingWindow.allowRequest timeNow tvar throttleName zone ident window limit
            Bool -> Bool
not (Bool -> Bool) -> IO Bool -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO Double
-> TVar (Map IPZoneIdentifier [Double])
-> IPZoneIdentifier
-> IPZoneIdentifier
-> IPZoneIdentifier
-> Int
-> Int
-> IO Bool
SlidingWindow.allowRequest
                      (POSIXTime -> Double
forall a b. (Real a, Fractional b) => a -> b
realToFrac (POSIXTime -> Double) -> IO POSIXTime -> IO Double
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO POSIXTime
getPOSIXTime)
                      TVar (Map IPZoneIdentifier [Double])
tvar
                      IPZoneIdentifier
throttleName
                      IPZoneIdentifier
zone
                      IPZoneIdentifier
ident
                      (ThrottleConfig -> Int
throttlePeriod ThrottleConfig
cfg)
                      (ThrottleConfig -> Int
throttleLimit ThrottleConfig
cfg)

        Algorithm
TokenBucket -> do
          let period :: Int
period     = ThrottleConfig -> Int
throttlePeriod ThrottleConfig
cfg
              limit :: Int
limit      = ThrottleConfig -> Int
throttleLimit ThrottleConfig
cfg
              refillRate :: Double
refillRate = if Int
period Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 then Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
limit Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
period else Double
0.0
              ttl :: Int
ttl        = Int -> Maybe Int -> Int
forall a. a -> Maybe a -> a
fromMaybe Int
2 (ThrottleConfig -> Maybe Int
throttleTokenBucketTTL ThrottleConfig
cfg)
          -- TokenBucket.allowRequest cache throttleName zone ident capacity refill expires
          Bool -> Bool
not (Bool -> Bool) -> IO Bool -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Cache (InMemoryStore 'TokenBucket)
-> IPZoneIdentifier
-> IPZoneIdentifier
-> IPZoneIdentifier
-> Int
-> Double
-> Int
-> IO Bool
forall (m :: * -> *).
MonadIO m =>
Cache (InMemoryStore 'TokenBucket)
-> IPZoneIdentifier
-> IPZoneIdentifier
-> IPZoneIdentifier
-> Int
-> Double
-> Int
-> m Bool
TokenBucket.allowRequest
                    (ZoneSpecificCaches -> Cache (InMemoryStore 'TokenBucket)
zscTokenBucketCache ZoneSpecificCaches
caches)
                    IPZoneIdentifier
throttleName
                    IPZoneIdentifier
zone
                    IPZoneIdentifier
ident
                    (Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
limit)
                    Double
refillRate
                    (Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
ttl)

        Algorithm
LeakyBucket -> do
          let period :: Int
period   = ThrottleConfig -> Int
throttlePeriod ThrottleConfig
cfg
              limit :: Int
limit    = ThrottleConfig -> Int
throttleLimit ThrottleConfig
cfg
              leakRate :: Double
leakRate = if Int
period Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 then Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
limit Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
period else Double
0.0
          -- LeakyBucket.allowRequest cache throttleName zone ident capacity leakRate
          Bool -> Bool
not (Bool -> Bool) -> IO Bool -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Cache (InMemoryStore 'LeakyBucket)
-> IPZoneIdentifier
-> IPZoneIdentifier
-> IPZoneIdentifier
-> Int
-> Double
-> IO Bool
forall (m :: * -> *).
MonadIO m =>
Cache (InMemoryStore 'LeakyBucket)
-> IPZoneIdentifier
-> IPZoneIdentifier
-> IPZoneIdentifier
-> Int
-> Double
-> m Bool
LeakyBucket.allowRequest
                    (ZoneSpecificCaches -> Cache (InMemoryStore 'LeakyBucket)
zscLeakyBucketCache ZoneSpecificCaches
caches)
                    IPZoneIdentifier
throttleName
                    IPZoneIdentifier
zone
                    IPZoneIdentifier
ident
                    (Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
limit)
                    Double
leakRate

        Algorithm
TinyLRU -> do
          TimeSpec
now <- Clock -> IO TimeSpec
getTime Clock
Monotonic
          case Cache (InMemoryStore 'TinyLRU) -> InMemoryStore 'TinyLRU
forall store. Cache store -> store
cacheStore (ZoneSpecificCaches -> Cache (InMemoryStore 'TinyLRU)
zscTinyLRUCache ZoneSpecificCaches
caches) of
            TinyLRUStore TVar (TinyLRUCache s)
tvar -> do
              TinyLRUCache s
cache <- TVar (TinyLRUCache s) -> IO (TinyLRUCache s)
forall a. TVar a -> IO a
readTVarIO TVar (TinyLRUCache s)
tvar
              -- allowRequestTinyLRU now cache ident capacity periodSecs
              Bool -> Bool
not (Bool -> Bool) -> IO Bool -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STM Bool -> IO Bool
forall a. STM a -> IO a
atomically (TimeSpec
-> TinyLRUCache s -> IPZoneIdentifier -> Int -> Int -> STM Bool
forall s.
TimeSpec
-> TinyLRUCache s -> IPZoneIdentifier -> Int -> Int -> STM Bool
allowRequestTinyLRU TimeSpec
now TinyLRUCache s
cache IPZoneIdentifier
ident (ThrottleConfig -> Int
throttleLimit ThrottleConfig
cfg) (ThrottleConfig -> Int
throttlePeriod ThrottleConfig
cfg))

-- | Backward-compatible entry that derives the identifier and delegates
-- to the precomputed path, ensuring no duplicate computation.
checkThrottle
  :: ZoneSpecificCaches -> Text -> Request -> Text -> ThrottleConfig -> IO Bool
checkThrottle :: ZoneSpecificCaches
-> IPZoneIdentifier
-> Request
-> IPZoneIdentifier
-> ThrottleConfig
-> IO Bool
checkThrottle ZoneSpecificCaches
caches IPZoneIdentifier
zone Request
req IPZoneIdentifier
throttleName ThrottleConfig
cfg = do
  Maybe IPZoneIdentifier
mIdentifier <- IdentifierBy -> Request -> IO (Maybe IPZoneIdentifier)
mkIdentifier (ThrottleConfig -> IdentifierBy
throttleIdentifierBy ThrottleConfig
cfg) Request
req
  ZoneSpecificCaches
-> IPZoneIdentifier
-> Request
-> IPZoneIdentifier
-> ThrottleConfig
-> Maybe IPZoneIdentifier
-> IO Bool
checkThrottleWithIdent ZoneSpecificCaches
caches IPZoneIdentifier
zone Request
req IPZoneIdentifier
throttleName ThrottleConfig
cfg Maybe IPZoneIdentifier
mIdentifier

-- | Reset all caches across all known zones.
--
-- Useful in tests or administrative endpoints.
cacheResetAll :: Env -> IO ()
cacheResetAll :: Env -> IO ()
cacheResetAll Env
env = do
  HashMap IPZoneIdentifier ZoneSpecificCaches
zoneCachesMap <- TVar (HashMap IPZoneIdentifier ZoneSpecificCaches)
-> IO (HashMap IPZoneIdentifier ZoneSpecificCaches)
forall a. TVar a -> IO a
readTVarIO (Env -> TVar (HashMap IPZoneIdentifier ZoneSpecificCaches)
envZoneCachesMap Env
env)
  ((IPZoneIdentifier, ZoneSpecificCaches) -> IO ())
-> [(IPZoneIdentifier, ZoneSpecificCaches)] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (ZoneSpecificCaches -> IO ()
resetZone (ZoneSpecificCaches -> IO ())
-> ((IPZoneIdentifier, ZoneSpecificCaches) -> ZoneSpecificCaches)
-> (IPZoneIdentifier, ZoneSpecificCaches)
-> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IPZoneIdentifier, ZoneSpecificCaches) -> ZoneSpecificCaches
forall a b. (a, b) -> b
snd) (HashMap IPZoneIdentifier ZoneSpecificCaches
-> [(IPZoneIdentifier, ZoneSpecificCaches)]
forall k v. HashMap k v -> [(k, v)]
HM.toList HashMap IPZoneIdentifier ZoneSpecificCaches
zoneCachesMap)
  where
    resetZone :: ZoneSpecificCaches -> IO ()
    resetZone :: ZoneSpecificCaches -> IO ()
resetZone ZoneSpecificCaches
caches = do
      Cache (InMemoryStore 'FixedWindow) -> IO ()
forall store. ResettableStore store => Cache store -> IO ()
cacheReset (ZoneSpecificCaches -> Cache (InMemoryStore 'FixedWindow)
zscCounterCache ZoneSpecificCaches
caches)
      Cache (InMemoryStore 'SlidingWindow) -> IO ()
forall store. ResettableStore store => Cache store -> IO ()
cacheReset (ZoneSpecificCaches -> Cache (InMemoryStore 'SlidingWindow)
zscTimestampCache ZoneSpecificCaches
caches)
      Cache (InMemoryStore 'TokenBucket) -> IO ()
forall store. ResettableStore store => Cache store -> IO ()
cacheReset (ZoneSpecificCaches -> Cache (InMemoryStore 'TokenBucket)
zscTokenBucketCache ZoneSpecificCaches
caches)
      Cache (InMemoryStore 'LeakyBucket) -> IO ()
forall store. ResettableStore store => Cache store -> IO ()
cacheReset (ZoneSpecificCaches -> Cache (InMemoryStore 'LeakyBucket)
zscLeakyBucketCache ZoneSpecificCaches
caches)
      Cache (InMemoryStore 'TinyLRU) -> IO ()
forall store. ResettableStore store => Cache store -> IO ()
cacheReset (ZoneSpecificCaches -> Cache (InMemoryStore 'TinyLRU)
zscTinyLRUCache ZoneSpecificCaches
caches)

-- | Retrieve or create caches for a given IP zone.
--
-- Ensures a single writer initializes a new zone; readers see either the
-- existing or newly-inserted caches.
getOrCreateZoneCaches
  :: Env
  -> IPZoneIdentifier
  -> IO ZoneSpecificCaches
getOrCreateZoneCaches :: Env -> IPZoneIdentifier -> IO ZoneSpecificCaches
getOrCreateZoneCaches Env
env IPZoneIdentifier
zone = do
  HashMap IPZoneIdentifier ZoneSpecificCaches
m <- TVar (HashMap IPZoneIdentifier ZoneSpecificCaches)
-> IO (HashMap IPZoneIdentifier ZoneSpecificCaches)
forall a. TVar a -> IO a
readTVarIO (Env -> TVar (HashMap IPZoneIdentifier ZoneSpecificCaches)
envZoneCachesMap Env
env)
  case IPZoneIdentifier
-> HashMap IPZoneIdentifier ZoneSpecificCaches
-> Maybe ZoneSpecificCaches
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
HM.lookup IPZoneIdentifier
zone HashMap IPZoneIdentifier ZoneSpecificCaches
m of
    Just ZoneSpecificCaches
caches -> ZoneSpecificCaches -> IO ZoneSpecificCaches
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ZoneSpecificCaches
caches
    Maybe ZoneSpecificCaches
Nothing -> do
      ZoneSpecificCaches
newCaches <- IO ZoneSpecificCaches
createZoneCaches
      STM ZoneSpecificCaches -> IO ZoneSpecificCaches
forall a. STM a -> IO a
atomically (STM ZoneSpecificCaches -> IO ZoneSpecificCaches)
-> STM ZoneSpecificCaches -> IO ZoneSpecificCaches
forall a b. (a -> b) -> a -> b
$ do
        HashMap IPZoneIdentifier ZoneSpecificCaches
m0 <- TVar (HashMap IPZoneIdentifier ZoneSpecificCaches)
-> STM (HashMap IPZoneIdentifier ZoneSpecificCaches)
forall a. TVar a -> STM a
readTVar (Env -> TVar (HashMap IPZoneIdentifier ZoneSpecificCaches)
envZoneCachesMap Env
env)
        case IPZoneIdentifier
-> HashMap IPZoneIdentifier ZoneSpecificCaches
-> Maybe ZoneSpecificCaches
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
HM.lookup IPZoneIdentifier
zone HashMap IPZoneIdentifier ZoneSpecificCaches
m0 of
          Just ZoneSpecificCaches
existing -> ZoneSpecificCaches -> STM ZoneSpecificCaches
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ZoneSpecificCaches
existing
          Maybe ZoneSpecificCaches
Nothing -> do
            let m1 :: HashMap IPZoneIdentifier ZoneSpecificCaches
m1 = IPZoneIdentifier
-> ZoneSpecificCaches
-> HashMap IPZoneIdentifier ZoneSpecificCaches
-> HashMap IPZoneIdentifier ZoneSpecificCaches
forall k v.
(Eq k, Hashable k) =>
k -> v -> HashMap k v -> HashMap k v
HM.insert IPZoneIdentifier
zone ZoneSpecificCaches
newCaches HashMap IPZoneIdentifier ZoneSpecificCaches
m0
            TVar (HashMap IPZoneIdentifier ZoneSpecificCaches)
-> HashMap IPZoneIdentifier ZoneSpecificCaches -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar (Env -> TVar (HashMap IPZoneIdentifier ZoneSpecificCaches)
envZoneCachesMap Env
env) HashMap IPZoneIdentifier ZoneSpecificCaches
m1
            ZoneSpecificCaches -> STM ZoneSpecificCaches
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ZoneSpecificCaches
newCaches

--------------------------------------------------------------------------------
-- Declarative configuration types

-- | How to identify clients for throttling.
data IdentifierBy
  = IdIP
  | IdHeader !HeaderName
  | IdCookie !Text
  | IdIPAndPath
  | IdIPAndUA
  | IdHeaderAndIP !HeaderName
  deriving (Int -> IdentifierBy -> ShowS
[IdentifierBy] -> ShowS
IdentifierBy -> String
(Int -> IdentifierBy -> ShowS)
-> (IdentifierBy -> String)
-> ([IdentifierBy] -> ShowS)
-> Show IdentifierBy
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> IdentifierBy -> ShowS
showsPrec :: Int -> IdentifierBy -> ShowS
$cshow :: IdentifierBy -> String
show :: IdentifierBy -> String
$cshowList :: [IdentifierBy] -> ShowS
showList :: [IdentifierBy] -> ShowS
Show, IdentifierBy -> IdentifierBy -> Bool
(IdentifierBy -> IdentifierBy -> Bool)
-> (IdentifierBy -> IdentifierBy -> Bool) -> Eq IdentifierBy
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: IdentifierBy -> IdentifierBy -> Bool
== :: IdentifierBy -> IdentifierBy -> Bool
$c/= :: IdentifierBy -> IdentifierBy -> Bool
/= :: IdentifierBy -> IdentifierBy -> Bool
Eq, (forall x. IdentifierBy -> Rep IdentifierBy x)
-> (forall x. Rep IdentifierBy x -> IdentifierBy)
-> Generic IdentifierBy
forall x. Rep IdentifierBy x -> IdentifierBy
forall x. IdentifierBy -> Rep IdentifierBy x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. IdentifierBy -> Rep IdentifierBy x
from :: forall x. IdentifierBy -> Rep IdentifierBy x
$cto :: forall x. Rep IdentifierBy x -> IdentifierBy
to :: forall x. Rep IdentifierBy x -> IdentifierBy
Generic)

-- Manual Hashable instance since HeaderName doesn't have one
instance Hashable IdentifierBy where
  hashWithSalt :: Int -> IdentifierBy -> Int
hashWithSalt Int
s IdentifierBy
IdIP = Int -> Int -> Int
forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
s (Int
0 :: Int)
  hashWithSalt Int
s (IdHeader CI ByteString
h) = Int -> (Int, ByteString) -> Int
forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
s (Int
1 :: Int, CI ByteString -> ByteString
forall s. CI s -> s
original CI ByteString
h)
  hashWithSalt Int
s (IdCookie IPZoneIdentifier
t) = Int -> (Int, IPZoneIdentifier) -> Int
forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
s (Int
2 :: Int, IPZoneIdentifier
t)
  hashWithSalt Int
s IdentifierBy
IdIPAndPath = Int -> Int -> Int
forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
s (Int
3 :: Int)
  hashWithSalt Int
s IdentifierBy
IdIPAndUA = Int -> Int -> Int
forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
s (Int
4 :: Int)
  hashWithSalt Int
s (IdHeaderAndIP CI ByteString
h) = Int -> (Int, ByteString) -> Int
forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
s (Int
5 :: Int, CI ByteString -> ByteString
forall s. CI s -> s
original CI ByteString
h)

-- | How to derive IP zones from requests.
data ZoneBy
  = ZoneDefault
  | ZoneIP
  | ZoneHeader !HeaderName
  deriving (Int -> ZoneBy -> ShowS
[ZoneBy] -> ShowS
ZoneBy -> String
(Int -> ZoneBy -> ShowS)
-> (ZoneBy -> String) -> ([ZoneBy] -> ShowS) -> Show ZoneBy
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ZoneBy -> ShowS
showsPrec :: Int -> ZoneBy -> ShowS
$cshow :: ZoneBy -> String
show :: ZoneBy -> String
$cshowList :: [ZoneBy] -> ShowS
showList :: [ZoneBy] -> ShowS
Show, ZoneBy -> ZoneBy -> Bool
(ZoneBy -> ZoneBy -> Bool)
-> (ZoneBy -> ZoneBy -> Bool) -> Eq ZoneBy
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ZoneBy -> ZoneBy -> Bool
== :: ZoneBy -> ZoneBy -> Bool
$c/= :: ZoneBy -> ZoneBy -> Bool
/= :: ZoneBy -> ZoneBy -> Bool
Eq, (forall x. ZoneBy -> Rep ZoneBy x)
-> (forall x. Rep ZoneBy x -> ZoneBy) -> Generic ZoneBy
forall x. Rep ZoneBy x -> ZoneBy
forall x. ZoneBy -> Rep ZoneBy x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. ZoneBy -> Rep ZoneBy x
from :: forall x. ZoneBy -> Rep ZoneBy x
$cto :: forall x. Rep ZoneBy x -> ZoneBy
to :: forall x. Rep ZoneBy x -> ZoneBy
Generic)

-- | Declarative throttle rule (parsed from JSON/YAML).
data RLThrottle = RLThrottle
  { RLThrottle -> IPZoneIdentifier
rlName   :: !Text
  , RLThrottle -> Int
rlLimit  :: !Int
  , RLThrottle -> Int
rlPeriod :: !Int
  , RLThrottle -> Algorithm
rlAlgo   :: !Algorithm
  , RLThrottle -> IdentifierBy
rlIdBy   :: !IdentifierBy
  , RLThrottle -> Maybe Int
rlTokenBucketTTL :: !(Maybe Int)
  } deriving (Int -> RLThrottle -> ShowS
[RLThrottle] -> ShowS
RLThrottle -> String
(Int -> RLThrottle -> ShowS)
-> (RLThrottle -> String)
-> ([RLThrottle] -> ShowS)
-> Show RLThrottle
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> RLThrottle -> ShowS
showsPrec :: Int -> RLThrottle -> ShowS
$cshow :: RLThrottle -> String
show :: RLThrottle -> String
$cshowList :: [RLThrottle] -> ShowS
showList :: [RLThrottle] -> ShowS
Show, RLThrottle -> RLThrottle -> Bool
(RLThrottle -> RLThrottle -> Bool)
-> (RLThrottle -> RLThrottle -> Bool) -> Eq RLThrottle
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: RLThrottle -> RLThrottle -> Bool
== :: RLThrottle -> RLThrottle -> Bool
$c/= :: RLThrottle -> RLThrottle -> Bool
/= :: RLThrottle -> RLThrottle -> Bool
Eq, (forall x. RLThrottle -> Rep RLThrottle x)
-> (forall x. Rep RLThrottle x -> RLThrottle) -> Generic RLThrottle
forall x. Rep RLThrottle x -> RLThrottle
forall x. RLThrottle -> Rep RLThrottle x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. RLThrottle -> Rep RLThrottle x
from :: forall x. RLThrottle -> Rep RLThrottle x
$cto :: forall x. Rep RLThrottle x -> RLThrottle
to :: forall x. Rep RLThrottle x -> RLThrottle
Generic)

-- | Top-level configuration: zone strategy and throttle rules.
data RateLimiterConfig = RateLimiterConfig
  { RateLimiterConfig -> ZoneBy
rlZoneBy    :: !ZoneBy
  , RateLimiterConfig -> [RLThrottle]
rlThrottles :: ![RLThrottle]
  } deriving (Int -> RateLimiterConfig -> ShowS
[RateLimiterConfig] -> ShowS
RateLimiterConfig -> String
(Int -> RateLimiterConfig -> ShowS)
-> (RateLimiterConfig -> String)
-> ([RateLimiterConfig] -> ShowS)
-> Show RateLimiterConfig
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> RateLimiterConfig -> ShowS
showsPrec :: Int -> RateLimiterConfig -> ShowS
$cshow :: RateLimiterConfig -> String
show :: RateLimiterConfig -> String
$cshowList :: [RateLimiterConfig] -> ShowS
showList :: [RateLimiterConfig] -> ShowS
Show, RateLimiterConfig -> RateLimiterConfig -> Bool
(RateLimiterConfig -> RateLimiterConfig -> Bool)
-> (RateLimiterConfig -> RateLimiterConfig -> Bool)
-> Eq RateLimiterConfig
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: RateLimiterConfig -> RateLimiterConfig -> Bool
== :: RateLimiterConfig -> RateLimiterConfig -> Bool
$c/= :: RateLimiterConfig -> RateLimiterConfig -> Bool
/= :: RateLimiterConfig -> RateLimiterConfig -> Bool
Eq, (forall x. RateLimiterConfig -> Rep RateLimiterConfig x)
-> (forall x. Rep RateLimiterConfig x -> RateLimiterConfig)
-> Generic RateLimiterConfig
forall x. Rep RateLimiterConfig x -> RateLimiterConfig
forall x. RateLimiterConfig -> Rep RateLimiterConfig x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. RateLimiterConfig -> Rep RateLimiterConfig x
from :: forall x. RateLimiterConfig -> Rep RateLimiterConfig x
$cto :: forall x. Rep RateLimiterConfig x -> RateLimiterConfig
to :: forall x. Rep RateLimiterConfig x -> RateLimiterConfig
Generic)

instance FromJSON IdentifierBy where
  parseJSON :: Value -> Parser IdentifierBy
parseJSON (String IPZoneIdentifier
"ip")        = IdentifierBy -> Parser IdentifierBy
forall a. a -> Parser a
forall (f :: * -> *) a. Applicative f => a -> f a
pure IdentifierBy
IdIP
  parseJSON (String IPZoneIdentifier
"ip+path")   = IdentifierBy -> Parser IdentifierBy
forall a. a -> Parser a
forall (f :: * -> *) a. Applicative f => a -> f a
pure IdentifierBy
IdIPAndPath
  parseJSON (String IPZoneIdentifier
"ip+ua")     = IdentifierBy -> Parser IdentifierBy
forall a. a -> Parser a
forall (f :: * -> *) a. Applicative f => a -> f a
pure IdentifierBy
IdIPAndUA
  parseJSON (Object Object
o) =
    [Parser IdentifierBy] -> Parser IdentifierBy
forall (t :: * -> *) (f :: * -> *) a.
(Foldable t, Alternative f) =>
t (f a) -> f a
asum [ CI ByteString -> IdentifierBy
IdHeader      (CI ByteString -> IdentifierBy)
-> (IPZoneIdentifier -> CI ByteString)
-> IPZoneIdentifier
-> IdentifierBy
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IPZoneIdentifier -> CI ByteString
hdr (IPZoneIdentifier -> IdentifierBy)
-> Parser IPZoneIdentifier -> Parser IdentifierBy
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Object
o Object -> Key -> Parser IPZoneIdentifier
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"header"
         , IPZoneIdentifier -> IdentifierBy
IdCookie            (IPZoneIdentifier -> IdentifierBy)
-> Parser IPZoneIdentifier -> Parser IdentifierBy
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Object
o Object -> Key -> Parser IPZoneIdentifier
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"cookie"
         , CI ByteString -> IdentifierBy
IdHeaderAndIP (CI ByteString -> IdentifierBy)
-> (IPZoneIdentifier -> CI ByteString)
-> IPZoneIdentifier
-> IdentifierBy
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IPZoneIdentifier -> CI ByteString
hdr (IPZoneIdentifier -> IdentifierBy)
-> Parser IPZoneIdentifier -> Parser IdentifierBy
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Object
o Object -> Key -> Parser IPZoneIdentifier
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"header+ip"
         ]
  parseJSON Value
_ = String -> Parser IdentifierBy
forall a. String -> Parser a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"identifier_by: 'ip' | 'ip+path' | 'ip+ua' | {header} | {cookie} | {header+ip}"

instance ToJSON IdentifierBy where
  toJSON :: IdentifierBy -> Value
toJSON IdentifierBy
IdIP              = IPZoneIdentifier -> Value
String IPZoneIdentifier
"ip"
  toJSON IdentifierBy
IdIPAndPath       = IPZoneIdentifier -> Value
String IPZoneIdentifier
"ip+path"
  toJSON IdentifierBy
IdIPAndUA         = IPZoneIdentifier -> Value
String IPZoneIdentifier
"ip+ua"
  toJSON (IdHeader CI ByteString
h)      = [Pair] -> Value
object [Key
"header"    Key -> IPZoneIdentifier -> Pair
forall v. ToJSON v => Key -> v -> Pair
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= ByteString -> IPZoneIdentifier
TE.decodeUtf8 (CI ByteString -> ByteString
fromHeaderName CI ByteString
h)]
  toJSON (IdCookie IPZoneIdentifier
c)      = [Pair] -> Value
object [Key
"cookie"    Key -> IPZoneIdentifier -> Pair
forall v. ToJSON v => Key -> v -> Pair
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= IPZoneIdentifier
c]
  toJSON (IdHeaderAndIP CI ByteString
h) = [Pair] -> Value
object [Key
"header+ip" Key -> IPZoneIdentifier -> Pair
forall v. ToJSON v => Key -> v -> Pair
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= ByteString -> IPZoneIdentifier
TE.decodeUtf8 (CI ByteString -> ByteString
fromHeaderName CI ByteString
h)]

instance FromJSON ZoneBy where
  parseJSON :: Value -> Parser ZoneBy
parseJSON (String IPZoneIdentifier
"default") = ZoneBy -> Parser ZoneBy
forall a. a -> Parser a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ZoneBy
ZoneDefault
  parseJSON (String IPZoneIdentifier
"ip")      = ZoneBy -> Parser ZoneBy
forall a. a -> Parser a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ZoneBy
ZoneIP
  parseJSON (Object Object
o)         = CI ByteString -> ZoneBy
ZoneHeader (CI ByteString -> ZoneBy)
-> (IPZoneIdentifier -> CI ByteString)
-> IPZoneIdentifier
-> ZoneBy
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IPZoneIdentifier -> CI ByteString
hdr (IPZoneIdentifier -> ZoneBy)
-> Parser IPZoneIdentifier -> Parser ZoneBy
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Object
o Object -> Key -> Parser IPZoneIdentifier
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"header"
  parseJSON Value
_ = String -> Parser ZoneBy
forall a. String -> Parser a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"zone_by: 'default' | 'ip' | {header}"

instance ToJSON ZoneBy where
  toJSON :: ZoneBy -> Value
toJSON ZoneBy
ZoneDefault     = IPZoneIdentifier -> Value
String IPZoneIdentifier
"default"
  toJSON ZoneBy
ZoneIP          = IPZoneIdentifier -> Value
String IPZoneIdentifier
"ip"
  toJSON (ZoneHeader CI ByteString
h)  = [Pair] -> Value
object [Key
"header" Key -> IPZoneIdentifier -> Pair
forall v. ToJSON v => Key -> v -> Pair
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= ByteString -> IPZoneIdentifier
TE.decodeUtf8 (CI ByteString -> ByteString
fromHeaderName CI ByteString
h)]

instance FromJSON RLThrottle where
  parseJSON :: Value -> Parser RLThrottle
parseJSON = String
-> (Object -> Parser RLThrottle) -> Value -> Parser RLThrottle
forall a. String -> (Object -> Parser a) -> Value -> Parser a
withObject String
"throttle" ((Object -> Parser RLThrottle) -> Value -> Parser RLThrottle)
-> (Object -> Parser RLThrottle) -> Value -> Parser RLThrottle
forall a b. (a -> b) -> a -> b
$ \Object
o -> do
    IPZoneIdentifier
n   <- Object
o Object -> Key -> Parser IPZoneIdentifier
forall a. FromJSON a => Object -> Key -> Parser a
.:  Key
"name"
    Int
l   <- Object
o Object -> Key -> Parser Int
forall a. FromJSON a => Object -> Key -> Parser a
.:  Key
"limit"
    Int
p   <- Object
o Object -> Key -> Parser Int
forall a. FromJSON a => Object -> Key -> Parser a
.:  Key
"period"
    Algorithm
at  <- Object
o Object -> Key -> Parser IPZoneIdentifier
forall a. FromJSON a => Object -> Key -> Parser a
.:  Key
"algorithm" Parser IPZoneIdentifier
-> (IPZoneIdentifier -> Parser Algorithm) -> Parser Algorithm
forall a b. Parser a -> (a -> Parser b) -> Parser b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IPZoneIdentifier -> Parser Algorithm
parseAlgoText
    IdentifierBy
idb <- Object
o Object -> Key -> Parser IdentifierBy
forall a. FromJSON a => Object -> Key -> Parser a
.:  Key
"identifier_by"
    Maybe Int
ttl <- Object
o Object -> Key -> Parser (Maybe Int)
forall a. FromJSON a => Object -> Key -> Parser (Maybe a)
.:? Key
"token_bucket_ttl"
    RLThrottle -> Parser RLThrottle
forall a. a -> Parser a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IPZoneIdentifier
-> Int
-> Int
-> Algorithm
-> IdentifierBy
-> Maybe Int
-> RLThrottle
RLThrottle IPZoneIdentifier
n Int
l Int
p Algorithm
at IdentifierBy
idb Maybe Int
ttl)

instance ToJSON RLThrottle where
  toJSON :: RLThrottle -> Value
toJSON (RLThrottle IPZoneIdentifier
n Int
l Int
p Algorithm
a IdentifierBy
idb Maybe Int
ttl) =
    [Pair] -> Value
object [ Key
"name" Key -> IPZoneIdentifier -> Pair
forall v. ToJSON v => Key -> v -> Pair
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= IPZoneIdentifier
n, Key
"limit" Key -> Int -> Pair
forall v. ToJSON v => Key -> v -> Pair
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= Int
l, Key
"period" Key -> Int -> Pair
forall v. ToJSON v => Key -> v -> Pair
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= Int
p
           , Key
"algorithm" Key -> IPZoneIdentifier -> Pair
forall v. ToJSON v => Key -> v -> Pair
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= Algorithm -> IPZoneIdentifier
algoToText Algorithm
a, Key
"identifier_by" Key -> IdentifierBy -> Pair
forall v. ToJSON v => Key -> v -> Pair
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= IdentifierBy
idb
           , Key
"token_bucket_ttl" Key -> Maybe Int -> Pair
forall v. ToJSON v => Key -> v -> Pair
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= Maybe Int
ttl
           ]

instance FromJSON RateLimiterConfig where
  parseJSON :: Value -> Parser RateLimiterConfig
parseJSON = String
-> (Object -> Parser RateLimiterConfig)
-> Value
-> Parser RateLimiterConfig
forall a. String -> (Object -> Parser a) -> Value -> Parser a
withObject String
"rate-limiter" ((Object -> Parser RateLimiterConfig)
 -> Value -> Parser RateLimiterConfig)
-> (Object -> Parser RateLimiterConfig)
-> Value
-> Parser RateLimiterConfig
forall a b. (a -> b) -> a -> b
$ \Object
o ->
    ZoneBy -> [RLThrottle] -> RateLimiterConfig
RateLimiterConfig
      (ZoneBy -> [RLThrottle] -> RateLimiterConfig)
-> Parser ZoneBy -> Parser ([RLThrottle] -> RateLimiterConfig)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Object
o Object -> Key -> Parser (Maybe ZoneBy)
forall a. FromJSON a => Object -> Key -> Parser (Maybe a)
.:? Key
"zone_by" Parser (Maybe ZoneBy) -> ZoneBy -> Parser ZoneBy
forall a. Parser (Maybe a) -> a -> Parser a
.!= ZoneBy
ZoneDefault
      Parser ([RLThrottle] -> RateLimiterConfig)
-> Parser [RLThrottle] -> Parser RateLimiterConfig
forall a b. Parser (a -> b) -> Parser a -> Parser b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Object
o Object -> Key -> Parser [RLThrottle]
forall a. FromJSON a => Object -> Key -> Parser a
.:  Key
"throttles"

instance ToJSON RateLimiterConfig where
  toJSON :: RateLimiterConfig -> Value
toJSON (RateLimiterConfig ZoneBy
zb [RLThrottle]
ths) =
    [Pair] -> Value
object [ Key
"zone_by" Key -> ZoneBy -> Pair
forall v. ToJSON v => Key -> v -> Pair
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= ZoneBy
zb, Key
"throttles" Key -> [RLThrottle] -> Pair
forall v. ToJSON v => Key -> v -> Pair
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= [RLThrottle]
ths ]

--------------------------------------------------------------------------------
-- Public builders (preferred wiring API)

-- | Build 'Env' once from a declarative 'RateLimiterConfig'.
--
-- Use this at wiring time; the returned 'Env' is stable and reused across requests.
buildEnvFromConfig :: RateLimiterConfig -> IO Env
buildEnvFromConfig :: RateLimiterConfig -> IO Env
buildEnvFromConfig (RateLimiterConfig ZoneBy
zb [RLThrottle]
ths) = do
  let zoneFn :: Request -> IPZoneIdentifier
zoneFn = ZoneBy -> Request -> IPZoneIdentifier
mkZoneFn ZoneBy
zb
  Env
env <- (Request -> IPZoneIdentifier) -> IO Env
initConfig Request -> IPZoneIdentifier
zoneFn
  (RLThrottle -> IO Env) -> [RLThrottle] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Env -> RLThrottle -> IO Env
registerThrottle Env
env) [RLThrottle]
ths
  Env -> IO Env
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Env
env

-- | Produce a pure 'Middleware' from an existing 'Env'.
--
-- This is the recommended way to integrate with WAI/Keter: the middleware is
-- a pure function, while the state is already encapsulated in 'Env'.
buildRateLimiterWithEnv :: Env -> Middleware
buildRateLimiterWithEnv :: Env -> Application -> Application
buildRateLimiterWithEnv = Env -> Application -> Application
attackMiddleware

-- | Convenience: build an 'Env' from config and return the 'Middleware'.
--
-- Suitable if you don't need to retain the 'Env' for administrative operations.
buildRateLimiter :: RateLimiterConfig -> IO Middleware
buildRateLimiter :: RateLimiterConfig -> IO (Application -> Application)
buildRateLimiter RateLimiterConfig
cfg = Env -> Application -> Application
buildRateLimiterWithEnv (Env -> Application -> Application)
-> IO Env -> IO (Application -> Application)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> RateLimiterConfig -> IO Env
buildEnvFromConfig RateLimiterConfig
cfg

--------------------------------------------------------------------------------
-- Helper functions for configuration

-- | Register a single throttle rule into an 'Env'.
registerThrottle :: Env -> RLThrottle -> IO Env
registerThrottle :: Env -> RLThrottle -> IO Env
registerThrottle Env
env (RLThrottle IPZoneIdentifier
name Int
l Int
p Algorithm
algo IdentifierBy
idBy Maybe Int
ttl) =
  Env -> IPZoneIdentifier -> ThrottleConfig -> IO Env
addThrottle Env
env IPZoneIdentifier
name ThrottleConfig
    { throttleLimit :: Int
throttleLimit = Int
l
    , throttlePeriod :: Int
throttlePeriod = Int
p
    , throttleAlgorithm :: Algorithm
throttleAlgorithm = Algorithm
algo
    , throttleIdentifierBy :: IdentifierBy
throttleIdentifierBy = IdentifierBy
idBy
    , throttleTokenBucketTTL :: Maybe Int
throttleTokenBucketTTL = Maybe Int
ttl
    }

-- | Build a request-identifier function from a declarative spec.
mkIdentifier :: IdentifierBy -> Request -> IO (Maybe Text)
mkIdentifier :: IdentifierBy -> Request -> IO (Maybe IPZoneIdentifier)
mkIdentifier IdentifierBy
IdIP              = Request -> IO (Maybe IPZoneIdentifier)
RU.byIP
mkIdentifier IdentifierBy
IdIPAndPath       = Request -> IO (Maybe IPZoneIdentifier)
RU.byIPAndPath
mkIdentifier IdentifierBy
IdIPAndUA         = Request -> IO (Maybe IPZoneIdentifier)
RU.byIPAndUserAgent
mkIdentifier (IdHeader CI ByteString
h)      = \Request
req -> Maybe IPZoneIdentifier -> IO (Maybe IPZoneIdentifier)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe IPZoneIdentifier -> IO (Maybe IPZoneIdentifier))
-> Maybe IPZoneIdentifier -> IO (Maybe IPZoneIdentifier)
forall a b. (a -> b) -> a -> b
$ (ByteString -> IPZoneIdentifier)
-> Maybe ByteString -> Maybe IPZoneIdentifier
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (OnDecodeError -> ByteString -> IPZoneIdentifier
TE.decodeUtf8With OnDecodeError
TEE.lenientDecode) (Maybe ByteString -> Maybe IPZoneIdentifier)
-> (Request -> Maybe ByteString)
-> Request
-> Maybe IPZoneIdentifier
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CI ByteString -> ResponseHeaders -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup CI ByteString
h (ResponseHeaders -> Maybe ByteString)
-> (Request -> ResponseHeaders) -> Request -> Maybe ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> ResponseHeaders
requestHeaders (Request -> Maybe IPZoneIdentifier)
-> Request -> Maybe IPZoneIdentifier
forall a b. (a -> b) -> a -> b
$ Request
req
mkIdentifier (IdCookie IPZoneIdentifier
name)   = \Request
req -> Maybe IPZoneIdentifier -> IO (Maybe IPZoneIdentifier)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe IPZoneIdentifier -> IO (Maybe IPZoneIdentifier))
-> Maybe IPZoneIdentifier -> IO (Maybe IPZoneIdentifier)
forall a b. (a -> b) -> a -> b
$ IPZoneIdentifier -> Request -> Maybe IPZoneIdentifier
cookieLookupText IPZoneIdentifier
name Request
req
mkIdentifier (IdHeaderAndIP CI ByteString
h) = CI ByteString -> Request -> IO (Maybe IPZoneIdentifier)
RU.byHeaderAndIP CI ByteString
h

-- | Cookie lookup via Web.Cookie; ignores empty values.
cookieLookupText :: Text -> Request -> Maybe Text
cookieLookupText :: IPZoneIdentifier -> Request -> Maybe IPZoneIdentifier
cookieLookupText IPZoneIdentifier
n Request
req = do
  ByteString
raw <- CI ByteString -> ResponseHeaders -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup CI ByteString
hCookie (Request -> ResponseHeaders
requestHeaders Request
req)
  let pairs :: Cookies
pairs = ByteString -> Cookies
WC.parseCookies ByteString
raw
  ByteString
v <- ByteString -> Cookies -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup (IPZoneIdentifier -> ByteString
TE.encodeUtf8 IPZoneIdentifier
n) Cookies
pairs
  if ByteString -> Bool
S.null ByteString
v then Maybe IPZoneIdentifier
forall a. Maybe a
Nothing else IPZoneIdentifier -> Maybe IPZoneIdentifier
forall a. a -> Maybe a
Just (OnDecodeError -> ByteString -> IPZoneIdentifier
TE.decodeUtf8With OnDecodeError
TEE.lenientDecode ByteString
v)

-- | Derive IP zone function from a declarative spec.
mkZoneFn :: ZoneBy -> (Request -> IPZoneIdentifier)
mkZoneFn :: ZoneBy -> Request -> IPZoneIdentifier
mkZoneFn ZoneBy
ZoneDefault    = IPZoneIdentifier -> Request -> IPZoneIdentifier
forall a b. a -> b -> a
const IPZoneIdentifier
defaultIPZone
mkZoneFn ZoneBy
ZoneIP         = Request -> IPZoneIdentifier
getClientIPPure
mkZoneFn (ZoneHeader CI ByteString
h) = \Request
req ->
  IPZoneIdentifier
-> (ByteString -> IPZoneIdentifier)
-> Maybe ByteString
-> IPZoneIdentifier
forall b a. b -> (a -> b) -> Maybe a -> b
maybe IPZoneIdentifier
defaultIPZone (OnDecodeError -> ByteString -> IPZoneIdentifier
TE.decodeUtf8With OnDecodeError
TEE.lenientDecode) (CI ByteString -> ResponseHeaders -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup CI ByteString
h (Request -> ResponseHeaders
requestHeaders Request
req))

-- | Extract client IP with header precedence: X-Forwarded-For, X-Real-IP, then socket.
getClientIPPure :: Request -> IPZoneIdentifier
getClientIPPure :: Request -> IPZoneIdentifier
getClientIPPure Request
req =
  let safeDecode :: ByteString -> IPZoneIdentifier
safeDecode = OnDecodeError -> ByteString -> IPZoneIdentifier
TE.decodeUtf8With OnDecodeError
TEE.lenientDecode
  in case CI ByteString -> ResponseHeaders -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup (ByteString -> CI ByteString
forall s. FoldCase s => s -> CI s
mk ByteString
"x-forwarded-for") (Request -> ResponseHeaders
requestHeaders Request
req) of
    Just ByteString
xff -> (Char -> Bool) -> IPZoneIdentifier -> IPZoneIdentifier
Tx.takeWhile (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
/= Char
',') (IPZoneIdentifier -> IPZoneIdentifier)
-> IPZoneIdentifier -> IPZoneIdentifier
forall a b. (a -> b) -> a -> b
$ ByteString -> IPZoneIdentifier
safeDecode ByteString
xff
    Maybe ByteString
Nothing  ->
      case CI ByteString -> ResponseHeaders -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup (ByteString -> CI ByteString
forall s. FoldCase s => s -> CI s
mk ByteString
"x-real-ip") (Request -> ResponseHeaders
requestHeaders Request
req) of
        Just ByteString
rip -> ByteString -> IPZoneIdentifier
safeDecode ByteString
rip
        Maybe ByteString
Nothing  ->
          case Request -> SockAddr
remoteHost Request
req of
            SockAddrInet  PortNumber
_ HostAddress
addr     -> HostAddress -> IPZoneIdentifier
RU.ipv4ToString HostAddress
addr
            SockAddrInet6 PortNumber
_ HostAddress
_ HostAddress6
addr HostAddress
_ -> HostAddress6 -> IPZoneIdentifier
RU.ipv6ToString HostAddress6
addr
            SockAddrUnix   String
path      -> String -> IPZoneIdentifier
Tx.pack String
path

-- | Construct a case-insensitive header name from Text.
hdr :: Text -> HeaderName
hdr :: IPZoneIdentifier -> CI ByteString
hdr = ByteString -> CI ByteString
forall s. FoldCase s => s -> CI s
mk (ByteString -> CI ByteString)
-> (IPZoneIdentifier -> ByteString)
-> IPZoneIdentifier
-> CI ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IPZoneIdentifier -> ByteString
TE.encodeUtf8

-- | Extract original bytes from a case-insensitive header name.
fromHeaderName :: HeaderName -> S.ByteString
fromHeaderName :: CI ByteString -> ByteString
fromHeaderName = CI ByteString -> ByteString
forall s. CI s -> s
original

--------------------------------------------------------------------------------
-- Internal helpers: grouping and traversal (to avoid duplicate work)

type ThrottleName = Text
type Grouped = HM.HashMap IdentifierBy [(ThrottleName, ThrottleConfig)]

-- | Group throttles by their IdentifierBy to compute the identifier once per group.
groupByIdentifier :: HM.HashMap ThrottleName ThrottleConfig -> Grouped
groupByIdentifier :: HashMap IPZoneIdentifier ThrottleConfig -> Grouped
groupByIdentifier =
  (Grouped -> IPZoneIdentifier -> ThrottleConfig -> Grouped)
-> Grouped -> HashMap IPZoneIdentifier ThrottleConfig -> Grouped
forall a k v. (a -> k -> v -> a) -> a -> HashMap k v -> a
HM.foldlWithKey' Grouped -> IPZoneIdentifier -> ThrottleConfig -> Grouped
forall {a}.
HashMap IdentifierBy [(a, ThrottleConfig)]
-> a
-> ThrottleConfig
-> HashMap IdentifierBy [(a, ThrottleConfig)]
step Grouped
forall k v. HashMap k v
HM.empty
  where
    step :: HashMap IdentifierBy [(a, ThrottleConfig)]
-> a
-> ThrottleConfig
-> HashMap IdentifierBy [(a, ThrottleConfig)]
step HashMap IdentifierBy [(a, ThrottleConfig)]
acc a
name ThrottleConfig
cfg =
      ([(a, ThrottleConfig)]
 -> [(a, ThrottleConfig)] -> [(a, ThrottleConfig)])
-> IdentifierBy
-> [(a, ThrottleConfig)]
-> HashMap IdentifierBy [(a, ThrottleConfig)]
-> HashMap IdentifierBy [(a, ThrottleConfig)]
forall k v.
(Eq k, Hashable k) =>
(v -> v -> v) -> k -> v -> HashMap k v -> HashMap k v
HM.insertWith [(a, ThrottleConfig)]
-> [(a, ThrottleConfig)] -> [(a, ThrottleConfig)]
forall a. [a] -> [a] -> [a]
(++) (ThrottleConfig -> IdentifierBy
throttleIdentifierBy ThrottleConfig
cfg) [(a
name, ThrottleConfig
cfg)] HashMap IdentifierBy [(a, ThrottleConfig)]
acc

anyMList :: (a -> IO Bool) -> [a] -> IO Bool
anyMList :: forall a. (a -> IO Bool) -> [a] -> IO Bool
anyMList a -> IO Bool
_ []     = Bool -> IO Bool
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
anyMList a -> IO Bool
f (a
x:[a]
xs) = do
  Bool
b <- a -> IO Bool
f a
x
  if Bool
b then Bool -> IO Bool
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True else (a -> IO Bool) -> [a] -> IO Bool
forall a. (a -> IO Bool) -> [a] -> IO Bool
anyMList a -> IO Bool
f [a]
xs

anyMHashMap :: (k -> v -> IO Bool) -> HM.HashMap k v -> IO Bool
anyMHashMap :: forall k v. (k -> v -> IO Bool) -> HashMap k v -> IO Bool
anyMHashMap k -> v -> IO Bool
f = ((k, v) -> IO Bool) -> [(k, v)] -> IO Bool
forall a. (a -> IO Bool) -> [a] -> IO Bool
anyMList ((k -> v -> IO Bool) -> (k, v) -> IO Bool
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry k -> v -> IO Bool
f) ([(k, v)] -> IO Bool)
-> (HashMap k v -> [(k, v)]) -> HashMap k v -> IO Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HashMap k v -> [(k, v)]
forall k v. HashMap k v -> [(k, v)]
HM.toList