module Multitasking.RateLimit
  ( -- ** Throttling
    RateLimit,
    throttle,
    maxConcurrentTasks,
    tokenBucket,
  )
where

import Control.Concurrent.STM
import Control.Exception (finally)
import Control.Monad (forever, when)
import Control.Monad.IO.Class
import Multitasking.Communication
import Multitasking.Core
import Multitasking.MonadSTM
import Multitasking.Waiting

-- | Specifies a rate limit which can be applied to actions
newtype RateLimit = RateLimit (forall a. IO a -> IO a)

-- | Delay the given action according to the 'RateLimit'.
throttle :: RateLimit -> IO a -> IO a
throttle :: forall a. RateLimit -> IO a -> IO a
throttle (RateLimit forall a. IO a -> IO a
f) = IO a -> IO a
forall a. IO a -> IO a
f

-- | Limits concurrency to exactly N tasks. If N tasks are already running, the next one needs to wait.
maxConcurrentTasks :: (MonadSTM m) => Int -> m RateLimit
maxConcurrentTasks :: forall (m :: * -> *). MonadSTM m => Int -> m RateLimit
maxConcurrentTasks Int
concurrency = do
  counter <- Int -> m Counter
forall (m :: * -> *). MonadSTM m => Int -> m Counter
newCounter Int
concurrency
  pure $ RateLimit $ \IO a
action -> do
    STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
      value <- Counter -> STM Int
forall (m :: * -> *). MonadSTM m => Counter -> m Int
getCounter Counter
counter
      check (value > 0)
      decrementCounter counter
    IO a
action IO a -> IO () -> IO a
forall a b. IO a -> IO b -> IO a
`finally` Counter -> IO ()
forall (m :: * -> *). MonadSTM m => Counter -> m ()
incrementCounter Counter
counter

-- | Limits concurrency to a specific rate with the tocket bucket strategy.
tokenBucket ::
  (MonadIO m) =>
  Coordinator ->
  -- | Specify the rate as a 'Duration':  @fromSeconds (1 / X)@ for X tasks per second
  Duration ->
  -- | Allowed burst tasks: X burst tasks means that X tasks can start without waiting
  Int ->
  m RateLimit
tokenBucket :: forall (m :: * -> *).
MonadIO m =>
Coordinator -> Duration -> Int -> m RateLimit
tokenBucket Coordinator
coordinator Duration
recharge Int
burst' = IO RateLimit -> m RateLimit
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO RateLimit -> m RateLimit) -> IO RateLimit -> m RateLimit
forall a b. (a -> b) -> a -> b
$ do
  let burst :: Int
burst = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
burst' Int
1
  counter <- Int -> IO Counter
forall (m :: * -> *). MonadSTM m => Int -> m Counter
newCounter Int
burst
  _ <- start coordinator $ forever $ do
    waitDuration recharge
    atomically $ do
      value <- getCounter counter
      when (value < burst) (incrementCounter counter)
  pure $ RateLimit $ \IO a
action -> do
    STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
      value <- Counter -> STM Int
forall (m :: * -> *). MonadSTM m => Counter -> m Int
getCounter Counter
counter
      check (value > 0)
      decrementCounter counter
    IO a
action