{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}

{- |
Module      : Langchain.Retriever.MultiQueryRetriever
Description : Multi-query retrieval implementation for LangChain Haskell
Copyright   : (c) 2025 Tushar Adhatrao
License     : MIT
Maintainer  : Tushar Adhatrao <tusharadhatrao@gmail.com>
Stability   : experimental

Advanced retriever implementation that generates multiple queries from a single
input to improve document retrieval. Integrates with LLMs for query expansion
and vector stores for document retrieval

Example usage:

@
-- Create components
ollamaLLM = Ollama "llama3" []
vs = VectorStoreRetriever (createVectorStore ...)

-- Create retriever with default config
mqRetriever = newMultiQueryRetriever vs ollamaLLM

-- Retrieve documents
docs <- _get_relevant_documents mqRetriever "Haskell features"
-- Returns combined results from multiple generated queries
@
-}
module Langchain.Retriever.MultiQueryRetriever
  ( MultiQueryRetriever (..)
  , QueryGenerationPrompt (..)
  , newMultiQueryRetriever
  , defaultQueryGenerationPrompt
  , newMultiQueryRetrieverWithConfig
  , defaultMultiQueryRetrieverConfig
  , generateQueries
  ) where

import Langchain.DocumentLoader.Core (Document)
import Langchain.LLM.Core (LLM (..))
import Langchain.OutputParser.Core (NumberSeparatedList (..), OutputParser (..))
import Langchain.PromptTemplate (PromptTemplate (..), renderPrompt)
import Langchain.Retriever.Core (Retriever (..))
import qualified Langchain.Runnable.Core as Run

import Data.Either (rights)
import Data.List (nub)
import qualified Data.Map.Strict as HM
import Data.Text (Text)
import qualified Data.Text as T

{- | Query generation prompt template
Controls how the LLM generates multiple query variants from the original query.

Example prompt structure:

@
"You are an AI assistant... Original query: {query}... Generate {num_queries} versions..."
@
-}
newtype QueryGenerationPrompt = QueryGenerationPrompt PromptTemplate
  deriving (Int -> QueryGenerationPrompt -> ShowS
[QueryGenerationPrompt] -> ShowS
QueryGenerationPrompt -> String
(Int -> QueryGenerationPrompt -> ShowS)
-> (QueryGenerationPrompt -> String)
-> ([QueryGenerationPrompt] -> ShowS)
-> Show QueryGenerationPrompt
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> QueryGenerationPrompt -> ShowS
showsPrec :: Int -> QueryGenerationPrompt -> ShowS
$cshow :: QueryGenerationPrompt -> String
show :: QueryGenerationPrompt -> String
$cshowList :: [QueryGenerationPrompt] -> ShowS
showList :: [QueryGenerationPrompt] -> ShowS
Show, QueryGenerationPrompt -> QueryGenerationPrompt -> Bool
(QueryGenerationPrompt -> QueryGenerationPrompt -> Bool)
-> (QueryGenerationPrompt -> QueryGenerationPrompt -> Bool)
-> Eq QueryGenerationPrompt
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: QueryGenerationPrompt -> QueryGenerationPrompt -> Bool
== :: QueryGenerationPrompt -> QueryGenerationPrompt -> Bool
$c/= :: QueryGenerationPrompt -> QueryGenerationPrompt -> Bool
/= :: QueryGenerationPrompt -> QueryGenerationPrompt -> Bool
Eq)

{- | Default query generation prompt
Generates 3 query variants in numbered list format. Includes instructions for
query diversity and formatting.
-}
defaultQueryGenerationPrompt :: QueryGenerationPrompt
defaultQueryGenerationPrompt :: QueryGenerationPrompt
defaultQueryGenerationPrompt =
  PromptTemplate -> QueryGenerationPrompt
QueryGenerationPrompt (PromptTemplate -> QueryGenerationPrompt)
-> PromptTemplate -> QueryGenerationPrompt
forall a b. (a -> b) -> a -> b
$
    PromptTemplate
      { templateString :: Text
templateString =
          [Text] -> Text
T.unlines
            [ Text
"You are an AI language model assistant that helps users by generating multiple search queries based on their initial query."
            , Text
"These queries should help retrieve relevant documents or information from a vector database."
            , Text
""
            , Text
"Original query: {query}"
            , Text
""
            , Text
"Please generate {num_queries} different versions of this query that will help the user find the most relevant information."
            , Text
"The queries should be different but related to the original query."
            , Text
"Return these queries in the following format: 1. query 1 \n 2. query 2 \n 3. query 3"
            , Text
"Only return queries and nothing else"
            ]
      }

{- | Configuration for multi-query retrieval
-}
data MultiQueryRetrieverConfig = MultiQueryRetrieverConfig
  { MultiQueryRetrieverConfig -> Int
numQueries :: Int
  -- ^ Number of queries to generate
  , MultiQueryRetrieverConfig -> QueryGenerationPrompt
queryGenerationPrompt :: QueryGenerationPrompt
  -- ^ Prompt template for query generation
  , MultiQueryRetrieverConfig -> Bool
includeMergeDocs :: Bool
  -- ^ Whether to include merged documents
  , MultiQueryRetrieverConfig -> Bool
includeOriginalQuery :: Bool
  -- ^ Whether to include results from original query
  }

{- | Default configuration
- 3 generated queries
- Includes original query results
- Uses default query generation prompt
-}
defaultMultiQueryRetrieverConfig :: MultiQueryRetrieverConfig
defaultMultiQueryRetrieverConfig :: MultiQueryRetrieverConfig
defaultMultiQueryRetrieverConfig =
  MultiQueryRetrieverConfig
    { numQueries :: Int
numQueries = Int
3
    , queryGenerationPrompt :: QueryGenerationPrompt
queryGenerationPrompt = QueryGenerationPrompt
defaultQueryGenerationPrompt
    , includeMergeDocs :: Bool
includeMergeDocs = Bool
True
    , includeOriginalQuery :: Bool
includeOriginalQuery = Bool
True
    }

{- | Multi-query retriever implementation
Generates multiple queries using an LLM, retrieves documents for each query,
and combines results. Improves recall by exploring different query formulations.

Example instance:

@
mqRetriever = MultiQueryRetriever
  { retriever = vectorStoreRetriever
  , llm = ollamaLLM
  , config = defaultMultiQueryRetrieverConfig
  }
@
-}
data (Retriever a, LLM m) => MultiQueryRetriever a m = MultiQueryRetriever
  { forall a m. (Retriever a, LLM m) => MultiQueryRetriever a m -> a
retriever :: a
  -- ^ The base retriever
  , forall a m. (Retriever a, LLM m) => MultiQueryRetriever a m -> m
llm :: m
  -- ^ The language model for generating queries
  , forall a m.
(Retriever a, LLM m) =>
MultiQueryRetriever a m -> MultiQueryRetrieverConfig
config :: MultiQueryRetrieverConfig
  -- ^ Configuration
  }

{- | Create retriever with default settings
Example:

>>> newMultiQueryRetriever vsRetriever ollamaLLM
MultiQueryRetriever {numQueries = 3, ...}
-}
newMultiQueryRetriever :: (Retriever a, LLM m) => a -> m -> MultiQueryRetriever a m
newMultiQueryRetriever :: forall a m.
(Retriever a, LLM m) =>
a -> m -> MultiQueryRetriever a m
newMultiQueryRetriever a
r m
l =
  MultiQueryRetriever
    { retriever :: a
retriever = a
r
    , llm :: m
llm = m
l
    , config :: MultiQueryRetrieverConfig
config = MultiQueryRetrieverConfig
defaultMultiQueryRetrieverConfig
    }

{- | Create retriever with custom configuration
Example:

>>> let customCfg = defaultMultiQueryRetrieverConfig { numQueries = 5 }
>>> newMultiQueryRetrieverWithConfig vsRetriever ollamaLLM customCfg
MultiQueryRetriever {numQueries = 5, ...}
-}
newMultiQueryRetrieverWithConfig ::
  (Retriever a, LLM m) =>
  a ->
  m ->
  MultiQueryRetrieverConfig ->
  MultiQueryRetriever a m
newMultiQueryRetrieverWithConfig :: forall a m.
(Retriever a, LLM m) =>
a -> m -> MultiQueryRetrieverConfig -> MultiQueryRetriever a m
newMultiQueryRetrieverWithConfig a
r m
l MultiQueryRetrieverConfig
c =
  MultiQueryRetriever
    { retriever :: a
retriever = a
r
    , llm :: m
llm = m
l
    , config :: MultiQueryRetrieverConfig
config = MultiQueryRetrieverConfig
c
    }

{- | Generate multiple query variants using LLM
Example:

>>> generateQueries ollamaLLM prompt "Haskell" 3 True
Right ["Haskell", "Haskell features", "Haskell applications"]
-}
generateQueries ::
  LLM m => m -> QueryGenerationPrompt -> Text -> Int -> Bool -> IO (Either String [Text])
generateQueries :: forall m.
LLM m =>
m
-> QueryGenerationPrompt
-> Text
-> Int
-> Bool
-> IO (Either String [Text])
generateQueries m
model (QueryGenerationPrompt PromptTemplate
promptTemplate) Text
query Int
n Bool
includeOriginal = do
  let vars :: Map Text Text
vars = [(Text, Text)] -> Map Text Text
forall k a. Ord k => [(k, a)] -> Map k a
HM.fromList [(Text
"query", Text
query), (Text
"num_queries", String -> Text
T.pack (String -> Text) -> String -> Text
forall a b. (a -> b) -> a -> b
$ Int -> String
forall a. Show a => a -> String
show Int
n)]
  case PromptTemplate -> Map Text Text -> Either String Text
renderPrompt PromptTemplate
promptTemplate Map Text Text
vars of
    Left String
err -> Either String [Text] -> IO (Either String [Text])
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String [Text] -> IO (Either String [Text]))
-> Either String [Text] -> IO (Either String [Text])
forall a b. (a -> b) -> a -> b
$ String -> Either String [Text]
forall a b. a -> Either a b
Left String
err
    Right Text
prompt -> do
      Either String Text
result <- m -> Text -> Maybe Params -> IO (Either String Text)
forall m.
LLM m =>
m -> Text -> Maybe Params -> IO (Either String Text)
generate m
model Text
prompt Maybe Params
forall a. Maybe a
Nothing
      case Either String Text
result of
        Left String
err -> Either String [Text] -> IO (Either String [Text])
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String [Text] -> IO (Either String [Text]))
-> Either String [Text] -> IO (Either String [Text])
forall a b. (a -> b) -> a -> b
$ String -> Either String [Text]
forall a b. a -> Either a b
Left String
err
        Right Text
response -> do
          case Text -> Either String NumberSeparatedList
forall a. OutputParser a => Text -> Either String a
parse Text
response :: Either String NumberSeparatedList of
            Left String
err -> Either String [Text] -> IO (Either String [Text])
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String [Text] -> IO (Either String [Text]))
-> Either String [Text] -> IO (Either String [Text])
forall a b. (a -> b) -> a -> b
$ String -> Either String [Text]
forall a b. a -> Either a b
Left (String -> Either String [Text]) -> String -> Either String [Text]
forall a b. (a -> b) -> a -> b
$ String
"Failed to parse LLM response: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
err
            Right (NumberSeparatedList [Text]
queries) -> do
              let uniqueQueries :: [Text]
uniqueQueries = [Text] -> [Text]
forall a. Eq a => [a] -> [a]
nub ([Text] -> [Text]) -> [Text] -> [Text]
forall a b. (a -> b) -> a -> b
$ (Text -> Bool) -> [Text] -> [Text]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (Text -> Bool) -> Text -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Bool
T.null) [Text]
queries
              Either String [Text] -> IO (Either String [Text])
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String [Text] -> IO (Either String [Text]))
-> Either String [Text] -> IO (Either String [Text])
forall a b. (a -> b) -> a -> b
$
                [Text] -> Either String [Text]
forall a b. b -> Either a b
Right ([Text] -> Either String [Text]) -> [Text] -> Either String [Text]
forall a b. (a -> b) -> a -> b
$
                  if Bool
includeOriginal
                    then Text
query Text -> [Text] -> [Text]
forall a. a -> [a] -> [a]
: [Text]
uniqueQueries
                    else [Text]
uniqueQueries

{- | Combine documents from multiple queries
Removes duplicates while maintaining order (simplified approach).
-}
combineDocuments :: [[Document]] -> [Document]
combineDocuments :: [[Document]] -> [Document]
combineDocuments [[Document]]
docLists =
  -- This is a simplified approach. In a production system, you'd want a more
  -- sophisticated way to identify and rank duplicate documents
  [Document] -> [Document]
forall a. Eq a => [a] -> [a]
nub ([Document] -> [Document]) -> [Document] -> [Document]
forall a b. (a -> b) -> a -> b
$ [[Document]] -> [Document]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Document]]
docLists

{- | Retriever instance implementation
1. Generates multiple queries using LLM
2. Retrieves documents for each query
3. Combines and deduplicates results

Example retrieval:

>>> _get_relevant_documents mqRetriever "Haskell"
Right [Document "Haskell is...", Document "Functional programming...", ...]
-}
instance (Retriever a, LLM m) => Retriever (MultiQueryRetriever a m) where
  _get_relevant_documents :: MultiQueryRetriever a m -> Text -> IO (Either String [Document])
_get_relevant_documents MultiQueryRetriever a m
r Text
query = do
    let baseRetriever :: a
baseRetriever = MultiQueryRetriever a m -> a
forall a m. (Retriever a, LLM m) => MultiQueryRetriever a m -> a
retriever MultiQueryRetriever a m
r
        model :: m
model = MultiQueryRetriever a m -> m
forall a m. (Retriever a, LLM m) => MultiQueryRetriever a m -> m
llm MultiQueryRetriever a m
r
        cfg :: MultiQueryRetrieverConfig
cfg = MultiQueryRetriever a m -> MultiQueryRetrieverConfig
forall a m.
(Retriever a, LLM m) =>
MultiQueryRetriever a m -> MultiQueryRetrieverConfig
config MultiQueryRetriever a m
r

    -- Generate multiple queries
    Either String [Text]
queriesResult <-
      m
-> QueryGenerationPrompt
-> Text
-> Int
-> Bool
-> IO (Either String [Text])
forall m.
LLM m =>
m
-> QueryGenerationPrompt
-> Text
-> Int
-> Bool
-> IO (Either String [Text])
generateQueries
        m
model
        (MultiQueryRetrieverConfig -> QueryGenerationPrompt
queryGenerationPrompt MultiQueryRetrieverConfig
cfg)
        Text
query
        (MultiQueryRetrieverConfig -> Int
numQueries MultiQueryRetrieverConfig
cfg)
        (MultiQueryRetrieverConfig -> Bool
includeOriginalQuery MultiQueryRetrieverConfig
cfg)

    case Either String [Text]
queriesResult of
      Left String
err -> Either String [Document] -> IO (Either String [Document])
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String [Document] -> IO (Either String [Document]))
-> Either String [Document] -> IO (Either String [Document])
forall a b. (a -> b) -> a -> b
$ String -> Either String [Document]
forall a b. a -> Either a b
Left (String -> Either String [Document])
-> String -> Either String [Document]
forall a b. (a -> b) -> a -> b
$ String
"Error generating queries: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
err
      Right [Text]
queries -> do
        -- Get documents for each query
        [Either String [Document]]
results <- (Text -> IO (Either String [Document]))
-> [Text] -> IO [Either String [Document]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (a -> Text -> IO (Either String [Document])
forall a. Retriever a => a -> Text -> IO (Either String [Document])
_get_relevant_documents a
baseRetriever) [Text]
queries

        -- Filter successful results
        let validResults :: [[Document]]
validResults = [Either String [Document]] -> [[Document]]
forall a b. [Either a b] -> [b]
rights [Either String [Document]]
results

        if [[Document]] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [[Document]]
validResults
          then Either String [Document] -> IO (Either String [Document])
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String [Document] -> IO (Either String [Document]))
-> Either String [Document] -> IO (Either String [Document])
forall a b. (a -> b) -> a -> b
$ String -> Either String [Document]
forall a b. a -> Either a b
Left String
"No valid results from any query"
          else Either String [Document] -> IO (Either String [Document])
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String [Document] -> IO (Either String [Document]))
-> Either String [Document] -> IO (Either String [Document])
forall a b. (a -> b) -> a -> b
$ [Document] -> Either String [Document]
forall a b. b -> Either a b
Right ([Document] -> Either String [Document])
-> [Document] -> Either String [Document]
forall a b. (a -> b) -> a -> b
$ [[Document]] -> [Document]
combineDocuments [[Document]]
validResults

{-
 ghci> :set -XOverloadedStrings
 ghci> let ollamaEmbed = OllamaEmbeddings "nomic-embed-text:latest" Nothing Nothing
 ghci> let vs = emptyInMemoryVectorStore ollamaEmbed
 ghci> import Data.Map (empty)
 ghci> import Data.Either
 ghci> newVs <- addDocuments vs [Document "Tushar is 25 years old." empty]
 ghci> let newVs_ = fromRight vs newVs
 ghci> let vRet = VectorStoreRetriever newVs_
 ghci> let ollamLLM = Ollama "llama3.2" []
 ghci> let mqRet = newMultiQueryRetriever vRet ollamLLM
 ghci> documents <- _get_relevant_documents mqRet "How old is Tushar?"
 ghci> documents
    Right [Document {pageContent = "Tushar is 25 years old.", metadata = fromList []}]
 -}

{- | Runnable interface implementation
Allows integration with LangChain workflows:

>>> invoke mqRetriever "AI applications"
Right [Document "Machine learning...", ...]
-}
instance (Retriever a, LLM m) => Run.Runnable (MultiQueryRetriever a m) where
  type RunnableInput (MultiQueryRetriever a m) = Text
  type RunnableOutput (MultiQueryRetriever a m) = [Document]

  invoke :: MultiQueryRetriever a m
-> RunnableInput (MultiQueryRetriever a m)
-> IO (Either String (RunnableOutput (MultiQueryRetriever a m)))
invoke MultiQueryRetriever a m
r RunnableInput (MultiQueryRetriever a m)
query = MultiQueryRetriever a m -> Text -> IO (Either String [Document])
forall a. Retriever a => a -> Text -> IO (Either String [Document])
_get_relevant_documents MultiQueryRetriever a m
r Text
RunnableInput (MultiQueryRetriever a m)
query

{- $examples
Test case patterns:
1. Query generation
   >>> generateQueries ollamaLLM prompt "Test" 2 False
   Right ["Test case", "Test example"]

2. Full retrieval flow
   >>> _get_relevant_documents mqRetriever "Haskell"
   Right [Document "Functional...", Document "Type system..."]

3. Configuration variants
   >>> let cfg = defaultMultiQueryRetrieverConfig { numQueries = 5 }
   >>> newMultiQueryRetrieverWithConfig vsRetriever ollamaLLM cfg
   MultiQueryRetriever {numQueries = 5, ...}
-}