{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

{- |
Module      : Langchain.Runnable.Utils
Description : Utility wrappers for Runnable components in LangChain
Copyright   : (c) 2025 Tushar Adhatrao
License     : MIT
Maintainer  : Tushar Adhatrao <tusharadhatrao@gmail.com>

This module provides various utility wrappers for 'Runnable' components that enhance
their behavior with common patterns like:

* Configuration management
* Result caching
* Automatic retries
* Timeout handling

These utilities follow the decorator pattern, wrapping existing 'Runnable' instances
with additional functionality while preserving the original input/output types.

Note: This module is experimental and the API may change in future versions.
-}
module Langchain.Runnable.Utils
  ( -- * Configuration Management
    WithConfig (..)

    -- * Caching
  , Cached (..)
  , cached

    -- * Resilience Patterns
  , Retry (..)
  , WithTimeout (..)
  ) where

import Control.Concurrent
import Data.Map.Strict as Map
import Langchain.Runnable.Core

{- | Wrapper for 'Runnable' components with configurable behavior.

This wrapper allows attaching configuration data to a 'Runnable' instance.
The configuration data can be accessed and modified without changing the
underlying 'Runnable' implementation.

Example:

@
data LLMConfig = LLMConfig
  { temperature :: Float
  , maxTokens :: Int
  }

let
  baseModel = OpenAI defaultOpenAIConfig
  configuredModel = WithConfig
    { configuredRunnable = baseModel
    , runnableConfig = LLMConfig 0.7 100
    }

-- Later, modify the configuration without changing the model
let updatedModel = configuredModel { runnableConfig = LLMConfig 0.9 150 }

-- Use the model as a regular Runnable
result <- invoke updatedModel "Explain monads in Haskell"
@
-}
data WithConfig config r
  = (Runnable r) =>
  WithConfig
  { forall config r. WithConfig config r -> r
configuredRunnable :: r
  -- ^ The wrapped 'Runnable' instance
  , forall config r. WithConfig config r -> config
runnableConfig :: config
  -- ^ Configuration data for this 'Runnable'
  }

-- | Make WithConfig a Runnable that applies the configuration
instance (Runnable r) => Runnable (WithConfig config r) where
  type RunnableInput (WithConfig config r) = RunnableInput r
  type RunnableOutput (WithConfig config r) = RunnableOutput r

  invoke :: WithConfig config r
-> RunnableInput (WithConfig config r)
-> IO (Either String (RunnableOutput (WithConfig config r)))
invoke (WithConfig r
r1 config
_) RunnableInput (WithConfig config r)
input = r -> RunnableInput r -> IO (Either String (RunnableOutput r))
forall r.
Runnable r =>
r -> RunnableInput r -> IO (Either String (RunnableOutput r))
invoke r
r1 RunnableInput r
RunnableInput (WithConfig config r)
input

{- | Cache results of a 'Runnable' to avoid duplicate computations.

This wrapper stores previously computed results in a thread-safe cache.
When an input is encountered again, the cached result is returned instead
of recomputing it, which can significantly improve performance for expensive
operations or when the same inputs are frequently processed.

Note: The cached results are stored in-memory and will be lost when the program
terminates. For persistent caching, consider implementing a custom wrapper that
uses database storage.

The 'RunnableInput' type must be an instance of 'Ord' for map lookups.
-}
data Cached r
  = (Runnable r, Ord (RunnableInput r)) =>
  Cached
  { forall r. Cached r -> r
cachedRunnable :: r
  -- ^ The wrapped 'Runnable' instance
  , forall r.
Cached r -> MVar (Map (RunnableInput r) (RunnableOutput r))
cacheMap :: MVar (Map.Map (RunnableInput r) (RunnableOutput r))
  -- ^ Thread-safe cache storage
  }

{- | Create a new cached 'Runnable'.

This function initializes an empty cache and wraps the provided 'Runnable'
in a 'Cached' wrapper.

Example:

@
main = do
  -- Create a cached LLM to avoid redundant API calls
  let expensiveModel = OpenAI { model = "gpt-4", temperature = 0.7 }
  cachedModel <- cached expensiveModel

  -- These will all use the same cached result for identical inputs
  result1 <- invoke cachedModel "What is functional programming?"
  result2 <- invoke cachedModel "What is functional programming?"
  result3 <- invoke cachedModel "What is functional programming?"

  -- This will compute a new result
  result4 <- invoke cachedModel "What is Haskell?"
@
-}
cached :: (Runnable r, Ord (RunnableInput r)) => r -> IO (Cached r)
cached :: forall r. (Runnable r, Ord (RunnableInput r)) => r -> IO (Cached r)
cached r
r = do
  MVar (Map (RunnableInput r) (RunnableOutput r))
cache <- Map (RunnableInput r) (RunnableOutput r)
-> IO (MVar (Map (RunnableInput r) (RunnableOutput r)))
forall a. a -> IO (MVar a)
newMVar Map (RunnableInput r) (RunnableOutput r)
forall k a. Map k a
Map.empty
  Cached r -> IO (Cached r)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Cached r -> IO (Cached r)) -> Cached r -> IO (Cached r)
forall a b. (a -> b) -> a -> b
$ r -> MVar (Map (RunnableInput r) (RunnableOutput r)) -> Cached r
forall r.
(Runnable r, Ord (RunnableInput r)) =>
r -> MVar (Map (RunnableInput r) (RunnableOutput r)) -> Cached r
Cached r
r MVar (Map (RunnableInput r) (RunnableOutput r))
cache

-- | Make Cached a Runnable that uses a cache
instance (Runnable r, Ord (RunnableInput r)) => Runnable (Cached r) where
  type RunnableInput (Cached r) = RunnableInput r
  type RunnableOutput (Cached r) = RunnableOutput r

  invoke :: Cached r
-> RunnableInput (Cached r)
-> IO (Either String (RunnableOutput (Cached r)))
invoke (Cached r
r MVar (Map (RunnableInput r) (RunnableOutput r))
cacheRef) RunnableInput (Cached r)
input = do
    Map (RunnableInput r) (RunnableOutput r)
cache <- MVar (Map (RunnableInput r) (RunnableOutput r))
-> IO (Map (RunnableInput r) (RunnableOutput r))
forall a. MVar a -> IO a
readMVar MVar (Map (RunnableInput r) (RunnableOutput r))
cacheRef
    case RunnableInput r
-> Map (RunnableInput r) (RunnableOutput r)
-> Maybe (RunnableOutput r)
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup RunnableInput r
RunnableInput (Cached r)
input Map (RunnableInput r) (RunnableOutput r)
cache of
      Just RunnableOutput r
output -> Either String (RunnableOutput r)
-> IO (Either String (RunnableOutput r))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String (RunnableOutput r)
 -> IO (Either String (RunnableOutput r)))
-> Either String (RunnableOutput r)
-> IO (Either String (RunnableOutput r))
forall a b. (a -> b) -> a -> b
$ RunnableOutput r -> Either String (RunnableOutput r)
forall a b. b -> Either a b
Right RunnableOutput r
output -- Cache hit: return cached result
      Maybe (RunnableOutput r)
Nothing -> do
        -- Cache miss: compute and store resul
        Either String (RunnableOutput r)
result <- r -> RunnableInput r -> IO (Either String (RunnableOutput r))
forall r.
Runnable r =>
r -> RunnableInput r -> IO (Either String (RunnableOutput r))
invoke r
r RunnableInput r
RunnableInput (Cached r)
input
        case Either String (RunnableOutput r)
result of
          Left String
err -> Either String (RunnableOutput r)
-> IO (Either String (RunnableOutput r))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String (RunnableOutput r)
 -> IO (Either String (RunnableOutput r)))
-> Either String (RunnableOutput r)
-> IO (Either String (RunnableOutput r))
forall a b. (a -> b) -> a -> b
$ String -> Either String (RunnableOutput r)
forall a b. a -> Either a b
Left String
err
          Right RunnableOutput r
output -> do
            MVar (Map (RunnableInput r) (RunnableOutput r))
-> (Map (RunnableInput r) (RunnableOutput r)
    -> IO (Map (RunnableInput r) (RunnableOutput r)))
-> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ MVar (Map (RunnableInput r) (RunnableOutput r))
cacheRef ((Map (RunnableInput r) (RunnableOutput r)
  -> IO (Map (RunnableInput r) (RunnableOutput r)))
 -> IO ())
-> (Map (RunnableInput r) (RunnableOutput r)
    -> IO (Map (RunnableInput r) (RunnableOutput r)))
-> IO ()
forall a b. (a -> b) -> a -> b
$ \Map (RunnableInput r) (RunnableOutput r)
c -> Map (RunnableInput r) (RunnableOutput r)
-> IO (Map (RunnableInput r) (RunnableOutput r))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Map (RunnableInput r) (RunnableOutput r)
 -> IO (Map (RunnableInput r) (RunnableOutput r)))
-> Map (RunnableInput r) (RunnableOutput r)
-> IO (Map (RunnableInput r) (RunnableOutput r))
forall a b. (a -> b) -> a -> b
$ RunnableInput r
-> RunnableOutput r
-> Map (RunnableInput r) (RunnableOutput r)
-> Map (RunnableInput r) (RunnableOutput r)
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert RunnableInput r
RunnableInput (Cached r)
input RunnableOutput r
output Map (RunnableInput r) (RunnableOutput r)
c
            Either String (RunnableOutput r)
-> IO (Either String (RunnableOutput r))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String (RunnableOutput r)
 -> IO (Either String (RunnableOutput r)))
-> Either String (RunnableOutput r)
-> IO (Either String (RunnableOutput r))
forall a b. (a -> b) -> a -> b
$ RunnableOutput r -> Either String (RunnableOutput r)
forall a b. b -> Either a b
Right RunnableOutput r
output

{- | Add retry capability to any 'Runnable'.

This wrapper automatically retries failed operations up to a specified
number of times with a configurable delay between attempts. This is particularly
useful for network operations or external API calls that might fail transiently.

Example:

@
-- Create an LLM with automatic retry for network failures
let
  baseModel = OpenAI defaultConfig
  resilientModel = Retry
    { retryRunnable = baseModel
    , maxRetries = 3
    , retryDelay = 1000000  -- 1 second delay between retries
    }

-- If the API call fails, it will retry up to 3 times
result <- invoke resilientModel "Generate a story about a Haskell programmer"
@
-}
data Retry r
  = (Runnable r) =>
  Retry
  { forall r. Retry r -> r
retryRunnable :: r
  -- ^ The wrapped 'Runnable' instance
  , forall r. Retry r -> Int
maxRetries :: Int
  -- ^ Maximum number of retry attempts
  , forall r. Retry r -> Int
retryDelay :: Int
  -- ^ Delay between retry attempts in microseconds
  }

-- | Make Retry a Runnable that retries on failure
instance (Runnable r) => Runnable (Retry r) where
  type RunnableInput (Retry r) = RunnableInput r
  type RunnableOutput (Retry r) = RunnableOutput r

  invoke :: Retry r
-> RunnableInput (Retry r)
-> IO (Either String (RunnableOutput (Retry r)))
invoke (Retry r
r Int
maxRetries_ Int
delay) RunnableInput (Retry r)
input = Int -> IO (Either String (RunnableOutput r))
retryWithCount Int
0
    where
      retryWithCount :: Int -> IO (Either String (RunnableOutput r))
retryWithCount Int
count = do
        Either String (RunnableOutput r)
result <- r -> RunnableInput r -> IO (Either String (RunnableOutput r))
forall r.
Runnable r =>
r -> RunnableInput r -> IO (Either String (RunnableOutput r))
invoke r
r RunnableInput r
RunnableInput (Retry r)
input
        case Either String (RunnableOutput r)
result of
          Left String
err ->
            if Int
count Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
maxRetries_
              then do
                Int -> IO ()
threadDelay Int
delay
                Int -> IO (Either String (RunnableOutput r))
retryWithCount (Int
count Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
              else Either String (RunnableOutput r)
-> IO (Either String (RunnableOutput r))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String (RunnableOutput r)
 -> IO (Either String (RunnableOutput r)))
-> Either String (RunnableOutput r)
-> IO (Either String (RunnableOutput r))
forall a b. (a -> b) -> a -> b
$ String -> Either String (RunnableOutput r)
forall a b. a -> Either a b
Left String
err
          Right RunnableOutput r
output -> Either String (RunnableOutput r)
-> IO (Either String (RunnableOutput r))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String (RunnableOutput r)
 -> IO (Either String (RunnableOutput r)))
-> Either String (RunnableOutput r)
-> IO (Either String (RunnableOutput r))
forall a b. (a -> b) -> a -> b
$ RunnableOutput r -> Either String (RunnableOutput r)
forall a b. b -> Either a b
Right RunnableOutput r
output

{- | Add timeout capability to any 'Runnable'.

This wrapper enforces a maximum execution time for the wrapped 'Runnable'.
If the operation takes longer than the specified timeout, it is cancelled and
an error is returned. This is useful for limiting the execution time of potentially
long-running operations.

Example:

@
-- Create an LLM with a 30-second timeout
let
  baseModel = OpenAI defaultConfig
  timeboxedModel = WithTimeout
    { timeoutRunnable = baseModel
    , timeoutMicroseconds = 30000000  -- 30 seconds
    }

-- If the API call takes longer than 30 seconds, it will be cancelled
result <- invoke timeboxedModel "Generate a detailed analysis of Haskell's type system"
@

Note: This implementation uses 'forkIO' and 'killThread', which may not always
cleanly terminate the underlying operation, especially for certain types of I/O.
For critical applications, consider implementing a more robust timeout mechanism.
-}
data WithTimeout r
  = (Runnable r) =>
  WithTimeout
  { forall r. WithTimeout r -> r
timeoutRunnable :: r
  -- ^ The wrapped 'Runnable' instance
  , forall r. WithTimeout r -> Int
timeoutMicroseconds :: Int
  -- ^ Timeout duration in microseconds
  }

-- | Make WithTimeout a Runnable that times out
instance (Runnable r) => Runnable (WithTimeout r) where
  type RunnableInput (WithTimeout r) = RunnableInput r
  type RunnableOutput (WithTimeout r) = RunnableOutput r

  invoke :: WithTimeout r
-> RunnableInput (WithTimeout r)
-> IO (Either String (RunnableOutput (WithTimeout r)))
invoke (WithTimeout r
r Int
timeout) RunnableInput (WithTimeout r)
input = do
    MVar (Maybe (Either String (RunnableOutput r)))
resultVar <- IO (MVar (Maybe (Either String (RunnableOutput r))))
forall a. IO (MVar a)
newEmptyMVar

    -- Fork a thread to run the computation
    ThreadId
tid <- IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ do
      Either String (RunnableOutput r)
result <- r -> RunnableInput r -> IO (Either String (RunnableOutput r))
forall r.
Runnable r =>
r -> RunnableInput r -> IO (Either String (RunnableOutput r))
invoke r
r RunnableInput r
RunnableInput (WithTimeout r)
input
      MVar (Maybe (Either String (RunnableOutput r)))
-> Maybe (Either String (RunnableOutput r)) -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar (Maybe (Either String (RunnableOutput r)))
resultVar (Either String (RunnableOutput r)
-> Maybe (Either String (RunnableOutput r))
forall a. a -> Maybe a
Just Either String (RunnableOutput r)
result)

    -- Set up the timeout
    ThreadId
timeoutTid <- IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ do
      Int -> IO ()
threadDelay Int
timeout
      MVar (Maybe (Either String (RunnableOutput r)))
-> Maybe (Either String (RunnableOutput r)) -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar (Maybe (Either String (RunnableOutput r)))
resultVar Maybe (Either String (RunnableOutput r))
forall a. Maybe a
Nothing

    -- Wait for either result or timeout
    Maybe (Either String (RunnableOutput r))
result <- MVar (Maybe (Either String (RunnableOutput r)))
-> IO (Maybe (Either String (RunnableOutput r)))
forall a. MVar a -> IO a
takeMVar MVar (Maybe (Either String (RunnableOutput r)))
resultVar

    -- Kill the other thread
    ThreadId -> IO ()
killThread ThreadId
tid
    ThreadId -> IO ()
killThread ThreadId
timeoutTid

    case Maybe (Either String (RunnableOutput r))
result of
      Just Either String (RunnableOutput r)
r_ -> Either String (RunnableOutput r)
-> IO (Either String (RunnableOutput r))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Either String (RunnableOutput r)
r_
      Maybe (Either String (RunnableOutput r))
Nothing -> Either String (RunnableOutput r)
-> IO (Either String (RunnableOutput r))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String (RunnableOutput r)
 -> IO (Either String (RunnableOutput r)))
-> Either String (RunnableOutput r)
-> IO (Either String (RunnableOutput r))
forall a b. (a -> b) -> a -> b
$ String -> Either String (RunnableOutput r)
forall a b. a -> Either a b
Left String
"Operation timed out"