{- |
Module      : Llama.Tokenize
Description : High level Tokenize interface for llama-cpp
Copyright   : (c) 2025 Tushar Adhatrao
License     : MIT
Maintainer  : Tushar Adhatrao <tusharadhatrao@gmail.com>
-}
module Llama.Tokenize (
    tokenize
  ,  tokenToPiece
  ,  detokenize
) where

import Llama.Internal.Foreign
import Llama.Internal.Types
import Foreign
import Foreign.C.String

-- | Tokenize a string into tokens
tokenize :: Vocab -> String -> Bool -> Bool -> IO ([LlamaToken], Int)
tokenize :: Vocab -> String -> Bool -> Bool -> IO ([LlamaToken], Int)
tokenize (Vocab ForeignPtr CLlamaVocab
vocab) String
text Bool
addSpecial Bool
parseSpecial = do
  ForeignPtr CLlamaVocab
-> (Ptr CLlamaVocab -> IO ([LlamaToken], Int))
-> IO ([LlamaToken], Int)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaVocab
vocab ((Ptr CLlamaVocab -> IO ([LlamaToken], Int))
 -> IO ([LlamaToken], Int))
-> (Ptr CLlamaVocab -> IO ([LlamaToken], Int))
-> IO ([LlamaToken], Int)
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaVocab
vocabPtr -> do
    String
-> (CString -> IO ([LlamaToken], Int)) -> IO ([LlamaToken], Int)
forall a. String -> (CString -> IO a) -> IO a
withCString String
text ((CString -> IO ([LlamaToken], Int)) -> IO ([LlamaToken], Int))
-> (CString -> IO ([LlamaToken], Int)) -> IO ([LlamaToken], Int)
forall a b. (a -> b) -> a -> b
$ \CString
cText -> do
      let textLen :: Int
textLen = String -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length String
text
      Int
-> (Ptr LlamaToken -> IO ([LlamaToken], Int))
-> IO ([LlamaToken], Int)
forall a b. Storable a => Int -> (Ptr a -> IO b) -> IO b
allocaArray Int
textLen ((Ptr LlamaToken -> IO ([LlamaToken], Int))
 -> IO ([LlamaToken], Int))
-> (Ptr LlamaToken -> IO ([LlamaToken], Int))
-> IO ([LlamaToken], Int)
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaToken
tokensPtr -> do
        tokenCount <- CLlamaVocab
-> CString
-> LlamaToken
-> Ptr LlamaToken
-> LlamaToken
-> CBool
-> CBool
-> IO LlamaToken
c_llama_tokenize
          (Ptr CLlamaVocab -> CLlamaVocab
CLlamaVocab Ptr CLlamaVocab
vocabPtr)
          CString
cText
          (Int -> LlamaToken
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
textLen)
          Ptr LlamaToken
tokensPtr
          (Int -> LlamaToken
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
textLen)
          (Bool -> CBool
forall a. Num a => Bool -> a
fromBool Bool
addSpecial)
          (Bool -> CBool
forall a. Num a => Bool -> a
fromBool Bool
parseSpecial)
        tokens <- peekArray (fromIntegral tokenCount) tokensPtr
        return (tokens, fromIntegral tokenCount)

-- | Convert a token to a piece of text
tokenToPiece :: Vocab -> LlamaToken -> Bool -> IO String
tokenToPiece :: Vocab -> LlamaToken -> Bool -> IO String
tokenToPiece (Vocab ForeignPtr CLlamaVocab
vocab) LlamaToken
token_ Bool
special = do
  ForeignPtr CLlamaVocab
-> (Ptr CLlamaVocab -> IO String) -> IO String
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaVocab
vocab ((Ptr CLlamaVocab -> IO String) -> IO String)
-> (Ptr CLlamaVocab -> IO String) -> IO String
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaVocab
vocabPtr -> do
    Int -> (CString -> IO String) -> IO String
forall a b. Storable a => Int -> (Ptr a -> IO b) -> IO b
allocaArray Int
256 ((CString -> IO String) -> IO String)
-> (CString -> IO String) -> IO String
forall a b. (a -> b) -> a -> b
$ \CString
bufPtr -> do
      _ <- CLlamaVocab
-> LlamaToken
-> CString
-> LlamaToken
-> LlamaToken
-> CBool
-> IO LlamaToken
c_llama_token_to_piece
        (Ptr CLlamaVocab -> CLlamaVocab
CLlamaVocab Ptr CLlamaVocab
vocabPtr)
        LlamaToken
token_
        CString
bufPtr
        LlamaToken
256
        LlamaToken
0
        (Bool -> CBool
forall a. Num a => Bool -> a
fromBool Bool
special)
      peekCString bufPtr

-- | Detokenize tokens into a string
detokenize :: Vocab -> [LlamaToken] -> Bool -> Bool -> IO String
detokenize :: Vocab -> [LlamaToken] -> Bool -> Bool -> IO String
detokenize (Vocab ForeignPtr CLlamaVocab
vocab) [LlamaToken]
tokens Bool
removeSpecial Bool
unparseSpecial = do
  ForeignPtr CLlamaVocab
-> (Ptr CLlamaVocab -> IO String) -> IO String
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaVocab
vocab ((Ptr CLlamaVocab -> IO String) -> IO String)
-> (Ptr CLlamaVocab -> IO String) -> IO String
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaVocab
vocabPtr -> do
    Int -> (CString -> IO String) -> IO String
forall a b. Storable a => Int -> (Ptr a -> IO b) -> IO b
allocaArray Int
256 ((CString -> IO String) -> IO String)
-> (CString -> IO String) -> IO String
forall a b. (a -> b) -> a -> b
$ \CString
textPtr -> do
      [LlamaToken] -> (Ptr LlamaToken -> IO String) -> IO String
forall a b. Storable a => [a] -> (Ptr a -> IO b) -> IO b
withArray [LlamaToken]
tokens ((Ptr LlamaToken -> IO String) -> IO String)
-> (Ptr LlamaToken -> IO String) -> IO String
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaToken
tokensPtr -> do
        _ <- CLlamaVocab
-> Ptr LlamaToken
-> LlamaToken
-> CString
-> LlamaToken
-> CBool
-> CBool
-> IO LlamaToken
c_llama_detokenize
            (Ptr CLlamaVocab -> CLlamaVocab
CLlamaVocab Ptr CLlamaVocab
vocabPtr)
            Ptr LlamaToken
tokensPtr
            (Int -> LlamaToken
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> LlamaToken) -> Int -> LlamaToken
forall a b. (a -> b) -> a -> b
$ [LlamaToken] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LlamaToken]
tokens)
            CString
textPtr
            LlamaToken
256
            (Bool -> CBool
forall a. Num a => Bool -> a
fromBool Bool
removeSpecial)
            (Bool -> CBool
forall a. Num a => Bool -> a
fromBool Bool
unparseSpecial)
        peekCString textPtr