{- |
Module      : Llama.ChatTemplate
Description : Chat template related functions for llama-cpp
Copyright   : (c) 2025 Tushar Adhatrao
License     : MIT
Maintainer  : Tushar Adhatrao <tusharadhatrao@gmail.com>
-}
module Llama.ChatTemplate (
    ChatMessage (..)
    , toCLlamaChatMessage
   , chatApplyTemplate
   , chatGetBuiltinTemplates
  ) where

import Data.Functor
import Foreign
import Foreign.C.String
import Llama.Internal.Foreign
import Llama.Internal.Types

data ChatMessage = ChatMessage
  { ChatMessage -> String
chatRole :: String
  , ChatMessage -> String
chatContent :: String
  }
  deriving (Int -> ChatMessage -> ShowS
[ChatMessage] -> ShowS
ChatMessage -> String
(Int -> ChatMessage -> ShowS)
-> (ChatMessage -> String)
-> ([ChatMessage] -> ShowS)
-> Show ChatMessage
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ChatMessage -> ShowS
showsPrec :: Int -> ChatMessage -> ShowS
$cshow :: ChatMessage -> String
show :: ChatMessage -> String
$cshowList :: [ChatMessage] -> ShowS
showList :: [ChatMessage] -> ShowS
Show, ChatMessage -> ChatMessage -> Bool
(ChatMessage -> ChatMessage -> Bool)
-> (ChatMessage -> ChatMessage -> Bool) -> Eq ChatMessage
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ChatMessage -> ChatMessage -> Bool
== :: ChatMessage -> ChatMessage -> Bool
$c/= :: ChatMessage -> ChatMessage -> Bool
/= :: ChatMessage -> ChatMessage -> Bool
Eq)

toCLlamaChatMessage :: ChatMessage -> IO LlamaChatMessage
toCLlamaChatMessage :: ChatMessage -> IO LlamaChatMessage
toCLlamaChatMessage ChatMessage
msg = do
  rolePtr <- String -> IO CString
newCString (ChatMessage -> String
chatRole ChatMessage
msg)
  contentPtr <- newCString (chatContent msg)
  return $ LlamaChatMessage rolePtr contentPtr

-- | Apply a chat template to format a conversation.
chatApplyTemplate ::
  -- | Optional custom template (uses built-in if Nothing)
  Maybe String ->
  -- | List of chat messages
  [ChatMessage] ->
  -- | Add assistant token at end?
  Bool ->
  -- | Buffer size (suggested: 4096)
  Int ->
  -- | Returns formatted string or error message
  IO (Either String String)
chatApplyTemplate :: Maybe String
-> [ChatMessage] -> Bool -> Int -> IO (Either String String)
chatApplyTemplate Maybe String
mTemplate [ChatMessage]
messages Bool
addAssist Int
bufferSize = do
  let nMessages :: Int
nMessages = [ChatMessage] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ChatMessage]
messages
      bufSize :: Int
bufSize = Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
bufferSize (Int
nMessages Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
64) -- Heuristic fallback
  templateCString <- IO CString -> (String -> IO CString) -> Maybe String -> IO CString
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (CString -> IO CString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return CString
forall a. Ptr a
nullPtr) String -> IO CString
newCString Maybe String
mTemplate
  msgs <- mapM toCLlamaChatMessage messages
  cMessages <- withArray msgs $ \Ptr LlamaChatMessage
ptr ->
    Ptr LlamaChatMessage -> IO (Ptr LlamaChatMessage)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Ptr LlamaChatMessage
ptr

  allocaBytes bufSize $ \CString
bufPtr -> do
    result <-
      CString
-> Ptr LlamaChatMessage
-> CSize
-> CBool
-> CString
-> CInt
-> IO CInt
c_llama_chat_apply_template
        CString
templateCString
        Ptr LlamaChatMessage
cMessages
        (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nMessages)
        (if Bool
addAssist then CBool
1 else CBool
0)
        CString
bufPtr
        (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
bufSize)

    if result < 0
      then return $ Left "Failed to apply chat template"
      else do
        let actualSize = CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CInt
result
        if actualSize >= bufSize
          then -- Need larger buffer
            allocaBytes actualSize $ \CString
newBufPtr -> do
              result' <-
                CString
-> Ptr LlamaChatMessage
-> CSize
-> CBool
-> CString
-> CInt
-> IO CInt
c_llama_chat_apply_template
                  CString
templateCString
                  Ptr LlamaChatMessage
cMessages
                  (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nMessages)
                  (if Bool
addAssist then CBool
1 else CBool
0)
                  CString
newBufPtr
                  (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
actualSize)
              if result' < 0
                then return $ Left "Failed to apply chat template after resize"
                else peekCString newBufPtr <&> Right
          else do
            str <- peekCString bufPtr
            return $ Right str

-- | Get list of available built-in chat templates.
chatGetBuiltinTemplates :: IO [String]
chatGetBuiltinTemplates :: IO [String]
chatGetBuiltinTemplates = do
  -- First get number of templates
  numTemplates <- Ptr CString -> CSize -> IO CInt
c_llama_chat_builtin_templates Ptr CString
forall a. Ptr a
nullPtr CSize
0
  let maxTemplates = CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CInt
numTemplates

  if numTemplates <= 0
    then return []
    else do
      -- Allocate array of CString pointers
      arrPtr <- mallocArray maxTemplates
      result <- c_llama_chat_builtin_templates arrPtr (fromIntegral maxTemplates)
      if result < 0
        then return []
        else do
          cstrs <- peekArray maxTemplates arrPtr
          strs <- mapM peekCString cstrs
          free arrPtr
          return strs