{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Langchain.Runnable.Utils
(
WithConfig (..)
, Cached (..)
, cached
, Retry (..)
, WithTimeout (..)
) where
import Control.Concurrent
import Data.Map.Strict as Map
import Langchain.Error (llmError)
import Langchain.Runnable.Core
data WithConfig config r
= (Runnable r) =>
WithConfig
{ forall config r. WithConfig config r -> r
configuredRunnable :: r
, forall config r. WithConfig config r -> config
runnableConfig :: config
}
instance (Runnable r) => Runnable (WithConfig config r) where
type RunnableInput (WithConfig config r) = RunnableInput r
type RunnableOutput (WithConfig config r) = RunnableOutput r
invoke :: WithConfig config r
-> RunnableInput (WithConfig config r)
-> IO (LangchainResult (RunnableOutput (WithConfig config r)))
invoke (WithConfig r
r1 config
_) = r -> RunnableInput r -> IO (LangchainResult (RunnableOutput r))
forall r.
Runnable r =>
r -> RunnableInput r -> IO (LangchainResult (RunnableOutput r))
invoke r
r1
data Cached r
= (Runnable r, Ord (RunnableInput r)) =>
Cached
{ forall r. Cached r -> r
cachedRunnable :: r
, forall r.
Cached r -> MVar (Map (RunnableInput r) (RunnableOutput r))
cacheMap :: MVar (Map.Map (RunnableInput r) (RunnableOutput r))
}
cached :: (Runnable r, Ord (RunnableInput r)) => r -> IO (Cached r)
cached :: forall r. (Runnable r, Ord (RunnableInput r)) => r -> IO (Cached r)
cached r
r = do
cache <- Map (RunnableInput r) (RunnableOutput r)
-> IO (MVar (Map (RunnableInput r) (RunnableOutput r)))
forall a. a -> IO (MVar a)
newMVar Map (RunnableInput r) (RunnableOutput r)
forall k a. Map k a
Map.empty
return $ Cached r cache
instance (Runnable r, Ord (RunnableInput r)) => Runnable (Cached r) where
type RunnableInput (Cached r) = RunnableInput r
type RunnableOutput (Cached r) = RunnableOutput r
invoke :: Cached r
-> RunnableInput (Cached r)
-> IO (LangchainResult (RunnableOutput (Cached r)))
invoke (Cached r
r MVar (Map (RunnableInput r) (RunnableOutput r))
cacheRef) RunnableInput (Cached r)
input = do
cache <- MVar (Map (RunnableInput r) (RunnableOutput r))
-> IO (Map (RunnableInput r) (RunnableOutput r))
forall a. MVar a -> IO a
readMVar MVar (Map (RunnableInput r) (RunnableOutput r))
cacheRef
case Map.lookup input cache of
Just RunnableOutput r
output -> Either LangchainError (RunnableOutput r)
-> IO (Either LangchainError (RunnableOutput r))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either LangchainError (RunnableOutput r)
-> IO (Either LangchainError (RunnableOutput r)))
-> Either LangchainError (RunnableOutput r)
-> IO (Either LangchainError (RunnableOutput r))
forall a b. (a -> b) -> a -> b
$ RunnableOutput r -> Either LangchainError (RunnableOutput r)
forall a b. b -> Either a b
Right RunnableOutput r
output
Maybe (RunnableOutput r)
Nothing -> do
result <- r
-> RunnableInput r -> IO (Either LangchainError (RunnableOutput r))
forall r.
Runnable r =>
r -> RunnableInput r -> IO (LangchainResult (RunnableOutput r))
invoke r
r RunnableInput r
RunnableInput (Cached r)
input
case result of
Left LangchainError
err -> Either LangchainError (RunnableOutput r)
-> IO (Either LangchainError (RunnableOutput r))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either LangchainError (RunnableOutput r)
-> IO (Either LangchainError (RunnableOutput r)))
-> Either LangchainError (RunnableOutput r)
-> IO (Either LangchainError (RunnableOutput r))
forall a b. (a -> b) -> a -> b
$ LangchainError -> Either LangchainError (RunnableOutput r)
forall a b. a -> Either a b
Left LangchainError
err
Right RunnableOutput r
output -> do
MVar (Map (RunnableInput r) (RunnableOutput r))
-> (Map (RunnableInput r) (RunnableOutput r)
-> IO (Map (RunnableInput r) (RunnableOutput r)))
-> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ MVar (Map (RunnableInput r) (RunnableOutput r))
cacheRef ((Map (RunnableInput r) (RunnableOutput r)
-> IO (Map (RunnableInput r) (RunnableOutput r)))
-> IO ())
-> (Map (RunnableInput r) (RunnableOutput r)
-> IO (Map (RunnableInput r) (RunnableOutput r)))
-> IO ()
forall a b. (a -> b) -> a -> b
$ \Map (RunnableInput r) (RunnableOutput r)
c -> Map (RunnableInput r) (RunnableOutput r)
-> IO (Map (RunnableInput r) (RunnableOutput r))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Map (RunnableInput r) (RunnableOutput r)
-> IO (Map (RunnableInput r) (RunnableOutput r)))
-> Map (RunnableInput r) (RunnableOutput r)
-> IO (Map (RunnableInput r) (RunnableOutput r))
forall a b. (a -> b) -> a -> b
$ RunnableInput r
-> RunnableOutput r
-> Map (RunnableInput r) (RunnableOutput r)
-> Map (RunnableInput r) (RunnableOutput r)
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert RunnableInput r
RunnableInput (Cached r)
input RunnableOutput r
output Map (RunnableInput r) (RunnableOutput r)
c
Either LangchainError (RunnableOutput r)
-> IO (Either LangchainError (RunnableOutput r))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either LangchainError (RunnableOutput r)
-> IO (Either LangchainError (RunnableOutput r)))
-> Either LangchainError (RunnableOutput r)
-> IO (Either LangchainError (RunnableOutput r))
forall a b. (a -> b) -> a -> b
$ RunnableOutput r -> Either LangchainError (RunnableOutput r)
forall a b. b -> Either a b
Right RunnableOutput r
output
data Retry r
= (Runnable r) =>
Retry
{ forall r. Retry r -> r
retryRunnable :: r
, forall r. Retry r -> Int
maxRetries :: Int
, forall r. Retry r -> Int
retryDelay :: Int
}
instance (Runnable r) => Runnable (Retry r) where
type RunnableInput (Retry r) = RunnableInput r
type RunnableOutput (Retry r) = RunnableOutput r
invoke :: Retry r
-> RunnableInput (Retry r)
-> IO (LangchainResult (RunnableOutput (Retry r)))
invoke (Retry r
r Int
maxRetries_ Int
delay) RunnableInput (Retry r)
input = Int -> IO (Either LangchainError (RunnableOutput r))
retryWithCount Int
0
where
retryWithCount :: Int -> IO (Either LangchainError (RunnableOutput r))
retryWithCount Int
count = do
result <- r
-> RunnableInput r -> IO (Either LangchainError (RunnableOutput r))
forall r.
Runnable r =>
r -> RunnableInput r -> IO (LangchainResult (RunnableOutput r))
invoke r
r RunnableInput r
RunnableInput (Retry r)
input
case result of
Left LangchainError
err ->
if Int
count Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
maxRetries_
then do
Int -> IO ()
threadDelay Int
delay
Int -> IO (Either LangchainError (RunnableOutput r))
retryWithCount (Int
count Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
else Either LangchainError (RunnableOutput r)
-> IO (Either LangchainError (RunnableOutput r))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either LangchainError (RunnableOutput r)
-> IO (Either LangchainError (RunnableOutput r)))
-> Either LangchainError (RunnableOutput r)
-> IO (Either LangchainError (RunnableOutput r))
forall a b. (a -> b) -> a -> b
$ LangchainError -> Either LangchainError (RunnableOutput r)
forall a b. a -> Either a b
Left LangchainError
err
Right RunnableOutput r
output -> Either LangchainError (RunnableOutput r)
-> IO (Either LangchainError (RunnableOutput r))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either LangchainError (RunnableOutput r)
-> IO (Either LangchainError (RunnableOutput r)))
-> Either LangchainError (RunnableOutput r)
-> IO (Either LangchainError (RunnableOutput r))
forall a b. (a -> b) -> a -> b
$ RunnableOutput r -> Either LangchainError (RunnableOutput r)
forall a b. b -> Either a b
Right RunnableOutput r
output
data WithTimeout r
= (Runnable r) =>
WithTimeout
{ forall r. WithTimeout r -> r
timeoutRunnable :: r
, forall r. WithTimeout r -> Int
timeoutMicroseconds :: Int
}
instance (Runnable r) => Runnable (WithTimeout r) where
type RunnableInput (WithTimeout r) = RunnableInput r
type RunnableOutput (WithTimeout r) = RunnableOutput r
invoke :: WithTimeout r
-> RunnableInput (WithTimeout r)
-> IO (LangchainResult (RunnableOutput (WithTimeout r)))
invoke (WithTimeout r
r Int
timeout) RunnableInput (WithTimeout r)
input = do
resultVar <- IO (MVar (Maybe (Either LangchainError (RunnableOutput r))))
forall a. IO (MVar a)
newEmptyMVar
tid <- forkIO $ do
result <- invoke r input
putMVar resultVar (Just result)
timeoutTid <- forkIO $ do
threadDelay timeout
putMVar resultVar Nothing
result <- takeMVar resultVar
killThread tid
killThread timeoutTid
case result of
Just Either LangchainError (RunnableOutput r)
r_ -> Either LangchainError (RunnableOutput r)
-> IO (Either LangchainError (RunnableOutput r))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Either LangchainError (RunnableOutput r)
r_
Maybe (Either LangchainError (RunnableOutput r))
Nothing -> Either LangchainError (RunnableOutput r)
-> IO (Either LangchainError (RunnableOutput r))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either LangchainError (RunnableOutput r)
-> IO (Either LangchainError (RunnableOutput r)))
-> Either LangchainError (RunnableOutput r)
-> IO (Either LangchainError (RunnableOutput r))
forall a b. (a -> b) -> a -> b
$ LangchainError -> Either LangchainError (RunnableOutput r)
forall a b. a -> Either a b
Left (Text -> Maybe Text -> Maybe Text -> LangchainError
llmError Text
"Operation timed out" Maybe Text
forall a. Maybe a
Nothing Maybe Text
forall a. Maybe a
Nothing)