{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeApplications #-}

{-|
Module      : Keter.RateLimiter.IPZones
Description : Management of caches specific to IP zones for rate limiting
Copyright   : (c) 2025 Oleksandr Zhabenko
License     : MIT
Maintainer  : oleksandr.zhabenko@yahoo.com
Stability   : stable
Portability : portable

This module provides zone-based isolation for rate-limiting caches. It enables
each IP zone to maintain its own independent instances of rate-limiting
algorithms (e.g., token bucket, leaky bucket, sliding window, etc.). This
ensures multi-tenant systems can rate-limit different clients or groups in
isolation.

The primary structure here is 'ZoneSpecificCaches', which contains multiple
caches per zone. Utility functions allow dynamic creation, reset, and
lookup of caches for specific zones.

-}

module Keter.RateLimiter.IPZones
  ( -- * IP Zone Identification
    IPZoneIdentifier
  , defaultIPZone
    -- * Zone-specific Caches
  , ZoneSpecificCaches(..)
  , createZoneCaches
  , newZoneSpecificCaches
    -- * Cache Management
  , resetSingleZoneCaches
  , resetZoneCache
    -- * Address to Zone Resolution
  , sockAddrToIPZone 
  ) where

import Data.Text (Text)
import qualified Data.Text as T
import Keter.RateLimiter.Cache
  ( Cache(..)
  , InMemoryStore(..)
  , newCache
  , createInMemoryStore
  , cacheReset
  , Algorithm(..)
  , startCustomPurgeLeakyBucket
  )
import Network.Socket (SockAddr(..))
import Data.IP (fromHostAddress)
import Numeric (showHex)
import Data.Bits
import Control.Concurrent.STM (newTVarIO, atomically, readTVar) 
import qualified StmContainers.Map as StmMap

--------------------------------------------------------------------------------

-- | Type alias representing an identifier for an IP zone.
--
-- This is used as a logical namespace or grouping key for assigning and isolating rate limiters.
-- Examples: @\"default\"@, @\"zone-a\"@, or @\"192.168.1.0\/24\"@.
type IPZoneIdentifier = Text

-- | The default IP zone identifier used when no specific zone is assigned.
--
-- Used as a fallback when no zone-specific routing is determined.
defaultIPZone :: IPZoneIdentifier
defaultIPZone :: IPZoneIdentifier
defaultIPZone = IPZoneIdentifier
"default"

-- | A collection of caches dedicated to a specific IP zone.
--
-- Each cache corresponds to one of the supported rate-limiting algorithms,
-- maintained independently per zone.
data ZoneSpecificCaches = ZoneSpecificCaches
  { ZoneSpecificCaches -> Cache (InMemoryStore 'FixedWindow)
zscCounterCache     :: Cache (InMemoryStore 'FixedWindow)
    -- ^ Cache for Fixed Window counters.
  , ZoneSpecificCaches -> Cache (InMemoryStore 'SlidingWindow)
zscTimestampCache   :: Cache (InMemoryStore 'SlidingWindow)
    -- ^ Cache for timestamp lists used in Sliding Window.
  , ZoneSpecificCaches -> Cache (InMemoryStore 'TokenBucket)
zscTokenBucketCache :: Cache (InMemoryStore 'TokenBucket)
    -- ^ Token Bucket cache.
  , ZoneSpecificCaches -> Cache (InMemoryStore 'LeakyBucket)
zscLeakyBucketCache :: Cache (InMemoryStore 'LeakyBucket)
    -- ^ Leaky Bucket queue-based cache.
  , ZoneSpecificCaches -> Cache (InMemoryStore 'TinyLRU)
zscTinyLRUCache     :: Cache (InMemoryStore 'TinyLRU)
    -- ^ Optional auxiliary LRU cache.
  }

-- | Create a new set of caches for a single IP zone.
--
-- Each algorithm receives its own store. For 'LeakyBucket', a background
-- cleanup thread is also started to remove inactive entries periodically.
--
-- The cleanup thread runs every 60 seconds and removes entries older than 2 hours.
--
-- ==== __Examples__
--
-- @
-- zoneCaches <- createZoneCaches
-- cacheReset (zscTokenBucketCache zoneCaches)
-- @
createZoneCaches :: IO ZoneSpecificCaches
createZoneCaches :: IO ZoneSpecificCaches
createZoneCaches = do
  InMemoryStore 'FixedWindow
counterStore <- forall (a :: Algorithm). CreateStore a => IO (InMemoryStore a)
createInMemoryStore @'FixedWindow
  InMemoryStore 'SlidingWindow
slidingStore <- forall (a :: Algorithm). CreateStore a => IO (InMemoryStore a)
createInMemoryStore @'SlidingWindow
  InMemoryStore 'TokenBucket
tokenBucketStore <- forall (a :: Algorithm). CreateStore a => IO (InMemoryStore a)
createInMemoryStore @'TokenBucket
  TVar (Map IPZoneIdentifier LeakyBucketEntry)
leakyBucketTVar <- Map IPZoneIdentifier LeakyBucketEntry
-> IO (TVar (Map IPZoneIdentifier LeakyBucketEntry))
forall a. a -> IO (TVar a)
newTVarIO (Map IPZoneIdentifier LeakyBucketEntry
 -> IO (TVar (Map IPZoneIdentifier LeakyBucketEntry)))
-> IO (Map IPZoneIdentifier LeakyBucketEntry)
-> IO (TVar (Map IPZoneIdentifier LeakyBucketEntry))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< STM (Map IPZoneIdentifier LeakyBucketEntry)
-> IO (Map IPZoneIdentifier LeakyBucketEntry)
forall a. STM a -> IO a
atomically STM (Map IPZoneIdentifier LeakyBucketEntry)
forall key value. STM (Map key value)
StmMap.new
  let leakyBucketStore :: InMemoryStore 'LeakyBucket
leakyBucketStore = TVar (Map IPZoneIdentifier LeakyBucketEntry)
-> InMemoryStore 'LeakyBucket
LeakyBucketStore TVar (Map IPZoneIdentifier LeakyBucketEntry)
leakyBucketTVar
  Map IPZoneIdentifier LeakyBucketEntry
leakyBucketMap <- STM (Map IPZoneIdentifier LeakyBucketEntry)
-> IO (Map IPZoneIdentifier LeakyBucketEntry)
forall a. STM a -> IO a
atomically (STM (Map IPZoneIdentifier LeakyBucketEntry)
 -> IO (Map IPZoneIdentifier LeakyBucketEntry))
-> STM (Map IPZoneIdentifier LeakyBucketEntry)
-> IO (Map IPZoneIdentifier LeakyBucketEntry)
forall a b. (a -> b) -> a -> b
$ TVar (Map IPZoneIdentifier LeakyBucketEntry)
-> STM (Map IPZoneIdentifier LeakyBucketEntry)
forall a. TVar a -> STM a
readTVar TVar (Map IPZoneIdentifier LeakyBucketEntry)
leakyBucketTVar
  ThreadId
_ <- Map IPZoneIdentifier LeakyBucketEntry
-> Integer -> Integer -> IO ThreadId
startCustomPurgeLeakyBucket
         Map IPZoneIdentifier LeakyBucketEntry
leakyBucketMap
         (Integer
60 :: Integer)    -- Purge interval (every 60 seconds)
         (Integer
7200 :: Integer)  -- TTL (2 hours)
  InMemoryStore 'TinyLRU
tinyLRUStore <- forall (a :: Algorithm). CreateStore a => IO (InMemoryStore a)
createInMemoryStore @'TinyLRU
  ZoneSpecificCaches -> IO ZoneSpecificCaches
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ZoneSpecificCaches
    { zscCounterCache :: Cache (InMemoryStore 'FixedWindow)
zscCounterCache     = Algorithm
-> InMemoryStore 'FixedWindow -> Cache (InMemoryStore 'FixedWindow)
forall store. Algorithm -> store -> Cache store
newCache Algorithm
FixedWindow InMemoryStore 'FixedWindow
counterStore
    , zscTimestampCache :: Cache (InMemoryStore 'SlidingWindow)
zscTimestampCache   = Algorithm
-> InMemoryStore 'SlidingWindow
-> Cache (InMemoryStore 'SlidingWindow)
forall store. Algorithm -> store -> Cache store
newCache Algorithm
SlidingWindow InMemoryStore 'SlidingWindow
slidingStore
    , zscTokenBucketCache :: Cache (InMemoryStore 'TokenBucket)
zscTokenBucketCache = Algorithm
-> InMemoryStore 'TokenBucket -> Cache (InMemoryStore 'TokenBucket)
forall store. Algorithm -> store -> Cache store
newCache Algorithm
TokenBucket InMemoryStore 'TokenBucket
tokenBucketStore
    , zscLeakyBucketCache :: Cache (InMemoryStore 'LeakyBucket)
zscLeakyBucketCache = Algorithm
-> InMemoryStore 'LeakyBucket -> Cache (InMemoryStore 'LeakyBucket)
forall store. Algorithm -> store -> Cache store
newCache Algorithm
LeakyBucket InMemoryStore 'LeakyBucket
leakyBucketStore
    , zscTinyLRUCache :: Cache (InMemoryStore 'TinyLRU)
zscTinyLRUCache     = Algorithm
-> InMemoryStore 'TinyLRU -> Cache (InMemoryStore 'TinyLRU)
forall store. Algorithm -> store -> Cache store
newCache Algorithm
TinyLRU InMemoryStore 'TinyLRU
tinyLRUStore
    }

-- | Alias for 'createZoneCaches'.
--
-- Useful for more readable builder-based usage or factory patterns.
newZoneSpecificCaches :: IO ZoneSpecificCaches
newZoneSpecificCaches :: IO ZoneSpecificCaches
newZoneSpecificCaches = IO ZoneSpecificCaches
createZoneCaches

-- | Reset all caches within the given 'ZoneSpecificCaches'.
--
-- Clears all internal state, including token counts, timestamps, and queues.
--
-- ==== __Examples__
--
-- @
-- resetSingleZoneCaches zoneCaches
-- @
resetSingleZoneCaches :: ZoneSpecificCaches -> IO ()
resetSingleZoneCaches :: ZoneSpecificCaches -> IO ()
resetSingleZoneCaches ZoneSpecificCaches
zsc = do
  Cache (InMemoryStore 'FixedWindow) -> IO ()
forall store. ResettableStore store => Cache store -> IO ()
cacheReset (ZoneSpecificCaches -> Cache (InMemoryStore 'FixedWindow)
zscCounterCache ZoneSpecificCaches
zsc)
  Cache (InMemoryStore 'SlidingWindow) -> IO ()
forall store. ResettableStore store => Cache store -> IO ()
cacheReset (ZoneSpecificCaches -> Cache (InMemoryStore 'SlidingWindow)
zscTimestampCache ZoneSpecificCaches
zsc)
  Cache (InMemoryStore 'TokenBucket) -> IO ()
forall store. ResettableStore store => Cache store -> IO ()
cacheReset (ZoneSpecificCaches -> Cache (InMemoryStore 'TokenBucket)
zscTokenBucketCache ZoneSpecificCaches
zsc)
  Cache (InMemoryStore 'LeakyBucket) -> IO ()
forall store. ResettableStore store => Cache store -> IO ()
cacheReset (ZoneSpecificCaches -> Cache (InMemoryStore 'LeakyBucket)
zscLeakyBucketCache ZoneSpecificCaches
zsc)
  Cache (InMemoryStore 'TinyLRU) -> IO ()
forall store. ResettableStore store => Cache store -> IO ()
cacheReset (ZoneSpecificCaches -> Cache (InMemoryStore 'TinyLRU)
zscTinyLRUCache ZoneSpecificCaches
zsc)

-- | Reset a single cache for a specific algorithm within the given 'ZoneSpecificCaches'.
--
-- This is useful when only one type of rate limiter needs a reset.
--
-- ==== __Examples__
--
-- @
-- resetZoneCache zoneCaches TokenBucket
-- @
resetZoneCache :: ZoneSpecificCaches -> Algorithm -> IO ()
resetZoneCache :: ZoneSpecificCaches -> Algorithm -> IO ()
resetZoneCache ZoneSpecificCaches
zsc Algorithm
algorithm = case Algorithm
algorithm of
  Algorithm
FixedWindow   -> Cache (InMemoryStore 'FixedWindow) -> IO ()
forall store. ResettableStore store => Cache store -> IO ()
cacheReset (ZoneSpecificCaches -> Cache (InMemoryStore 'FixedWindow)
zscCounterCache ZoneSpecificCaches
zsc)
  Algorithm
SlidingWindow -> Cache (InMemoryStore 'SlidingWindow) -> IO ()
forall store. ResettableStore store => Cache store -> IO ()
cacheReset (ZoneSpecificCaches -> Cache (InMemoryStore 'SlidingWindow)
zscTimestampCache ZoneSpecificCaches
zsc)
  Algorithm
TokenBucket   -> Cache (InMemoryStore 'TokenBucket) -> IO ()
forall store. ResettableStore store => Cache store -> IO ()
cacheReset (ZoneSpecificCaches -> Cache (InMemoryStore 'TokenBucket)
zscTokenBucketCache ZoneSpecificCaches
zsc)
  Algorithm
LeakyBucket   -> Cache (InMemoryStore 'LeakyBucket) -> IO ()
forall store. ResettableStore store => Cache store -> IO ()
cacheReset (ZoneSpecificCaches -> Cache (InMemoryStore 'LeakyBucket)
zscLeakyBucketCache ZoneSpecificCaches
zsc)
  Algorithm
TinyLRU       -> Cache (InMemoryStore 'TinyLRU) -> IO ()
forall store. ResettableStore store => Cache store -> IO ()
cacheReset (ZoneSpecificCaches -> Cache (InMemoryStore 'TinyLRU)
zscTinyLRUCache ZoneSpecificCaches
zsc)

-- | Convert a socket address into an IP zone identifier.
--
-- IPv4 addresses are rendered using 'fromHostAddress'. IPv6 addresses are
-- expanded and zero-padded for consistency. Any unknown or unsupported
-- address formats fall back to 'defaultIPZone'.
--
-- ==== __Examples__
--
-- @
-- zone <- sockAddrToIPZone (SockAddrInet 80 0x7f000001)
-- print zone  -- \"127.0.0.1\"
-- @
sockAddrToIPZone :: SockAddr -> IO Text
sockAddrToIPZone :: SockAddr -> IO IPZoneIdentifier
sockAddrToIPZone (SockAddrInet PortNumber
_ Word32
hostAddr) = do
  let ip :: IPv4
ip = Word32 -> IPv4
fromHostAddress Word32
hostAddr
  IPZoneIdentifier -> IO IPZoneIdentifier
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (IPZoneIdentifier -> IO IPZoneIdentifier)
-> IPZoneIdentifier -> IO IPZoneIdentifier
forall a b. (a -> b) -> a -> b
$ String -> IPZoneIdentifier
T.pack (String -> IPZoneIdentifier) -> String -> IPZoneIdentifier
forall a b. (a -> b) -> a -> b
$ IPv4 -> String
forall a. Show a => a -> String
show IPv4
ip
sockAddrToIPZone (SockAddrInet6 PortNumber
_ Word32
_ (Word32
w1, Word32
w2, Word32
w3, Word32
w4) Word32
_) = 
  IPZoneIdentifier -> IO IPZoneIdentifier
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (IPZoneIdentifier -> IO IPZoneIdentifier)
-> IPZoneIdentifier -> IO IPZoneIdentifier
forall a b. (a -> b) -> a -> b
$ IPZoneIdentifier -> [IPZoneIdentifier] -> IPZoneIdentifier
T.intercalate IPZoneIdentifier
":" ([IPZoneIdentifier] -> IPZoneIdentifier)
-> [IPZoneIdentifier] -> IPZoneIdentifier
forall a b. (a -> b) -> a -> b
$ (Word32 -> IPZoneIdentifier) -> [Word32] -> [IPZoneIdentifier]
forall a b. (a -> b) -> [a] -> [b]
map (String -> IPZoneIdentifier
T.pack (String -> IPZoneIdentifier)
-> (Word32 -> String) -> Word32 -> IPZoneIdentifier
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word32 -> String
forall {a}. Integral a => a -> String
showHexWord) 
    [Word32
w1 Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
16, Word32
w1 Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
0xFFFF, Word32
w2 Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
16, Word32
w2 Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
0xFFFF, 
     Word32
w3 Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
16, Word32
w3 Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
0xFFFF, Word32
w4 Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
16, Word32
w4 Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
0xFFFF]
  where
    showHexWord :: a -> String
showHexWord a
n = let s :: String
s = a -> ShowS
forall a. Integral a => a -> ShowS
showHex a
n String
"" in if String -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length String
s Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
4 then Int -> Char -> String
forall a. Int -> a -> [a]
replicate (Int
4 Int -> Int -> Int
forall a. Num a => a -> a -> a
- String -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length String
s) Char
'0' String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
s else String
s
sockAddrToIPZone SockAddr
_ = IPZoneIdentifier -> IO IPZoneIdentifier
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return IPZoneIdentifier
defaultIPZone