{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TypeFamilies #-}

{- |
Module      : Langchain.Memory.Core
Description : Memory management for LangChain Haskell
Copyright   : (c) 2025 Tushar Adhatrao
License     : MIT
Maintainer  : Tushar Adhatrao <tusharadhatrao@gmail.com>
Stability   : experimental

Implementation of LangChain's memory management patterns, providing:

- Chat history tracking with size limits
- Message addition/trimming strategies
- Integration with Runnable workflows

Example usage:

@
-- Create memory with 5-message window
memory = WindowBufferMemory 5 (initialChatMessage "You are an assistant")

-- Add user message
newMemory <- addUserMessage memory "Hello, world!"

-- Retrieve current messages
messages <- messages newMemory
-- Right [Message System "...", Message User "Hello, world!"]
@
-}
module Langchain.Memory.Core
  ( BaseMemory (..)
  , WindowBufferMemory (..)
  , trimChatMessage
  , addAndTrim
  , initialChatMessage
  ) where

import qualified Data.List.NonEmpty as NE
import Data.Text (Text)
import Langchain.LLM.Core (ChatMessage, Message (..), Role (..), defaultMessageData)
import Langchain.Runnable.Core

{- | Base typeclass for memory implementations
Defines standard operations for chat history management.

Example instance:

@
instance BaseMemory MyMemory where
  messages = ...
  addUserMessage = ...
@
-}
class BaseMemory m where
  -- | Retrieve current chat history
  messages :: m -> IO (Either String ChatMessage)

  -- | Add user message to history
  addUserMessage :: m -> Text -> IO (Either String m)

  -- | Add AI response to history
  addAiMessage :: m -> Text -> IO (Either String m)

  -- | Add generic message to history
  addMessage :: m -> Message -> IO (Either String m)

  -- | Reset memory to initial state
  clear :: m -> IO (Either String m)

{- | Sliding window memory implementation.
Stores chat history with maximum size limit.

Example:

>>> let mem = WindowBufferMemory 2 (NE.singleton (Message System "Sys" defaultMessageData))
>>> addMessage mem (Message User "Hello" defaultMessageData)
Right (WindowBufferMemory {maxWindowSize = 2, ...})
-}
data WindowBufferMemory = WindowBufferMemory
  { WindowBufferMemory -> Int
maxWindowSize :: Int
  -- ^ Maximum number of messages to retain
  , WindowBufferMemory -> ChatMessage
windowBufferMessages :: ChatMessage
  -- ^ Current message buffer [[9]]
  }
  deriving (Int -> WindowBufferMemory -> ShowS
[WindowBufferMemory] -> ShowS
WindowBufferMemory -> String
(Int -> WindowBufferMemory -> ShowS)
-> (WindowBufferMemory -> String)
-> ([WindowBufferMemory] -> ShowS)
-> Show WindowBufferMemory
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> WindowBufferMemory -> ShowS
showsPrec :: Int -> WindowBufferMemory -> ShowS
$cshow :: WindowBufferMemory -> String
show :: WindowBufferMemory -> String
$cshowList :: [WindowBufferMemory] -> ShowS
showList :: [WindowBufferMemory] -> ShowS
Show, WindowBufferMemory -> WindowBufferMemory -> Bool
(WindowBufferMemory -> WindowBufferMemory -> Bool)
-> (WindowBufferMemory -> WindowBufferMemory -> Bool)
-> Eq WindowBufferMemory
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: WindowBufferMemory -> WindowBufferMemory -> Bool
== :: WindowBufferMemory -> WindowBufferMemory -> Bool
$c/= :: WindowBufferMemory -> WindowBufferMemory -> Bool
/= :: WindowBufferMemory -> WindowBufferMemory -> Bool
Eq)

instance BaseMemory WindowBufferMemory where
  -- \| Get current messages
  --
  --  Example:
  --
  --  >>> messages (WindowBufferMemory 5 initialMessages)
  --  Right initialMessages
  --
  messages :: WindowBufferMemory -> IO (Either String ChatMessage)
messages WindowBufferMemory {Int
ChatMessage
maxWindowSize :: WindowBufferMemory -> Int
windowBufferMessages :: WindowBufferMemory -> ChatMessage
maxWindowSize :: Int
windowBufferMessages :: ChatMessage
..} = Either String ChatMessage -> IO (Either String ChatMessage)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either String ChatMessage -> IO (Either String ChatMessage))
-> Either String ChatMessage -> IO (Either String ChatMessage)
forall a b. (a -> b) -> a -> b
$ ChatMessage -> Either String ChatMessage
forall a b. b -> Either a b
Right ChatMessage
windowBufferMessages

  -- \| Add message with window trimming
  --
  --  Example:
  --
  --  >>> let mem = WindowBufferMemory 2 (NE.fromList [msg1])
  --  >>> addMessage mem msg2
  --  Right (WindowBufferMemory {windowBufferMessages = [msg1, msg2]})
  --
  --  >>> addMessage mem msg3
  --  Right (WindowBufferMemory {windowBufferMessages = [msg2, msg3]})
  --
  addMessage :: WindowBufferMemory
-> Message -> IO (Either String WindowBufferMemory)
addMessage winBuffMem :: WindowBufferMemory
winBuffMem@WindowBufferMemory {Int
ChatMessage
maxWindowSize :: WindowBufferMemory -> Int
windowBufferMessages :: WindowBufferMemory -> ChatMessage
maxWindowSize :: Int
windowBufferMessages :: ChatMessage
..} Message
msg =
    let currentLength :: Int
currentLength = ChatMessage -> Int
forall a. NonEmpty a -> Int
NE.length ChatMessage
windowBufferMessages
     in if Int
currentLength Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
maxWindowSize
          then
            Either String WindowBufferMemory
-> IO (Either String WindowBufferMemory)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either String WindowBufferMemory
 -> IO (Either String WindowBufferMemory))
-> Either String WindowBufferMemory
-> IO (Either String WindowBufferMemory)
forall a b. (a -> b) -> a -> b
$
              WindowBufferMemory -> Either String WindowBufferMemory
forall a b. b -> Either a b
Right (WindowBufferMemory -> Either String WindowBufferMemory)
-> WindowBufferMemory -> Either String WindowBufferMemory
forall a b. (a -> b) -> a -> b
$
                WindowBufferMemory
winBuffMem
                  { windowBufferMessages =
                      NE.fromList $ (NE.tail windowBufferMessages) ++ [msg]
                  }
          else
            Either String WindowBufferMemory
-> IO (Either String WindowBufferMemory)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either String WindowBufferMemory
 -> IO (Either String WindowBufferMemory))
-> Either String WindowBufferMemory
-> IO (Either String WindowBufferMemory)
forall a b. (a -> b) -> a -> b
$
              WindowBufferMemory -> Either String WindowBufferMemory
forall a b. b -> Either a b
Right (WindowBufferMemory -> Either String WindowBufferMemory)
-> WindowBufferMemory -> Either String WindowBufferMemory
forall a b. (a -> b) -> a -> b
$
                WindowBufferMemory
winBuffMem
                  { windowBufferMessages =
                      windowBufferMessages `NE.append` NE.singleton msg
                  }

  -- \| Add user message
  --
  --  Example:
  --
  --  >>> addUserMessage mem "Hello"
  --  Right (WindowBufferMemory { ... })
  --
  addUserMessage :: WindowBufferMemory -> Text -> IO (Either String WindowBufferMemory)
addUserMessage WindowBufferMemory
winBuffMem Text
uMsg =
    WindowBufferMemory
-> Message -> IO (Either String WindowBufferMemory)
forall m. BaseMemory m => m -> Message -> IO (Either String m)
addMessage WindowBufferMemory
winBuffMem (Role -> Text -> MessageData -> Message
Message Role
User Text
uMsg MessageData
defaultMessageData)

  -- \| Add AI message
  --
  --  Example:
  --
  --  >>> addAiMessage mem "Response"
  --  Right (WindowBufferMemory { ... })
  --
  addAiMessage :: WindowBufferMemory -> Text -> IO (Either String WindowBufferMemory)
addAiMessage WindowBufferMemory
winBuffMem Text
uMsg =
    WindowBufferMemory
-> Message -> IO (Either String WindowBufferMemory)
forall m. BaseMemory m => m -> Message -> IO (Either String m)
addMessage WindowBufferMemory
winBuffMem (Role -> Text -> MessageData -> Message
Message Role
Assistant Text
uMsg MessageData
defaultMessageData)

  -- \| Reset to initial system message
  --
  --  Example:
  --
  --  >>> clear mem
  --  Right (WindowBufferMemory { windowBufferMessages = [systemMsg] })
  --
  clear :: WindowBufferMemory -> IO (Either String WindowBufferMemory)
clear WindowBufferMemory
winBuffMem =
    Either String WindowBufferMemory
-> IO (Either String WindowBufferMemory)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either String WindowBufferMemory
 -> IO (Either String WindowBufferMemory))
-> Either String WindowBufferMemory
-> IO (Either String WindowBufferMemory)
forall a b. (a -> b) -> a -> b
$
      WindowBufferMemory -> Either String WindowBufferMemory
forall a b. b -> Either a b
Right (WindowBufferMemory -> Either String WindowBufferMemory)
-> WindowBufferMemory -> Either String WindowBufferMemory
forall a b. (a -> b) -> a -> b
$
        WindowBufferMemory
winBuffMem
          { windowBufferMessages =
              NE.singleton $ Message System "You are an AI model" defaultMessageData
          }

{- | Trim chat history to last n messages
Example:

>>> let msgs = NE.fromList [msg1, msg2, msg3]
>>> trimChatMessage 2 msgs
[msg2, msg3]
-}
trimChatMessage :: Int -> ChatMessage -> ChatMessage
trimChatMessage :: Int -> ChatMessage -> ChatMessage
trimChatMessage Int
n ChatMessage
msgs = [Message] -> ChatMessage
forall a. HasCallStack => [a] -> NonEmpty a
NE.fromList ([Message] -> ChatMessage) -> [Message] -> ChatMessage
forall a b. (a -> b) -> a -> b
$ Int -> [Message] -> [Message]
forall a. Int -> [a] -> [a]
drop (Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
0 (ChatMessage -> Int
forall a. NonEmpty a -> Int
NE.length ChatMessage
msgs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n)) (ChatMessage -> [Message]
forall a. NonEmpty a -> [a]
NE.toList ChatMessage
msgs)

{- | Add and maintain window size
Example:

>>> let msgs = NE.fromList [msg1]
>>> addAndTrim 2 msg2 msgs
[msg1, msg2]
-}
addAndTrim :: Int -> Message -> ChatMessage -> ChatMessage
addAndTrim :: Int -> Message -> ChatMessage -> ChatMessage
addAndTrim Int
n Message
msg ChatMessage
msgs = Int -> ChatMessage -> ChatMessage
trimChatMessage Int
n (ChatMessage
msgs ChatMessage -> ChatMessage -> ChatMessage
forall a. NonEmpty a -> NonEmpty a -> NonEmpty a
`NE.append` Message -> ChatMessage
forall a. a -> NonEmpty a
NE.singleton Message
msg)

{- | Create initial chat history
Example:

>>> initialChatMessage "You are Qwen"
[Message System "You are Qwen"]
-}
initialChatMessage :: Text -> ChatMessage
initialChatMessage :: Text -> ChatMessage
initialChatMessage Text
systemPrompt = Message -> ChatMessage
forall a. a -> NonEmpty a
NE.singleton (Message -> ChatMessage) -> Message -> ChatMessage
forall a b. (a -> b) -> a -> b
$ Role -> Text -> MessageData -> Message
Message Role
System Text
systemPrompt MessageData
defaultMessageData

instance Runnable WindowBufferMemory where
  type RunnableInput WindowBufferMemory = Text
  type RunnableOutput WindowBufferMemory = WindowBufferMemory

  -- \| Runnable interface for user input
  --
  --  Example:
  --
  --  >>> invoke memory "Hello"
  --  Right (WindowBufferMemory { ... })
  --
  invoke :: WindowBufferMemory
-> RunnableInput WindowBufferMemory
-> IO (Either String (RunnableOutput WindowBufferMemory))
invoke WindowBufferMemory
memory RunnableInput WindowBufferMemory
input = WindowBufferMemory -> Text -> IO (Either String WindowBufferMemory)
forall m. BaseMemory m => m -> Text -> IO (Either String m)
addUserMessage WindowBufferMemory
memory Text
RunnableInput WindowBufferMemory
input

{- $examples
Test case patterns:
1. Message trimming
   >>> let mem = WindowBufferMemory 2 [msg1, msg2]
   >>> addMessage mem msg3
   Right [msg2, msg3]

2. Initial state
   >>> messages (WindowBufferMemory 5 initialMessages)
   Right initialMessages

3. Runnable integration
   >>> run (WindowBufferMemory 5 initialMessages) "Hello"
   Right (WindowBufferMemory { ... })
-}