{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeApplications #-}

{-|
Module      : Keter.RateLimiter.LeakyBucket
Description : Leaky bucket rate limiting algorithm implementation
Copyright   : (c) 2025 Oleksandr Zhabenko
License     : MIT
Maintainer  : oleksandr.zhabenko@yahoo.com
Stability   : stable
Portability : portable

This module implements the core logic for the Leaky Bucket rate-limiting algorithm. The primary goal of this algorithm is to smooth out bursts of requests into a steady, predictable flow. It is conceptually similar to a bucket with a hole in the bottom.

Incoming requests are like water being added to the bucket. The bucket has a finite capacity. If a request arrives when the bucket is full, it is rejected (it \"spills over\"). The hole in the bucket allows requests to be processed (or \"leak out\") at a constant leak rate.

This implementation uses a dedicated worker thread for each bucket (e.g., for each unique user or IP address) to process requests from a queue. This ensures that processing is serialized and state is managed safely under concurrent access. The first request for a given key will spawn the worker, which then serves all subsequent requests for that same key.
-}
module Keter.RateLimiter.LeakyBucket
  ( -- * Algorithm Logic
    allowRequest
  ) where

import Control.Concurrent.STM
import Control.Monad (when)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Data.Text (Text)
import Data.Time.Clock.POSIX (getPOSIXTime)
import Keter.RateLimiter.Types (LeakyBucketState(..))
import Keter.RateLimiter.Cache
import Keter.RateLimiter.AutoPurge (LeakyBucketEntry(..))
import qualified Focus as F
import qualified StmContainers.Map as StmMap

-- | Determines whether a request should be allowed based on the state of its corresponding leaky bucket.
--
-- This function is the primary entry point for the leaky bucket algorithm. When a request arrives,
-- this function is called to check if it should be processed or rejected. It operates by finding
-- or creating a bucket associated with the user's key. Every request is added to a queue for its bucket.
--
-- A dedicated worker thread is responsible for processing the queue for each bucket. This function ensures
-- that a worker is started for a new bucket but avoids starting duplicate workers. The final result
-- ('True' or 'False') is communicated back to the caller once the worker has processed the request.
--
-- === Algorithm Flow
--
-- 1. __Validation__: Check if capacity is valid (> 0)
-- 2. __Key Construction__: Build composite cache key from throttle name, IP zone, and user key
-- 3. __Bucket Management__: Find existing bucket or create new one atomically
-- 4. __Request Queuing__: Add request to bucket's processing queue
-- 5. __Worker Management__: Start worker thread for new buckets (one-time operation)
-- 6. __Response__: Wait for worker to process request and return result
--
-- === Concurrency Model
--
-- * Each bucket has its own worker thread for serialized processing
-- * Multiple clients can safely call this function concurrently
-- * STM ensures atomic bucket creation and state management
-- * Workers are started lazily on first request per bucket
--
-- ==== __Examples__
--
-- @
-- -- Basic usage: 10 request capacity, 1 request per second leak rate
-- isAllowed <- allowRequest cache \"api-throttle\" \"us-east-1\" \"user-123\" 10 1.0
-- if isAllowed
--   then putStrLn \"Request is allowed.\"
--   else putStrLn \"Request is blocked (429 Too Many Requests).\"
-- @
--
-- @
-- -- High-throughput API with burst tolerance
-- let capacity = 100        -- Allow burst of up to 100 requests
--     leakRate = 10.0       -- Process 10 requests per second steadily
-- 
-- result <- allowRequest cache \"high-volume\" \"zone-premium\" \"client-456\" capacity leakRate
-- when result $ processApiRequest
-- @
--
-- @
-- -- Rate limiting for different service tiers
-- let (cap, rate) = case userTier of
--       Premium -> (50, 5.0)   -- 50 burst, 5\/sec sustained
--       Standard -> (20, 2.0)  -- 20 burst, 2\/sec sustained  
--       Basic -> (5, 0.5)      -- 5 burst, 1 per 2 seconds
-- 
-- allowed <- allowRequest cache \"tiered-api\" zone userId cap rate
-- @
--
-- /Thread Safety:/ This function is fully thread-safe and can be called
-- concurrently from multiple threads.
--
-- /Performance:/ New buckets incur a one-time worker thread creation cost.
-- Subsequent requests are queued with minimal overhead.
allowRequest
  :: MonadIO m
  => Cache (InMemoryStore 'LeakyBucket)
  -- ^ Leaky bucket cache instance
  -> Text
  -- ^ Throttle name (logical grouping identifier)
  -> Text
  -- ^ IP zone identifier for multi-tenant isolation
  -> Text
  -- ^ User key (unique client identifier)
  -> Int
  -- ^ Bucket capacity (maximum queued requests, must be > 0)
  -> Double
  -- ^ Leak rate in requests per second (must be positive)
  -> m Bool
  -- ^ 'True' if request is allowed, 'False' if bucket is full
allowRequest :: forall (m :: * -> *).
MonadIO m =>
Cache (InMemoryStore 'LeakyBucket)
-> Text -> Text -> Text -> Int -> Double -> m Bool
allowRequest Cache (InMemoryStore 'LeakyBucket)
cache Text
throttleName Text
ipZone Text
userKey Int
capacity Double
leakRate = IO Bool -> m Bool
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Bool -> m Bool) -> IO Bool -> m Bool
forall a b. (a -> b) -> a -> b
$ do
  if Int
capacity Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0 then Bool -> IO Bool
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False else do
    Double
now <- POSIXTime -> Double
forall a b. (Real a, Fractional b) => a -> b
realToFrac (POSIXTime -> Double) -> IO POSIXTime -> IO Double
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO POSIXTime
getPOSIXTime
    let fullKey :: Text
fullKey = Text -> Algorithm -> Text -> Text -> Text
makeCacheKey Text
throttleName Algorithm
LeakyBucket Text
ipZone Text
userKey
        LeakyBucketStore TVar (Map Text LeakyBucketEntry)
tvBuckets = Cache (InMemoryStore 'LeakyBucket) -> InMemoryStore 'LeakyBucket
forall store. Cache store -> store
cacheStore Cache (InMemoryStore 'LeakyBucket)
cache
    TMVar Bool
replyVar <- IO (TMVar Bool)
forall a. IO (TMVar a)
newEmptyTMVarIO
    LeakyBucketEntry
newEntry <- LeakyBucketState -> IO LeakyBucketEntry
createLeakyBucketEntry (Double -> Double -> LeakyBucketState
LeakyBucketState Double
0 Double
now)
    LeakyBucketEntry
entry <- STM LeakyBucketEntry -> IO LeakyBucketEntry
forall a. STM a -> IO a
atomically (STM LeakyBucketEntry -> IO LeakyBucketEntry)
-> STM LeakyBucketEntry -> IO LeakyBucketEntry
forall a b. (a -> b) -> a -> b
$ do
      Map Text LeakyBucketEntry
buckets <- TVar (Map Text LeakyBucketEntry) -> STM (Map Text LeakyBucketEntry)
forall a. TVar a -> STM a
readTVar TVar (Map Text LeakyBucketEntry)
tvBuckets
      Focus LeakyBucketEntry STM LeakyBucketEntry
-> Text -> Map Text LeakyBucketEntry -> STM LeakyBucketEntry
forall key value result.
Hashable key =>
Focus value STM result -> key -> Map key value -> STM result
StmMap.focus
        ((LeakyBucketEntry, Change LeakyBucketEntry)
-> (LeakyBucketEntry
    -> (LeakyBucketEntry, Change LeakyBucketEntry))
-> Focus LeakyBucketEntry STM LeakyBucketEntry
forall (m :: * -> *) b a.
Monad m =>
(b, Change a) -> (a -> (b, Change a)) -> Focus a m b
F.cases (LeakyBucketEntry
newEntry, LeakyBucketEntry -> Change LeakyBucketEntry
forall a. a -> Change a
F.Set LeakyBucketEntry
newEntry)
                 (\LeakyBucketEntry
existing -> (LeakyBucketEntry
existing, Change LeakyBucketEntry
forall a. Change a
F.Leave)))
        Text
fullKey Map Text LeakyBucketEntry
buckets
    STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TQueue (TMVar Bool) -> TMVar Bool -> STM ()
forall a. TQueue a -> a -> STM ()
writeTQueue (LeakyBucketEntry -> TQueue (TMVar Bool)
lbeQueue LeakyBucketEntry
entry) TMVar Bool
replyVar
    Bool
started <- STM Bool -> IO Bool
forall a. STM a -> IO a
atomically (STM Bool -> IO Bool) -> STM Bool -> IO Bool
forall a b. (a -> b) -> a -> b
$ TMVar () -> () -> STM Bool
forall a. TMVar a -> a -> STM Bool
tryPutTMVar (LeakyBucketEntry -> TMVar ()
lbeWorkerLock LeakyBucketEntry
entry) ()
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
started (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
      TVar LeakyBucketState
-> TQueue (TMVar Bool) -> Int -> Double -> IO ()
startLeakyBucketWorker
        (LeakyBucketEntry -> TVar LeakyBucketState
lbeState LeakyBucketEntry
entry)
        (LeakyBucketEntry -> TQueue (TMVar Bool)
lbeQueue LeakyBucketEntry
entry)
        Int
capacity
        Double
leakRate
    STM Bool -> IO Bool
forall a. STM a -> IO a
atomically (STM Bool -> IO Bool) -> STM Bool -> IO Bool
forall a b. (a -> b) -> a -> b
$ TMVar Bool -> STM Bool
forall a. TMVar a -> STM a
takeTMVar TMVar Bool
replyVar