{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
module Langchain.Retriever.MultiQueryRetriever
( MultiQueryRetriever (..)
, QueryGenerationPrompt (..)
, newMultiQueryRetriever
, defaultQueryGenerationPrompt
, newMultiQueryRetrieverWithConfig
, defaultMultiQueryRetrieverConfig
, generateQueries
) where
import Langchain.DocumentLoader.Core (Document)
import Langchain.LLM.Core (LLM (..))
import Langchain.OutputParser.Core (NumberSeparatedList (..), OutputParser (..))
import Langchain.PromptTemplate (PromptTemplate (..), renderPrompt)
import Langchain.Retriever.Core (Retriever (..))
import qualified Langchain.Runnable.Core as Run
import Data.Either (rights)
import Data.List (nub)
import qualified Data.Map.Strict as HM
import Data.Text (Text)
import qualified Data.Text as T
newtype QueryGenerationPrompt = QueryGenerationPrompt PromptTemplate
deriving (Int -> QueryGenerationPrompt -> ShowS
[QueryGenerationPrompt] -> ShowS
QueryGenerationPrompt -> String
(Int -> QueryGenerationPrompt -> ShowS)
-> (QueryGenerationPrompt -> String)
-> ([QueryGenerationPrompt] -> ShowS)
-> Show QueryGenerationPrompt
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> QueryGenerationPrompt -> ShowS
showsPrec :: Int -> QueryGenerationPrompt -> ShowS
$cshow :: QueryGenerationPrompt -> String
show :: QueryGenerationPrompt -> String
$cshowList :: [QueryGenerationPrompt] -> ShowS
showList :: [QueryGenerationPrompt] -> ShowS
Show, QueryGenerationPrompt -> QueryGenerationPrompt -> Bool
(QueryGenerationPrompt -> QueryGenerationPrompt -> Bool)
-> (QueryGenerationPrompt -> QueryGenerationPrompt -> Bool)
-> Eq QueryGenerationPrompt
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: QueryGenerationPrompt -> QueryGenerationPrompt -> Bool
== :: QueryGenerationPrompt -> QueryGenerationPrompt -> Bool
$c/= :: QueryGenerationPrompt -> QueryGenerationPrompt -> Bool
/= :: QueryGenerationPrompt -> QueryGenerationPrompt -> Bool
Eq)
defaultQueryGenerationPrompt :: QueryGenerationPrompt
defaultQueryGenerationPrompt :: QueryGenerationPrompt
defaultQueryGenerationPrompt =
PromptTemplate -> QueryGenerationPrompt
QueryGenerationPrompt (PromptTemplate -> QueryGenerationPrompt)
-> PromptTemplate -> QueryGenerationPrompt
forall a b. (a -> b) -> a -> b
$
PromptTemplate
{ templateString :: Text
templateString =
[Text] -> Text
T.unlines
[ Text
"You are an AI language model assistant that helps users by generating multiple search queries based on their initial query."
, Text
"These queries should help retrieve relevant documents or information from a vector database."
, Text
""
, Text
"Original query: {query}"
, Text
""
, Text
"Please generate {num_queries} different versions of this query that will help the user find the most relevant information."
, Text
"The queries should be different but related to the original query."
, Text
"Return these queries in the following format: 1. query 1 \n 2. query 2 \n 3. query 3"
, Text
"Only return queries and nothing else"
]
}
data MultiQueryRetrieverConfig = MultiQueryRetrieverConfig
{ MultiQueryRetrieverConfig -> Int
numQueries :: Int
, MultiQueryRetrieverConfig -> QueryGenerationPrompt
queryGenerationPrompt :: QueryGenerationPrompt
, MultiQueryRetrieverConfig -> Bool
includeMergeDocs :: Bool
, MultiQueryRetrieverConfig -> Bool
includeOriginalQuery :: Bool
}
defaultMultiQueryRetrieverConfig :: MultiQueryRetrieverConfig
defaultMultiQueryRetrieverConfig :: MultiQueryRetrieverConfig
defaultMultiQueryRetrieverConfig =
MultiQueryRetrieverConfig
{ numQueries :: Int
numQueries = Int
3
, queryGenerationPrompt :: QueryGenerationPrompt
queryGenerationPrompt = QueryGenerationPrompt
defaultQueryGenerationPrompt
, includeMergeDocs :: Bool
includeMergeDocs = Bool
True
, includeOriginalQuery :: Bool
includeOriginalQuery = Bool
True
}
data (Retriever a, LLM m) => MultiQueryRetriever a m = MultiQueryRetriever
{ forall a m. (Retriever a, LLM m) => MultiQueryRetriever a m -> a
retriever :: a
, forall a m. (Retriever a, LLM m) => MultiQueryRetriever a m -> m
llm :: m
, forall a m.
(Retriever a, LLM m) =>
MultiQueryRetriever a m -> MultiQueryRetrieverConfig
config :: MultiQueryRetrieverConfig
}
newMultiQueryRetriever :: (Retriever a, LLM m) => a -> m -> MultiQueryRetriever a m
newMultiQueryRetriever :: forall a m.
(Retriever a, LLM m) =>
a -> m -> MultiQueryRetriever a m
newMultiQueryRetriever a
r m
l =
MultiQueryRetriever
{ retriever :: a
retriever = a
r
, llm :: m
llm = m
l
, config :: MultiQueryRetrieverConfig
config = MultiQueryRetrieverConfig
defaultMultiQueryRetrieverConfig
}
newMultiQueryRetrieverWithConfig ::
(Retriever a, LLM m) =>
a ->
m ->
MultiQueryRetrieverConfig ->
MultiQueryRetriever a m
newMultiQueryRetrieverWithConfig :: forall a m.
(Retriever a, LLM m) =>
a -> m -> MultiQueryRetrieverConfig -> MultiQueryRetriever a m
newMultiQueryRetrieverWithConfig a
r m
l MultiQueryRetrieverConfig
c =
MultiQueryRetriever
{ retriever :: a
retriever = a
r
, llm :: m
llm = m
l
, config :: MultiQueryRetrieverConfig
config = MultiQueryRetrieverConfig
c
}
generateQueries ::
LLM m => m -> QueryGenerationPrompt -> Text -> Int -> Bool -> IO (Either String [Text])
generateQueries :: forall m.
LLM m =>
m
-> QueryGenerationPrompt
-> Text
-> Int
-> Bool
-> IO (Either String [Text])
generateQueries m
model (QueryGenerationPrompt PromptTemplate
promptTemplate) Text
query Int
n Bool
includeOriginal = do
let vars :: Map Text Text
vars = [(Text, Text)] -> Map Text Text
forall k a. Ord k => [(k, a)] -> Map k a
HM.fromList [(Text
"query", Text
query), (Text
"num_queries", String -> Text
T.pack (String -> Text) -> String -> Text
forall a b. (a -> b) -> a -> b
$ Int -> String
forall a. Show a => a -> String
show Int
n)]
case PromptTemplate -> Map Text Text -> Either String Text
renderPrompt PromptTemplate
promptTemplate Map Text Text
vars of
Left String
err -> Either String [Text] -> IO (Either String [Text])
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String [Text] -> IO (Either String [Text]))
-> Either String [Text] -> IO (Either String [Text])
forall a b. (a -> b) -> a -> b
$ String -> Either String [Text]
forall a b. a -> Either a b
Left String
err
Right Text
prompt -> do
Either String Text
result <- m -> Text -> Maybe Params -> IO (Either String Text)
forall m.
LLM m =>
m -> Text -> Maybe Params -> IO (Either String Text)
generate m
model Text
prompt Maybe Params
forall a. Maybe a
Nothing
case Either String Text
result of
Left String
err -> Either String [Text] -> IO (Either String [Text])
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String [Text] -> IO (Either String [Text]))
-> Either String [Text] -> IO (Either String [Text])
forall a b. (a -> b) -> a -> b
$ String -> Either String [Text]
forall a b. a -> Either a b
Left String
err
Right Text
response -> do
case Text -> Either String NumberSeparatedList
forall a. OutputParser a => Text -> Either String a
parse Text
response :: Either String NumberSeparatedList of
Left String
err -> Either String [Text] -> IO (Either String [Text])
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String [Text] -> IO (Either String [Text]))
-> Either String [Text] -> IO (Either String [Text])
forall a b. (a -> b) -> a -> b
$ String -> Either String [Text]
forall a b. a -> Either a b
Left (String -> Either String [Text]) -> String -> Either String [Text]
forall a b. (a -> b) -> a -> b
$ String
"Failed to parse LLM response: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
err
Right (NumberSeparatedList [Text]
queries) -> do
let uniqueQueries :: [Text]
uniqueQueries = [Text] -> [Text]
forall a. Eq a => [a] -> [a]
nub ([Text] -> [Text]) -> [Text] -> [Text]
forall a b. (a -> b) -> a -> b
$ (Text -> Bool) -> [Text] -> [Text]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (Text -> Bool) -> Text -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Bool
T.null) [Text]
queries
Either String [Text] -> IO (Either String [Text])
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String [Text] -> IO (Either String [Text]))
-> Either String [Text] -> IO (Either String [Text])
forall a b. (a -> b) -> a -> b
$
[Text] -> Either String [Text]
forall a b. b -> Either a b
Right ([Text] -> Either String [Text]) -> [Text] -> Either String [Text]
forall a b. (a -> b) -> a -> b
$
if Bool
includeOriginal
then Text
query Text -> [Text] -> [Text]
forall a. a -> [a] -> [a]
: [Text]
uniqueQueries
else [Text]
uniqueQueries
combineDocuments :: [[Document]] -> [Document]
combineDocuments :: [[Document]] -> [Document]
combineDocuments [[Document]]
docLists =
[Document] -> [Document]
forall a. Eq a => [a] -> [a]
nub ([Document] -> [Document]) -> [Document] -> [Document]
forall a b. (a -> b) -> a -> b
$ [[Document]] -> [Document]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Document]]
docLists
instance (Retriever a, LLM m) => Retriever (MultiQueryRetriever a m) where
_get_relevant_documents :: MultiQueryRetriever a m -> Text -> IO (Either String [Document])
_get_relevant_documents MultiQueryRetriever a m
r Text
query = do
let baseRetriever :: a
baseRetriever = MultiQueryRetriever a m -> a
forall a m. (Retriever a, LLM m) => MultiQueryRetriever a m -> a
retriever MultiQueryRetriever a m
r
model :: m
model = MultiQueryRetriever a m -> m
forall a m. (Retriever a, LLM m) => MultiQueryRetriever a m -> m
llm MultiQueryRetriever a m
r
cfg :: MultiQueryRetrieverConfig
cfg = MultiQueryRetriever a m -> MultiQueryRetrieverConfig
forall a m.
(Retriever a, LLM m) =>
MultiQueryRetriever a m -> MultiQueryRetrieverConfig
config MultiQueryRetriever a m
r
Either String [Text]
queriesResult <-
m
-> QueryGenerationPrompt
-> Text
-> Int
-> Bool
-> IO (Either String [Text])
forall m.
LLM m =>
m
-> QueryGenerationPrompt
-> Text
-> Int
-> Bool
-> IO (Either String [Text])
generateQueries
m
model
(MultiQueryRetrieverConfig -> QueryGenerationPrompt
queryGenerationPrompt MultiQueryRetrieverConfig
cfg)
Text
query
(MultiQueryRetrieverConfig -> Int
numQueries MultiQueryRetrieverConfig
cfg)
(MultiQueryRetrieverConfig -> Bool
includeOriginalQuery MultiQueryRetrieverConfig
cfg)
case Either String [Text]
queriesResult of
Left String
err -> Either String [Document] -> IO (Either String [Document])
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String [Document] -> IO (Either String [Document]))
-> Either String [Document] -> IO (Either String [Document])
forall a b. (a -> b) -> a -> b
$ String -> Either String [Document]
forall a b. a -> Either a b
Left (String -> Either String [Document])
-> String -> Either String [Document]
forall a b. (a -> b) -> a -> b
$ String
"Error generating queries: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
err
Right [Text]
queries -> do
[Either String [Document]]
results <- (Text -> IO (Either String [Document]))
-> [Text] -> IO [Either String [Document]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (a -> Text -> IO (Either String [Document])
forall a. Retriever a => a -> Text -> IO (Either String [Document])
_get_relevant_documents a
baseRetriever) [Text]
queries
let validResults :: [[Document]]
validResults = [Either String [Document]] -> [[Document]]
forall a b. [Either a b] -> [b]
rights [Either String [Document]]
results
if [[Document]] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [[Document]]
validResults
then Either String [Document] -> IO (Either String [Document])
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String [Document] -> IO (Either String [Document]))
-> Either String [Document] -> IO (Either String [Document])
forall a b. (a -> b) -> a -> b
$ String -> Either String [Document]
forall a b. a -> Either a b
Left String
"No valid results from any query"
else Either String [Document] -> IO (Either String [Document])
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String [Document] -> IO (Either String [Document]))
-> Either String [Document] -> IO (Either String [Document])
forall a b. (a -> b) -> a -> b
$ [Document] -> Either String [Document]
forall a b. b -> Either a b
Right ([Document] -> Either String [Document])
-> [Document] -> Either String [Document]
forall a b. (a -> b) -> a -> b
$ [[Document]] -> [Document]
combineDocuments [[Document]]
validResults
instance (Retriever a, LLM m) => Run.Runnable (MultiQueryRetriever a m) where
type RunnableInput (MultiQueryRetriever a m) = Text
type RunnableOutput (MultiQueryRetriever a m) = [Document]
invoke :: MultiQueryRetriever a m
-> RunnableInput (MultiQueryRetriever a m)
-> IO (Either String (RunnableOutput (MultiQueryRetriever a m)))
invoke MultiQueryRetriever a m
r RunnableInput (MultiQueryRetriever a m)
query = MultiQueryRetriever a m -> Text -> IO (Either String [Document])
forall a. Retriever a => a -> Text -> IO (Either String [Document])
_get_relevant_documents MultiQueryRetriever a m
r Text
RunnableInput (MultiQueryRetriever a m)
query