{-# LANGUAGE TypeFamilies #-}

{- |
Module      : Langchain.Runnable.Core
Description : Core Interface of Runnable. Necessary for LangChain Expression Language (LCEL)
Copyright   : (c) 2025 Tushar Adhatrao
License     : MIT
Maintainer  : Tushar Adhatrao <tusharadhatrao@gmail.com>

This module defines the 'Runnable' typeclass, which is the fundamental abstraction in the
Haskell implementation of LangChain Expression Language (LCEL). A 'Runnable' represents any
component that can process an input and produce an output, potentially with side effects.

The 'Runnable' abstraction enables composition of various LLM-related components into
processing pipelines, including:

* Language Models
* Prompt Templates
* Document Retrievers
* Text Splitters
* Embedders
* Vector Stores
* Output Parsers

By implementing the 'Runnable' typeclass, components can be combined using the combinators
provided in "Langchain.Runnable.Chain".
-}
module Langchain.Runnable.Core
  ( Runnable (..)
  ) where

import Control.Monad.IO.Class (MonadIO, liftIO)
import Langchain.Error (LangchainResult)

{- | The core 'Runnable' typeclass represents anything that can "run" with an input and produce an output.

This typeclass is the foundation of the LangChain Expression Language (LCEL) in Haskell,
allowing different components to be composed into processing pipelines.

To implement a 'Runnable', you must:

1. Define the input and output types using associated type families
2. Implement the 'invoke' method
3. Optionally override 'batch' and 'stream' for specific optimizations

Example implementation:

@
data TextSplitter = TextSplitter { chunkSize :: Int, overlap :: Int }

instance Runnable TextSplitter where
  type RunnableInput TextSplitter = String
  type RunnableOutput TextSplitter = [String]

  invoke splitter text = do
    -- Implementation of text splitting logic
    let chunks = splitTextIntoChunks (chunkSize splitter) (overlap splitter) text
    return $ Right chunks
@
-}
class Runnable r where
  {- | The type of input the runnable accepts.

  For example, an LLM might accept 'String' or 'PromptValue' as input.
  -}
  type RunnableInput r

  {- | The type of output the runnable produces.

  For example, an LLM might produce 'String' or 'LLMResult' as output.
  -}
  type RunnableOutput r

  {- | Core method to invoke (run) this component with a single input.

  This is the primary method that must be implemented for any 'Runnable'.
  It processes a single input and returns either an error message or the output.

  Example usage:

  @
  let model = OpenAI { temperature = 0.7, model = "gpt-3.5-turbo" }
  result <- invoke model "Explain monads in simple terms."
  case result of
    Left err -> putStrLn $ "Error: " ++ err
    Right response -> putStrLn response
  @
  -}
  invoke :: r -> RunnableInput r -> IO (LangchainResult (RunnableOutput r))

  invokeM :: MonadIO m => r -> RunnableInput r -> m (LangchainResult (RunnableOutput r))
  invokeM r
runnable RunnableInput r
input = IO (LangchainResult (RunnableOutput r))
-> m (LangchainResult (RunnableOutput r))
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (LangchainResult (RunnableOutput r))
 -> m (LangchainResult (RunnableOutput r)))
-> IO (LangchainResult (RunnableOutput r))
-> m (LangchainResult (RunnableOutput r))
forall a b. (a -> b) -> a -> b
$ r -> RunnableInput r -> IO (LangchainResult (RunnableOutput r))
forall r.
Runnable r =>
r -> RunnableInput r -> IO (LangchainResult (RunnableOutput r))
invoke r
runnable RunnableInput r
input

  batch :: r -> [RunnableInput r] -> IO (LangchainResult [RunnableOutput r])

  batchM :: MonadIO m => r -> [RunnableInput r] -> m (LangchainResult [RunnableOutput r])
  batchM r
runnable [RunnableInput r]
inputs = IO (LangchainResult [RunnableOutput r])
-> m (LangchainResult [RunnableOutput r])
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (LangchainResult [RunnableOutput r])
 -> m (LangchainResult [RunnableOutput r]))
-> IO (LangchainResult [RunnableOutput r])
-> m (LangchainResult [RunnableOutput r])
forall a b. (a -> b) -> a -> b
$ r -> [RunnableInput r] -> IO (LangchainResult [RunnableOutput r])
forall r.
Runnable r =>
r -> [RunnableInput r] -> IO (LangchainResult [RunnableOutput r])
batch r
runnable [RunnableInput r]
inputs

  -- | Default implementation of batch that processes each input sequentially
  batch r
r [RunnableInput r]
inputs = do
    results <- (RunnableInput r -> IO (LangchainResult (RunnableOutput r)))
-> [RunnableInput r] -> IO [LangchainResult (RunnableOutput r)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (r -> RunnableInput r -> IO (LangchainResult (RunnableOutput r))
forall r.
Runnable r =>
r -> RunnableInput r -> IO (LangchainResult (RunnableOutput r))
invoke r
r) [RunnableInput r]
inputs
    return $ sequence results

  stream :: r -> RunnableInput r -> (RunnableOutput r -> IO ()) -> IO (LangchainResult ())

  -- | Default implementation that invokes the runnable and then calls the callback with the full result
  stream r
r RunnableInput r
input RunnableOutput r -> IO ()
callback = do
    result <- r -> RunnableInput r -> IO (LangchainResult (RunnableOutput r))
forall r.
Runnable r =>
r -> RunnableInput r -> IO (LangchainResult (RunnableOutput r))
invoke r
r RunnableInput r
input
    case result of
      Left LangchainError
err -> LangchainResult () -> IO (LangchainResult ())
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (LangchainResult () -> IO (LangchainResult ()))
-> LangchainResult () -> IO (LangchainResult ())
forall a b. (a -> b) -> a -> b
$ LangchainError -> LangchainResult ()
forall a b. a -> Either a b
Left LangchainError
err
      Right RunnableOutput r
output -> do
        RunnableOutput r -> IO ()
callback RunnableOutput r
output
        LangchainResult () -> IO (LangchainResult ())
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (LangchainResult () -> IO (LangchainResult ()))
-> LangchainResult () -> IO (LangchainResult ())
forall a b. (a -> b) -> a -> b
$ () -> LangchainResult ()
forall a b. b -> Either a b
Right ()

  streamM ::
    MonadIO m => r -> RunnableInput r -> (RunnableOutput r -> IO ()) -> m (LangchainResult ())
  streamM r
runnable RunnableInput r
input RunnableOutput r -> IO ()
callback = IO (LangchainResult ()) -> m (LangchainResult ())
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (LangchainResult ()) -> m (LangchainResult ()))
-> IO (LangchainResult ()) -> m (LangchainResult ())
forall a b. (a -> b) -> a -> b
$ r
-> RunnableInput r
-> (RunnableOutput r -> IO ())
-> IO (LangchainResult ())
forall r.
Runnable r =>
r
-> RunnableInput r
-> (RunnableOutput r -> IO ())
-> IO (LangchainResult ())
stream r
runnable RunnableInput r
input RunnableOutput r -> IO ()
callback