module Llama.ChatTemplate (
ChatMessage (..)
, toCLlamaChatMessage
, chatApplyTemplate
, chatGetBuiltinTemplates
) where
import Data.Functor
import Foreign
import Foreign.C.String
import Llama.Internal.Foreign
import Llama.Internal.Types
data ChatMessage = ChatMessage
{ ChatMessage -> String
chatRole :: String
, ChatMessage -> String
chatContent :: String
}
deriving (Int -> ChatMessage -> ShowS
[ChatMessage] -> ShowS
ChatMessage -> String
(Int -> ChatMessage -> ShowS)
-> (ChatMessage -> String)
-> ([ChatMessage] -> ShowS)
-> Show ChatMessage
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ChatMessage -> ShowS
showsPrec :: Int -> ChatMessage -> ShowS
$cshow :: ChatMessage -> String
show :: ChatMessage -> String
$cshowList :: [ChatMessage] -> ShowS
showList :: [ChatMessage] -> ShowS
Show, ChatMessage -> ChatMessage -> Bool
(ChatMessage -> ChatMessage -> Bool)
-> (ChatMessage -> ChatMessage -> Bool) -> Eq ChatMessage
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ChatMessage -> ChatMessage -> Bool
== :: ChatMessage -> ChatMessage -> Bool
$c/= :: ChatMessage -> ChatMessage -> Bool
/= :: ChatMessage -> ChatMessage -> Bool
Eq)
toCLlamaChatMessage :: ChatMessage -> IO LlamaChatMessage
toCLlamaChatMessage :: ChatMessage -> IO LlamaChatMessage
toCLlamaChatMessage ChatMessage
msg = do
rolePtr <- String -> IO CString
newCString (ChatMessage -> String
chatRole ChatMessage
msg)
contentPtr <- newCString (chatContent msg)
return $ LlamaChatMessage rolePtr contentPtr
chatApplyTemplate ::
Maybe String ->
[ChatMessage] ->
Bool ->
Int ->
IO (Either String String)
chatApplyTemplate :: Maybe String
-> [ChatMessage] -> Bool -> Int -> IO (Either String String)
chatApplyTemplate Maybe String
mTemplate [ChatMessage]
messages Bool
addAssist Int
bufferSize = do
let nMessages :: Int
nMessages = [ChatMessage] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ChatMessage]
messages
bufSize :: Int
bufSize = Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
bufferSize (Int
nMessages Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
64)
templateCString <- IO CString -> (String -> IO CString) -> Maybe String -> IO CString
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (CString -> IO CString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return CString
forall a. Ptr a
nullPtr) String -> IO CString
newCString Maybe String
mTemplate
msgs <- mapM toCLlamaChatMessage messages
cMessages <- withArray msgs $ \Ptr LlamaChatMessage
ptr ->
Ptr LlamaChatMessage -> IO (Ptr LlamaChatMessage)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Ptr LlamaChatMessage
ptr
allocaBytes bufSize $ \CString
bufPtr -> do
result <-
CString
-> Ptr LlamaChatMessage
-> CSize
-> CBool
-> CString
-> CInt
-> IO CInt
c_llama_chat_apply_template
CString
templateCString
Ptr LlamaChatMessage
cMessages
(Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nMessages)
(if Bool
addAssist then CBool
1 else CBool
0)
CString
bufPtr
(Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
bufSize)
if result < 0
then return $ Left "Failed to apply chat template"
else do
let actualSize = CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CInt
result
if actualSize >= bufSize
then
allocaBytes actualSize $ \CString
newBufPtr -> do
result' <-
CString
-> Ptr LlamaChatMessage
-> CSize
-> CBool
-> CString
-> CInt
-> IO CInt
c_llama_chat_apply_template
CString
templateCString
Ptr LlamaChatMessage
cMessages
(Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nMessages)
(if Bool
addAssist then CBool
1 else CBool
0)
CString
newBufPtr
(Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
actualSize)
if result' < 0
then return $ Left "Failed to apply chat template after resize"
else peekCString newBufPtr <&> Right
else do
str <- peekCString bufPtr
return $ Right str
chatGetBuiltinTemplates :: IO [String]
chatGetBuiltinTemplates :: IO [String]
chatGetBuiltinTemplates = do
numTemplates <- Ptr CString -> CSize -> IO CInt
c_llama_chat_builtin_templates Ptr CString
forall a. Ptr a
nullPtr CSize
0
let maxTemplates = CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CInt
numTemplates
if numTemplates <= 0
then return []
else do
arrPtr <- mallocArray maxTemplates
result <- c_llama_chat_builtin_templates arrPtr (fromIntegral maxTemplates)
if result < 0
then return []
else do
cstrs <- peekArray maxTemplates arrPtr
strs <- mapM peekCString cstrs
free arrPtr
return strs