{- |
Module      : Llama.Tokenize
Description : High level Tokenize interface for llama-cpp
Copyright   : (c) 2025 Tushar Adhatrao
License     : MIT
Maintainer  : Tushar Adhatrao <tusharadhatrao@gmail.com>
-}
module Llama.Vocab (
    getVocabSize
, getVocabTokenCount
, getVocabTokenText
, getVocabTokenScore
, getVocabTokenAttr
, isVocabTokenEog
, isVocabTokenControl
, getVocabBosToken
, getVocabEosToken
, getVocabEotToken
, getVocabSepToken
, getVocabNlToken
, getVocabPadToken
, getVocabAddBOSToken
, getVocabAddEOSToken
, getVocabFIMPrefixToken
, getVocabFIMSuffixToken
, getVocabFIMMiddleToken
, getVocabFIMPADToken
, getVocabFIMSeparatorToken
) where

import Llama.Internal.Foreign
import Llama.Internal.Types
import Foreign
import Foreign.C.String

-- | Get the number of vocab entries
getVocabSize :: Vocab -> IO Int
getVocabSize :: Vocab -> IO Int
getVocabSize (Vocab ForeignPtr CLlamaVocab
vocabFPtr) = do
  ForeignPtr CLlamaVocab -> (Ptr CLlamaVocab -> IO Int) -> IO Int
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaVocab
vocabFPtr ((Ptr CLlamaVocab -> IO Int) -> IO Int)
-> (Ptr CLlamaVocab -> IO Int) -> IO Int
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaVocab
vocabPtr ->
    CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> Int) -> IO CInt -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CLlamaVocab -> IO CInt
c_llama_n_vocab (Ptr CLlamaVocab -> CLlamaVocab
CLlamaVocab Ptr CLlamaVocab
vocabPtr)

-- | Get the number of tokens in the vocab
getVocabTokenCount :: Vocab -> IO Int
getVocabTokenCount :: Vocab -> IO Int
getVocabTokenCount (Vocab ForeignPtr CLlamaVocab
vocabFPtr) = do
  ForeignPtr CLlamaVocab -> (Ptr CLlamaVocab -> IO Int) -> IO Int
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaVocab
vocabFPtr ((Ptr CLlamaVocab -> IO Int) -> IO Int)
-> (Ptr CLlamaVocab -> IO Int) -> IO Int
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaVocab
vocabPtr ->
    CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> Int) -> IO CInt -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CLlamaVocab -> IO CInt
c_llama_vocab_n_tokens (Ptr CLlamaVocab -> CLlamaVocab
CLlamaVocab Ptr CLlamaVocab
vocabPtr)

-- | Get the text for a token
getVocabTokenText :: Vocab -> LlamaToken -> IO String
getVocabTokenText :: Vocab -> CInt -> IO String
getVocabTokenText (Vocab ForeignPtr CLlamaVocab
vocabFPtr) CInt
token_ = do
  ForeignPtr CLlamaVocab
-> (Ptr CLlamaVocab -> IO String) -> IO String
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaVocab
vocabFPtr ((Ptr CLlamaVocab -> IO String) -> IO String)
-> (Ptr CLlamaVocab -> IO String) -> IO String
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaVocab
vocabPtr -> do
    cText <- CLlamaVocab -> CInt -> IO CString
c_llama_vocab_get_text (Ptr CLlamaVocab -> CLlamaVocab
CLlamaVocab Ptr CLlamaVocab
vocabPtr) CInt
token_
    peekCString cText

-- | Get the score for a token
getVocabTokenScore :: Vocab -> LlamaToken -> IO Float
getVocabTokenScore :: Vocab -> CInt -> IO Float
getVocabTokenScore (Vocab ForeignPtr CLlamaVocab
vocabFPtr) CInt
token_ = do
  ForeignPtr CLlamaVocab -> (Ptr CLlamaVocab -> IO Float) -> IO Float
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaVocab
vocabFPtr ((Ptr CLlamaVocab -> IO Float) -> IO Float)
-> (Ptr CLlamaVocab -> IO Float) -> IO Float
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaVocab
vocabPtr ->
    CFloat -> Float
forall a b. (Real a, Fractional b) => a -> b
realToFrac (CFloat -> Float) -> IO CFloat -> IO Float
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CLlamaVocab -> CInt -> IO CFloat
c_llama_vocab_get_score (Ptr CLlamaVocab -> CLlamaVocab
CLlamaVocab Ptr CLlamaVocab
vocabPtr) CInt
token_

-- | Get the attribute for a token
getVocabTokenAttr :: Vocab -> LlamaToken -> IO Int
getVocabTokenAttr :: Vocab -> CInt -> IO Int
getVocabTokenAttr (Vocab ForeignPtr CLlamaVocab
vocabFPtr) CInt
token_ = do
  ForeignPtr CLlamaVocab -> (Ptr CLlamaVocab -> IO Int) -> IO Int
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaVocab
vocabFPtr ((Ptr CLlamaVocab -> IO Int) -> IO Int)
-> (Ptr CLlamaVocab -> IO Int) -> IO Int
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaVocab
vocabPtr ->
    CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> Int) -> IO CInt -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CLlamaVocab -> CInt -> IO CInt
c_llama_vocab_get_attr (Ptr CLlamaVocab -> CLlamaVocab
CLlamaVocab Ptr CLlamaVocab
vocabPtr) CInt
token_

-- | Check if a token is end-of-grammar
isVocabTokenEog :: Vocab -> LlamaToken -> IO Bool
isVocabTokenEog :: Vocab -> CInt -> IO Bool
isVocabTokenEog (Vocab ForeignPtr CLlamaVocab
vocabFPtr) CInt
token_ = do
  ForeignPtr CLlamaVocab -> (Ptr CLlamaVocab -> IO Bool) -> IO Bool
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaVocab
vocabFPtr ((Ptr CLlamaVocab -> IO Bool) -> IO Bool)
-> (Ptr CLlamaVocab -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaVocab
vocabPtr ->
    (CBool -> CBool -> Bool
forall a. Eq a => a -> a -> Bool
/= CBool
0) (CBool -> Bool) -> IO CBool -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CLlamaVocab -> CInt -> IO CBool
c_llama_vocab_is_eog (Ptr CLlamaVocab -> CLlamaVocab
CLlamaVocab Ptr CLlamaVocab
vocabPtr) CInt
token_

-- | Check if a token is a control token
isVocabTokenControl :: Vocab -> LlamaToken -> IO Bool
isVocabTokenControl :: Vocab -> CInt -> IO Bool
isVocabTokenControl (Vocab ForeignPtr CLlamaVocab
vocabFPtr) CInt
token_ = do
  ForeignPtr CLlamaVocab -> (Ptr CLlamaVocab -> IO Bool) -> IO Bool
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaVocab
vocabFPtr ((Ptr CLlamaVocab -> IO Bool) -> IO Bool)
-> (Ptr CLlamaVocab -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaVocab
vocabPtr ->
    (CBool -> CBool -> Bool
forall a. Eq a => a -> a -> Bool
/= CBool
0) (CBool -> Bool) -> IO CBool -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CLlamaVocab -> CInt -> IO CBool
c_llama_vocab_is_control (Ptr CLlamaVocab -> CLlamaVocab
CLlamaVocab Ptr CLlamaVocab
vocabPtr) CInt
token_

-- | Get the beginning-of-sentence token
getVocabBosToken :: Vocab -> IO LlamaToken
getVocabBosToken :: Vocab -> IO CInt
getVocabBosToken (Vocab ForeignPtr CLlamaVocab
vocabFPtr) = do
  ForeignPtr CLlamaVocab -> (Ptr CLlamaVocab -> IO CInt) -> IO CInt
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaVocab
vocabFPtr ((Ptr CLlamaVocab -> IO CInt) -> IO CInt)
-> (Ptr CLlamaVocab -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaVocab
vocabPtr ->
    CLlamaVocab -> IO CInt
c_llama_vocab_bos (Ptr CLlamaVocab -> CLlamaVocab
CLlamaVocab Ptr CLlamaVocab
vocabPtr)

-- | Get the end-of-sentence token
getVocabEosToken :: Vocab -> IO LlamaToken
getVocabEosToken :: Vocab -> IO CInt
getVocabEosToken (Vocab ForeignPtr CLlamaVocab
vocabFPtr) = do
  ForeignPtr CLlamaVocab -> (Ptr CLlamaVocab -> IO CInt) -> IO CInt
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaVocab
vocabFPtr ((Ptr CLlamaVocab -> IO CInt) -> IO CInt)
-> (Ptr CLlamaVocab -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaVocab
vocabPtr ->
    CLlamaVocab -> IO CInt
c_llama_vocab_eos (Ptr CLlamaVocab -> CLlamaVocab
CLlamaVocab Ptr CLlamaVocab
vocabPtr)

-- | Get the end-of-turn token
getVocabEotToken :: Vocab -> IO LlamaToken
getVocabEotToken :: Vocab -> IO CInt
getVocabEotToken (Vocab ForeignPtr CLlamaVocab
vocabFPtr) = do
  ForeignPtr CLlamaVocab -> (Ptr CLlamaVocab -> IO CInt) -> IO CInt
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaVocab
vocabFPtr ((Ptr CLlamaVocab -> IO CInt) -> IO CInt)
-> (Ptr CLlamaVocab -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaVocab
vocabPtr ->
    CLlamaVocab -> IO CInt
c_llama_vocab_eot (Ptr CLlamaVocab -> CLlamaVocab
CLlamaVocab Ptr CLlamaVocab
vocabPtr)

-- | Get the sentence separator token
getVocabSepToken :: Vocab -> IO LlamaToken
getVocabSepToken :: Vocab -> IO CInt
getVocabSepToken (Vocab ForeignPtr CLlamaVocab
vocabFPtr) = do
  ForeignPtr CLlamaVocab -> (Ptr CLlamaVocab -> IO CInt) -> IO CInt
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaVocab
vocabFPtr ((Ptr CLlamaVocab -> IO CInt) -> IO CInt)
-> (Ptr CLlamaVocab -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaVocab
vocabPtr ->
    CLlamaVocab -> IO CInt
c_llama_vocab_sep (Ptr CLlamaVocab -> CLlamaVocab
CLlamaVocab Ptr CLlamaVocab
vocabPtr)

-- | Get the next-line token
getVocabNlToken :: Vocab -> IO LlamaToken
getVocabNlToken :: Vocab -> IO CInt
getVocabNlToken (Vocab ForeignPtr CLlamaVocab
vocabFPtr) = do
  ForeignPtr CLlamaVocab -> (Ptr CLlamaVocab -> IO CInt) -> IO CInt
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaVocab
vocabFPtr ((Ptr CLlamaVocab -> IO CInt) -> IO CInt)
-> (Ptr CLlamaVocab -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaVocab
vocabPtr ->
    CLlamaVocab -> IO CInt
c_llama_vocab_nl (Ptr CLlamaVocab -> CLlamaVocab
CLlamaVocab Ptr CLlamaVocab
vocabPtr)

-- | Get the padding token
getVocabPadToken :: Vocab -> IO LlamaToken
getVocabPadToken :: Vocab -> IO CInt
getVocabPadToken (Vocab ForeignPtr CLlamaVocab
vocabFPtr) = do
  ForeignPtr CLlamaVocab -> (Ptr CLlamaVocab -> IO CInt) -> IO CInt
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaVocab
vocabFPtr ((Ptr CLlamaVocab -> IO CInt) -> IO CInt)
-> (Ptr CLlamaVocab -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaVocab
vocabPtr ->
    CLlamaVocab -> IO CInt
c_llama_vocab_pad (Ptr CLlamaVocab -> CLlamaVocab
CLlamaVocab Ptr CLlamaVocab
vocabPtr)

-- | Get whether to add BOS token automatically
getVocabAddBOSToken :: Vocab -> IO Bool
getVocabAddBOSToken :: Vocab -> IO Bool
getVocabAddBOSToken (Vocab ForeignPtr CLlamaVocab
vocabFPtr) = do
  ForeignPtr CLlamaVocab -> (Ptr CLlamaVocab -> IO Bool) -> IO Bool
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaVocab
vocabFPtr ((Ptr CLlamaVocab -> IO Bool) -> IO Bool)
-> (Ptr CLlamaVocab -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaVocab
vocabPtr ->
    (CBool -> CBool -> Bool
forall a. Eq a => a -> a -> Bool
/= CBool
0) (CBool -> Bool) -> IO CBool -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CLlamaVocab -> IO CBool
c_llama_vocab_get_add_bos (Ptr CLlamaVocab -> CLlamaVocab
CLlamaVocab Ptr CLlamaVocab
vocabPtr)

-- | Get whether to add EOS token automatically
getVocabAddEOSToken :: Vocab -> IO Bool
getVocabAddEOSToken :: Vocab -> IO Bool
getVocabAddEOSToken (Vocab ForeignPtr CLlamaVocab
vocabFPtr) = do
  ForeignPtr CLlamaVocab -> (Ptr CLlamaVocab -> IO Bool) -> IO Bool
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaVocab
vocabFPtr ((Ptr CLlamaVocab -> IO Bool) -> IO Bool)
-> (Ptr CLlamaVocab -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaVocab
vocabPtr ->
    (CBool -> CBool -> Bool
forall a. Eq a => a -> a -> Bool
/= CBool
0) (CBool -> Bool) -> IO CBool -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CLlamaVocab -> IO CBool
c_llama_vocab_get_add_eos (Ptr CLlamaVocab -> CLlamaVocab
CLlamaVocab Ptr CLlamaVocab
vocabPtr)

-- | Get the FIM prefix token
getVocabFIMPrefixToken :: Vocab -> IO LlamaToken
getVocabFIMPrefixToken :: Vocab -> IO CInt
getVocabFIMPrefixToken (Vocab ForeignPtr CLlamaVocab
vocabFPtr) = do
  ForeignPtr CLlamaVocab -> (Ptr CLlamaVocab -> IO CInt) -> IO CInt
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaVocab
vocabFPtr ((Ptr CLlamaVocab -> IO CInt) -> IO CInt)
-> (Ptr CLlamaVocab -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaVocab
vocabPtr ->
    CLlamaVocab -> IO CInt
c_llama_vocab_fim_pre (Ptr CLlamaVocab -> CLlamaVocab
CLlamaVocab Ptr CLlamaVocab
vocabPtr)

-- | Get the FIM suffix token
getVocabFIMSuffixToken :: Vocab -> IO LlamaToken
getVocabFIMSuffixToken :: Vocab -> IO CInt
getVocabFIMSuffixToken (Vocab ForeignPtr CLlamaVocab
vocabFPtr) = do
  ForeignPtr CLlamaVocab -> (Ptr CLlamaVocab -> IO CInt) -> IO CInt
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaVocab
vocabFPtr ((Ptr CLlamaVocab -> IO CInt) -> IO CInt)
-> (Ptr CLlamaVocab -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaVocab
vocabPtr ->
    CLlamaVocab -> IO CInt
c_llama_vocab_fim_suf (Ptr CLlamaVocab -> CLlamaVocab
CLlamaVocab Ptr CLlamaVocab
vocabPtr)

-- | Get the FIM middle token
getVocabFIMMiddleToken :: Vocab -> IO LlamaToken
getVocabFIMMiddleToken :: Vocab -> IO CInt
getVocabFIMMiddleToken (Vocab ForeignPtr CLlamaVocab
vocabFPtr) = do
  ForeignPtr CLlamaVocab -> (Ptr CLlamaVocab -> IO CInt) -> IO CInt
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaVocab
vocabFPtr ((Ptr CLlamaVocab -> IO CInt) -> IO CInt)
-> (Ptr CLlamaVocab -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaVocab
vocabPtr ->
    CLlamaVocab -> IO CInt
c_llama_vocab_fim_mid (Ptr CLlamaVocab -> CLlamaVocab
CLlamaVocab Ptr CLlamaVocab
vocabPtr)

-- | Get the FIM pad token
getVocabFIMPADToken :: Vocab -> IO LlamaToken
getVocabFIMPADToken :: Vocab -> IO CInt
getVocabFIMPADToken (Vocab ForeignPtr CLlamaVocab
vocabFPtr) = do
  ForeignPtr CLlamaVocab -> (Ptr CLlamaVocab -> IO CInt) -> IO CInt
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaVocab
vocabFPtr ((Ptr CLlamaVocab -> IO CInt) -> IO CInt)
-> (Ptr CLlamaVocab -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaVocab
vocabPtr ->
    CLlamaVocab -> IO CInt
c_llama_vocab_fim_pad (Ptr CLlamaVocab -> CLlamaVocab
CLlamaVocab Ptr CLlamaVocab
vocabPtr)

-- | Get the FIM separator token
getVocabFIMSeparatorToken :: Vocab -> IO LlamaToken
getVocabFIMSeparatorToken :: Vocab -> IO CInt
getVocabFIMSeparatorToken (Vocab ForeignPtr CLlamaVocab
vocabFPtr) = do
  ForeignPtr CLlamaVocab -> (Ptr CLlamaVocab -> IO CInt) -> IO CInt
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaVocab
vocabFPtr ((Ptr CLlamaVocab -> IO CInt) -> IO CInt)
-> (Ptr CLlamaVocab -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaVocab
vocabPtr ->
    CLlamaVocab -> IO CInt
c_llama_vocab_fim_sep (Ptr CLlamaVocab -> CLlamaVocab
CLlamaVocab Ptr CLlamaVocab
vocabPtr)