{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE DataKinds #-}

-- |
-- Module      : Keter.RateLimiter.TokenBucket
-- Description : Token bucket rate limiting algorithm implementation
-- Copyright   : (c) 2025 Oleksandr Zhabenko
-- License     : MIT
-- Maintainer  : oleksandr.zhabenko@yahoo.com
-- Stability   : stable
-- Portability : portable
--
-- This module provides a rate limiter based on the /Token Bucket/ algorithm.
-- It integrates with the "Keter.RateLimiter.Cache" infrastructure and uses STM
-- and worker threads to manage refill and request allowance.
--
-- The token bucket algorithm allows for a configurable burst size (capacity)
-- and replenishes tokens over time at a fixed rate. If a request is made and
-- a token is available, the request is allowed and a token is consumed.
-- Otherwise, the request is denied.
--
-- == Example usage
--
-- @
-- import Keter.RateLimiter.TokenBucket (allowRequest)
-- 
-- allowed <- allowRequest cache \"zone1\" \"user123\" 10 2.5 60
-- when allowed $ doSomething
-- @
--
-- This call checks if a request by @\"user123\"@ in @\"zone1\"@ is allowed, given a
-- bucket with a capacity of 10, a refill rate of 2.5 tokens\/second, and a TTL of 60 seconds.

module Keter.RateLimiter.TokenBucket
  ( -- * Request Evaluation
    allowRequest
  ) where

import Control.Concurrent.MVar
import Control.Concurrent.STM
import Control.Monad.IO.Class   (MonadIO, liftIO)
import Data.Text                (Text)
import Data.Time.Clock.POSIX    (getPOSIXTime)

import Keter.RateLimiter.Cache
import Keter.RateLimiter.Types          (TokenBucketState (..))
import Keter.RateLimiter.AutoPurge      (TokenBucketEntry (..))
import qualified Focus                  as F
import qualified StmContainers.Map      as StmMap

------------------------------------------------------------------------------

-- | Minimum TTL allowed for a token bucket.
-- 
-- Requests with TTLs less than this threshold are denied to avoid race conditions
-- or unbounded cleanup complexity.
minTTL :: Int
minTTL :: Int
minTTL = Int
2

-- | Check whether a request may pass through the token-bucket limiter.
--
-- This function enforces rate-limiting per (IP zone, user key) combination.
-- Each request will either:
--
-- * Succeed immediately if the request belongs to a new bucket and capacity allows.
-- * Be enqueued and handled asynchronously if the bucket already exists.
-- * Be denied if no tokens are available or TTL is invalid.
--
-- The token bucket is defined by:
--
-- * /capacity/: maximum number of tokens in the bucket (i.e., max burst size).
-- * /refillRate/: tokens added per second (can be fractional).
-- * /expiresIn/: TTL in seconds; determines how long idle buckets live.
--
-- The function performs the following steps:
--
-- 1. Validates that TTL meets the minimum threshold
-- 2. Creates or retrieves the token bucket for the given key
-- 3. For new buckets: starts a worker thread and allows the first request
-- 4. For existing buckets: queues the request and waits for the worker's response
--
-- ==== __Examples__
--
-- @
-- -- Allow 100 requests per minute with burst capacity of 10
-- let capacity = 10
--     refillRate = 100.0 \/ 60.0  -- ~1.67 tokens per second
--     ttl = 300                  -- 5 minutes TTL
--
-- result <- allowRequest cache \"api-throttle\" \"192.168.1.1\" \"user456\" capacity refillRate ttl
-- if result
--   then putStrLn \"Request allowed\"
--   else putStrLn \"Request denied - rate limit exceeded\"
-- @
--
-- @
-- -- High-frequency API with small bursts
-- allowed <- allowRequest cache \"fast-api\" \"zone-premium\" \"client789\" 5 10.0 120
-- @
--
-- /Thread Safety:/ This function is thread-safe and can be called concurrently
-- from multiple threads for the same or different keys.
--
-- /Performance:/ For new buckets, there's a one-time setup cost of starting
-- a worker thread. Subsequent requests are processed asynchronously with
-- minimal blocking.
allowRequest
  :: MonadIO m
  => Cache (InMemoryStore 'TokenBucket)
  -- ^ Token bucket cache instance
  -> Text
  -- ^ Throttle name (logical grouping identifier)
  -> Text
  -- ^ IP zone identifier
  -> Text
  -- ^ User key (unique client identifier)
  -> Int
  -- ^ Bucket capacity (maximum tokens, must be positive)
  -> Double
  -- ^ Refill rate in tokens per second (must be positive, can be fractional)
  -> Int
  -- ^ TTL in seconds (must be >= 'minTTL')
  -> m Bool
  -- ^ 'True' if request is allowed, 'False' if denied
allowRequest :: forall (m :: * -> *).
MonadIO m =>
Cache (InMemoryStore 'TokenBucket)
-> Text -> Text -> Text -> Int -> Double -> Int -> m Bool
allowRequest Cache (InMemoryStore 'TokenBucket)
cache Text
throttleName Text
ipZone Text
userKey Int
capacity Double
refillRate Int
expiresIn = IO Bool -> m Bool
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Bool -> m Bool) -> IO Bool -> m Bool
forall a b. (a -> b) -> a -> b
$
  if Int
expiresIn Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
minTTL
     then do
       Bool -> IO Bool
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
     else do
       Int
now <- POSIXTime -> Int
forall b. Integral b => POSIXTime -> b
forall a b. (RealFrac a, Integral b) => a -> b
floor (POSIXTime -> Int) -> IO POSIXTime -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO POSIXTime
getPOSIXTime
       let key :: Text
key = Text -> Algorithm -> Text -> Text -> Text
makeCacheKey Text
throttleName (Cache (InMemoryStore 'TokenBucket) -> Algorithm
forall store. Cache store -> Algorithm
cacheAlgorithm Cache (InMemoryStore 'TokenBucket)
cache) Text
ipZone Text
userKey
           TokenBucketStore TVar (Map Text TokenBucketEntry)
tvBuckets = Cache (InMemoryStore 'TokenBucket) -> InMemoryStore 'TokenBucket
forall store. Cache store -> store
cacheStore Cache (InMemoryStore 'TokenBucket)
cache
       MVar Bool
replyVar <- IO (MVar Bool)
forall a. IO (MVar a)
newEmptyMVar
       TokenBucketEntry
newEntryInitialState <- TokenBucketState -> IO TokenBucketEntry
createTokenBucketEntry (Int -> Int -> TokenBucketState
TokenBucketState (Int
capacity Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int
now)
       
       (Bool
wasNew, TokenBucketEntry
entry) <- STM (Bool, TokenBucketEntry) -> IO (Bool, TokenBucketEntry)
forall a. STM a -> IO a
atomically (STM (Bool, TokenBucketEntry) -> IO (Bool, TokenBucketEntry))
-> STM (Bool, TokenBucketEntry) -> IO (Bool, TokenBucketEntry)
forall a b. (a -> b) -> a -> b
$ do
         Map Text TokenBucketEntry
buckets <- TVar (Map Text TokenBucketEntry) -> STM (Map Text TokenBucketEntry)
forall a. TVar a -> STM a
readTVar TVar (Map Text TokenBucketEntry)
tvBuckets
         (Bool
wasNewEntry, TokenBucketEntry
ent) <-
           Focus TokenBucketEntry STM (Bool, TokenBucketEntry)
-> Text
-> Map Text TokenBucketEntry
-> STM (Bool, TokenBucketEntry)
forall key value result.
Hashable key =>
Focus value STM result -> key -> Map key value -> STM result
StmMap.focus
             (STM ((Bool, TokenBucketEntry), Change TokenBucketEntry)
-> (TokenBucketEntry
    -> STM ((Bool, TokenBucketEntry), Change TokenBucketEntry))
-> Focus TokenBucketEntry STM (Bool, TokenBucketEntry)
forall element (m :: * -> *) result.
m (result, Change element)
-> (element -> m (result, Change element))
-> Focus element m result
F.Focus
                (((Bool, TokenBucketEntry), Change TokenBucketEntry)
-> STM ((Bool, TokenBucketEntry), Change TokenBucketEntry)
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Bool
True, TokenBucketEntry
newEntryInitialState), TokenBucketEntry -> Change TokenBucketEntry
forall a. a -> Change a
F.Set TokenBucketEntry
newEntryInitialState))
                (\TokenBucketEntry
existingEnt -> do
                  Bool
workerLockEmpty <- TMVar () -> STM Bool
forall a. TMVar a -> STM Bool
isEmptyTMVar (TokenBucketEntry -> TMVar ()
tbeWorkerLock TokenBucketEntry
existingEnt)
                  if Bool
workerLockEmpty
                    then ((Bool, TokenBucketEntry), Change TokenBucketEntry)
-> STM ((Bool, TokenBucketEntry), Change TokenBucketEntry)
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Bool
True, TokenBucketEntry
newEntryInitialState), TokenBucketEntry -> Change TokenBucketEntry
forall a. a -> Change a
F.Set TokenBucketEntry
newEntryInitialState)
                    else ((Bool, TokenBucketEntry), Change TokenBucketEntry)
-> STM ((Bool, TokenBucketEntry), Change TokenBucketEntry)
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Bool
False, TokenBucketEntry
existingEnt), Change TokenBucketEntry
forall a. Change a
F.Leave)
                )
             )
             Text
key Map Text TokenBucketEntry
buckets
         (Bool, TokenBucketEntry) -> STM (Bool, TokenBucketEntry)
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool
wasNewEntry, TokenBucketEntry
ent)
       if Bool
wasNew
         then
           if Int
capacity Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0
             then do
               TMVar ()
workerReadyVar <- STM (TMVar ()) -> IO (TMVar ())
forall a. STM a -> IO a
atomically STM (TMVar ())
forall a. STM (TMVar a)
newEmptyTMVar
               STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TMVar () -> () -> STM ()
forall a. TMVar a -> a -> STM ()
putTMVar (TokenBucketEntry -> TMVar ()
tbeWorkerLock TokenBucketEntry
entry) ()
               TVar TokenBucketState
-> TQueue (MVar Bool) -> Int -> Double -> TMVar () -> IO ()
startTokenBucketWorker (TokenBucketEntry -> TVar TokenBucketState
tbeState TokenBucketEntry
entry)
                                      (TokenBucketEntry -> TQueue (MVar Bool)
tbeQueue TokenBucketEntry
entry)
                                      Int
capacity
                                      Double
refillRate
                                      TMVar ()
workerReadyVar
               STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TMVar () -> STM ()
forall a. TMVar a -> STM a
takeTMVar TMVar ()
workerReadyVar
               Bool -> IO Bool
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
             else
               Bool -> IO Bool
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
         else do
           STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TQueue (MVar Bool) -> MVar Bool -> STM ()
forall a. TQueue a -> a -> STM ()
writeTQueue (TokenBucketEntry -> TQueue (MVar Bool)
tbeQueue TokenBucketEntry
entry) MVar Bool
replyVar
           Bool
result <- MVar Bool -> IO Bool
forall a. MVar a -> IO a
takeMVar MVar Bool
replyVar
           Bool -> IO Bool
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
result