{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TypeFamilies #-}
module Langchain.Memory.Core
( BaseMemory (..)
, WindowBufferMemory (..)
, trimChatMessage
, addAndTrim
, initialChatMessage
) where
import qualified Data.List.NonEmpty as NE
import Data.Text (Text)
import Langchain.LLM.Core (ChatMessage, Message (..), Role (..), defaultMessageData)
import Langchain.Runnable.Core
class BaseMemory m where
messages :: m -> IO (Either String ChatMessage)
addUserMessage :: m -> Text -> IO (Either String m)
addAiMessage :: m -> Text -> IO (Either String m)
addMessage :: m -> Message -> IO (Either String m)
clear :: m -> IO (Either String m)
data WindowBufferMemory = WindowBufferMemory
{ WindowBufferMemory -> Int
maxWindowSize :: Int
, WindowBufferMemory -> ChatMessage
windowBufferMessages :: ChatMessage
}
deriving (Int -> WindowBufferMemory -> ShowS
[WindowBufferMemory] -> ShowS
WindowBufferMemory -> String
(Int -> WindowBufferMemory -> ShowS)
-> (WindowBufferMemory -> String)
-> ([WindowBufferMemory] -> ShowS)
-> Show WindowBufferMemory
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> WindowBufferMemory -> ShowS
showsPrec :: Int -> WindowBufferMemory -> ShowS
$cshow :: WindowBufferMemory -> String
show :: WindowBufferMemory -> String
$cshowList :: [WindowBufferMemory] -> ShowS
showList :: [WindowBufferMemory] -> ShowS
Show, WindowBufferMemory -> WindowBufferMemory -> Bool
(WindowBufferMemory -> WindowBufferMemory -> Bool)
-> (WindowBufferMemory -> WindowBufferMemory -> Bool)
-> Eq WindowBufferMemory
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: WindowBufferMemory -> WindowBufferMemory -> Bool
== :: WindowBufferMemory -> WindowBufferMemory -> Bool
$c/= :: WindowBufferMemory -> WindowBufferMemory -> Bool
/= :: WindowBufferMemory -> WindowBufferMemory -> Bool
Eq)
instance BaseMemory WindowBufferMemory where
messages :: WindowBufferMemory -> IO (Either String ChatMessage)
messages WindowBufferMemory {Int
ChatMessage
maxWindowSize :: WindowBufferMemory -> Int
windowBufferMessages :: WindowBufferMemory -> ChatMessage
maxWindowSize :: Int
windowBufferMessages :: ChatMessage
..} = Either String ChatMessage -> IO (Either String ChatMessage)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either String ChatMessage -> IO (Either String ChatMessage))
-> Either String ChatMessage -> IO (Either String ChatMessage)
forall a b. (a -> b) -> a -> b
$ ChatMessage -> Either String ChatMessage
forall a b. b -> Either a b
Right ChatMessage
windowBufferMessages
addMessage :: WindowBufferMemory
-> Message -> IO (Either String WindowBufferMemory)
addMessage winBuffMem :: WindowBufferMemory
winBuffMem@WindowBufferMemory {Int
ChatMessage
maxWindowSize :: WindowBufferMemory -> Int
windowBufferMessages :: WindowBufferMemory -> ChatMessage
maxWindowSize :: Int
windowBufferMessages :: ChatMessage
..} Message
msg =
let currentLength :: Int
currentLength = ChatMessage -> Int
forall a. NonEmpty a -> Int
NE.length ChatMessage
windowBufferMessages
in if Int
currentLength Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
maxWindowSize
then
Either String WindowBufferMemory
-> IO (Either String WindowBufferMemory)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either String WindowBufferMemory
-> IO (Either String WindowBufferMemory))
-> Either String WindowBufferMemory
-> IO (Either String WindowBufferMemory)
forall a b. (a -> b) -> a -> b
$
WindowBufferMemory -> Either String WindowBufferMemory
forall a b. b -> Either a b
Right (WindowBufferMemory -> Either String WindowBufferMemory)
-> WindowBufferMemory -> Either String WindowBufferMemory
forall a b. (a -> b) -> a -> b
$
WindowBufferMemory
winBuffMem
{ windowBufferMessages =
NE.fromList $ (NE.tail windowBufferMessages) ++ [msg]
}
else
Either String WindowBufferMemory
-> IO (Either String WindowBufferMemory)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either String WindowBufferMemory
-> IO (Either String WindowBufferMemory))
-> Either String WindowBufferMemory
-> IO (Either String WindowBufferMemory)
forall a b. (a -> b) -> a -> b
$
WindowBufferMemory -> Either String WindowBufferMemory
forall a b. b -> Either a b
Right (WindowBufferMemory -> Either String WindowBufferMemory)
-> WindowBufferMemory -> Either String WindowBufferMemory
forall a b. (a -> b) -> a -> b
$
WindowBufferMemory
winBuffMem
{ windowBufferMessages =
windowBufferMessages `NE.append` NE.singleton msg
}
addUserMessage :: WindowBufferMemory -> Text -> IO (Either String WindowBufferMemory)
addUserMessage WindowBufferMemory
winBuffMem Text
uMsg =
WindowBufferMemory
-> Message -> IO (Either String WindowBufferMemory)
forall m. BaseMemory m => m -> Message -> IO (Either String m)
addMessage WindowBufferMemory
winBuffMem (Role -> Text -> MessageData -> Message
Message Role
User Text
uMsg MessageData
defaultMessageData)
addAiMessage :: WindowBufferMemory -> Text -> IO (Either String WindowBufferMemory)
addAiMessage WindowBufferMemory
winBuffMem Text
uMsg =
WindowBufferMemory
-> Message -> IO (Either String WindowBufferMemory)
forall m. BaseMemory m => m -> Message -> IO (Either String m)
addMessage WindowBufferMemory
winBuffMem (Role -> Text -> MessageData -> Message
Message Role
Assistant Text
uMsg MessageData
defaultMessageData)
clear :: WindowBufferMemory -> IO (Either String WindowBufferMemory)
clear WindowBufferMemory
winBuffMem =
Either String WindowBufferMemory
-> IO (Either String WindowBufferMemory)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either String WindowBufferMemory
-> IO (Either String WindowBufferMemory))
-> Either String WindowBufferMemory
-> IO (Either String WindowBufferMemory)
forall a b. (a -> b) -> a -> b
$
WindowBufferMemory -> Either String WindowBufferMemory
forall a b. b -> Either a b
Right (WindowBufferMemory -> Either String WindowBufferMemory)
-> WindowBufferMemory -> Either String WindowBufferMemory
forall a b. (a -> b) -> a -> b
$
WindowBufferMemory
winBuffMem
{ windowBufferMessages =
NE.singleton $ Message System "You are an AI model" defaultMessageData
}
trimChatMessage :: Int -> ChatMessage -> ChatMessage
trimChatMessage :: Int -> ChatMessage -> ChatMessage
trimChatMessage Int
n ChatMessage
msgs = [Message] -> ChatMessage
forall a. HasCallStack => [a] -> NonEmpty a
NE.fromList ([Message] -> ChatMessage) -> [Message] -> ChatMessage
forall a b. (a -> b) -> a -> b
$ Int -> [Message] -> [Message]
forall a. Int -> [a] -> [a]
drop (Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
0 (ChatMessage -> Int
forall a. NonEmpty a -> Int
NE.length ChatMessage
msgs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n)) (ChatMessage -> [Message]
forall a. NonEmpty a -> [a]
NE.toList ChatMessage
msgs)
addAndTrim :: Int -> Message -> ChatMessage -> ChatMessage
addAndTrim :: Int -> Message -> ChatMessage -> ChatMessage
addAndTrim Int
n Message
msg ChatMessage
msgs = Int -> ChatMessage -> ChatMessage
trimChatMessage Int
n (ChatMessage
msgs ChatMessage -> ChatMessage -> ChatMessage
forall a. NonEmpty a -> NonEmpty a -> NonEmpty a
`NE.append` Message -> ChatMessage
forall a. a -> NonEmpty a
NE.singleton Message
msg)
initialChatMessage :: Text -> ChatMessage
initialChatMessage :: Text -> ChatMessage
initialChatMessage Text
systemPrompt = Message -> ChatMessage
forall a. a -> NonEmpty a
NE.singleton (Message -> ChatMessage) -> Message -> ChatMessage
forall a b. (a -> b) -> a -> b
$ Role -> Text -> MessageData -> Message
Message Role
System Text
systemPrompt MessageData
defaultMessageData
instance Runnable WindowBufferMemory where
type RunnableInput WindowBufferMemory = Text
type RunnableOutput WindowBufferMemory = WindowBufferMemory
invoke :: WindowBufferMemory
-> RunnableInput WindowBufferMemory
-> IO (Either String (RunnableOutput WindowBufferMemory))
invoke WindowBufferMemory
memory RunnableInput WindowBufferMemory
input = WindowBufferMemory -> Text -> IO (Either String WindowBufferMemory)
forall m. BaseMemory m => m -> Text -> IO (Either String m)
addUserMessage WindowBufferMemory
memory Text
RunnableInput WindowBufferMemory
input