{- |
Module      : Llama.Model
Description : High level Model interface for llama-cpp
Copyright   : (c) 2025 Tushar Adhatrao
License     : MIT
Maintainer  : Tushar Adhatrao <tusharadhatrao@gmail.com>
-}
module Llama.Model (
    defaultModelParams
, loadModelFromFile
, initContextFromModel
, getModelVocab
, getContextModel
, getVocabType
, getModelRoPEFreqScale
, getModelNumKVHeads
, getModelNumHeads
, getModelNumLayers
, getModelEmbeddingDim
, getModelTrainingContextSize
, getModelSize
, getModelChatTemplate
, getModelHasEncoder
, getModelNumParams
, getModelHasDecoder
, getModelDecoderStartToken
, getModelIsRecurrent
, quantizeModel
, quantizeModelDefault
, defaultQuantizeParams
, getModelMetaCount
, getModelMetaValue
, getModelMetaKeyByIndex
, getModelMetaValueByIndex
, getModelDescription
, loadModelFromSplits
, getModelRopeType
) where

import Data.Functor
import Foreign
import Foreign.C.String
import Llama.Internal.Foreign
import Llama.Internal.Types
import Llama.Internal.Types.Params

-- | Default model parameters
defaultModelParams :: IO LlamaModelParams
defaultModelParams :: IO LlamaModelParams
defaultModelParams = do
  -- Convert to a pointer to pass to the C function
  (Ptr LlamaModelParams -> IO LlamaModelParams)
-> IO LlamaModelParams
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr LlamaModelParams -> IO LlamaModelParams)
 -> IO LlamaModelParams)
-> (Ptr LlamaModelParams -> IO LlamaModelParams)
-> IO LlamaModelParams
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaModelParams
paramsPtr -> do
    CLlamaModelParams -> IO ()
c_llama_model_default_params (Ptr LlamaModelParams -> CLlamaModelParams
CLlamaModelParams Ptr LlamaModelParams
paramsPtr)
    Ptr LlamaModelParams -> IO LlamaModelParams
forall a. Storable a => Ptr a -> IO a
peek Ptr LlamaModelParams
paramsPtr

-- | Load a model from a file using the specified parameters
loadModelFromFile :: FilePath -> LlamaModelParams -> IO (Either String Model)
loadModelFromFile :: FilePath -> LlamaModelParams -> IO (Either FilePath Model)
loadModelFromFile FilePath
path LlamaModelParams
params = do
  FilePath
-> (CString -> IO (Either FilePath Model))
-> IO (Either FilePath Model)
forall a. FilePath -> (CString -> IO a) -> IO a
withCString FilePath
path ((CString -> IO (Either FilePath Model))
 -> IO (Either FilePath Model))
-> (CString -> IO (Either FilePath Model))
-> IO (Either FilePath Model)
forall a b. (a -> b) -> a -> b
$ \CString
cPath -> do
    LlamaModelParams
-> (Ptr LlamaModelParams -> IO (Either FilePath Model))
-> IO (Either FilePath Model)
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
withStorable LlamaModelParams
params ((Ptr LlamaModelParams -> IO (Either FilePath Model))
 -> IO (Either FilePath Model))
-> (Ptr LlamaModelParams -> IO (Either FilePath Model))
-> IO (Either FilePath Model)
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaModelParams
paramsPtr -> do
      model <- CString -> CLlamaModelParams -> IO CLlamaModel
c_llama_model_load_from_file_wrap CString
cPath (Ptr LlamaModelParams -> CLlamaModelParams
CLlamaModelParams Ptr LlamaModelParams
paramsPtr)
      if model == CLlamaModel nullPtr
        then return $ Left "Failed to load model"
        else do
          let (CLlamaModel modelPtr) = model
          fp <- newForeignPtr p_llama_model_free modelPtr
          return $ Right $ Model fp

-- | Create a context from a model using the specified parameters
initContextFromModel :: Model -> LlamaContextParams -> IO (Either String Context)
initContextFromModel :: Model -> LlamaContextParams -> IO (Either FilePath Context)
initContextFromModel (Model ForeignPtr CLlamaModel
modelFPtr) LlamaContextParams
params = do
  ForeignPtr CLlamaModel
-> (Ptr CLlamaModel -> IO (Either FilePath Context))
-> IO (Either FilePath Context)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaModel
modelFPtr ((Ptr CLlamaModel -> IO (Either FilePath Context))
 -> IO (Either FilePath Context))
-> (Ptr CLlamaModel -> IO (Either FilePath Context))
-> IO (Either FilePath Context)
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaModel
modelPtr -> do
    LlamaContextParams
-> (Ptr LlamaContextParams -> IO (Either FilePath Context))
-> IO (Either FilePath Context)
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
withStorable LlamaContextParams
params ((Ptr LlamaContextParams -> IO (Either FilePath Context))
 -> IO (Either FilePath Context))
-> (Ptr LlamaContextParams -> IO (Either FilePath Context))
-> IO (Either FilePath Context)
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaContextParams
paramsPtr -> do
      context <-
        CLlamaModel -> CLlamaContextParams -> IO CLlamaContext
c_llama_init_from_model_wrap
          (Ptr CLlamaModel -> CLlamaModel
CLlamaModel Ptr CLlamaModel
modelPtr)
          (Ptr LlamaContextParams -> CLlamaContextParams
CLlamaContextParams Ptr LlamaContextParams
paramsPtr)
      if context == CLlamaContext nullPtr
        then return $ Left "Failed to initialize context"
        else do
          let (CLlamaContext contextPtr) = context
          fp <- newForeignPtr p_llama_free contextPtr
          return $ Right $ Context fp

-- | Get the vocabulary from a model
getModelVocab :: Model -> IO (Either String Vocab)
getModelVocab :: Model -> IO (Either FilePath Vocab)
getModelVocab (Model ForeignPtr CLlamaModel
modelFPtr) = do
  ForeignPtr CLlamaModel
-> (Ptr CLlamaModel -> IO (Either FilePath Vocab))
-> IO (Either FilePath Vocab)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaModel
modelFPtr ((Ptr CLlamaModel -> IO (Either FilePath Vocab))
 -> IO (Either FilePath Vocab))
-> (Ptr CLlamaModel -> IO (Either FilePath Vocab))
-> IO (Either FilePath Vocab)
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaModel
modelPtr -> do
    vocab <- CLlamaModel -> IO CLlamaVocab
c_llama_model_get_vocab (Ptr CLlamaModel -> CLlamaModel
CLlamaModel Ptr CLlamaModel
modelPtr)
    if vocab == CLlamaVocab nullPtr
      then return $ Left "Failed to get vocabulary"
      else do
        -- For now, assuming it's owned by the model and doesn't need separate freeing
        let (CLlamaVocab vocabPtr) = vocab
        -- Using a dummy finalizer since vocab is owned by the model
        fp <- newForeignPtr_ vocabPtr
        return $ Right $ Vocab fp

-- | Convert Storable Haskell struct to pointer and run an action
withStorable :: Storable a => a -> (Ptr a -> IO b) -> IO b
withStorable :: forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
withStorable a
x Ptr a -> IO b
f = (Ptr a -> IO b) -> IO b
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr a -> IO b) -> IO b) -> (Ptr a -> IO b) -> IO b
forall a b. (a -> b) -> a -> b
$ \Ptr a
ptr -> do
  Ptr a -> a -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr a
ptr a
x
  Ptr a -> IO b
f Ptr a
ptr

-- | Get the model associated with a context.
getContextModel :: Context -> IO Model
getContextModel :: Context -> IO Model
getContextModel (Context ForeignPtr CLlamaContext
ctxFPtr) =
  ForeignPtr CLlamaContext
-> (Ptr CLlamaContext -> IO Model) -> IO Model
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO Model) -> IO Model)
-> (Ptr CLlamaContext -> IO Model) -> IO Model
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr -> do
    CLlamaModel modelPtr <- CLlamaContext -> IO CLlamaModel
c_llama_get_model (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr)
    fp <- newForeignPtr p_llama_model_free modelPtr
    return $ Model fp

-- | Get the vocabulary type.
getVocabType :: Vocab -> IO (Maybe LlamaVocabType)
getVocabType :: Vocab -> IO (Maybe LlamaVocabType)
getVocabType (Vocab ForeignPtr CLlamaVocab
vocabFPtr) =
  ForeignPtr CLlamaVocab
-> (Ptr CLlamaVocab -> IO (Maybe LlamaVocabType))
-> IO (Maybe LlamaVocabType)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaVocab
vocabFPtr ((Ptr CLlamaVocab -> IO (Maybe LlamaVocabType))
 -> IO (Maybe LlamaVocabType))
-> (Ptr CLlamaVocab -> IO (Maybe LlamaVocabType))
-> IO (Maybe LlamaVocabType)
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaVocab
vocabPtr ->
    CInt -> Maybe LlamaVocabType
fromLlamaRopeVocabType (CInt -> Maybe LlamaVocabType)
-> IO CInt -> IO (Maybe LlamaVocabType)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CLlamaVocab -> IO CInt
c_llama_vocab_type (Ptr CLlamaVocab -> CLlamaVocab
CLlamaVocab Ptr CLlamaVocab
vocabPtr)

-- | Get RoPE frequency scaling factor used during training.
getModelRoPEFreqScale :: Model -> IO Float
getModelRoPEFreqScale :: Model -> IO Float
getModelRoPEFreqScale (Model ForeignPtr CLlamaModel
modelFPtr) =
  ForeignPtr CLlamaModel -> (Ptr CLlamaModel -> IO Float) -> IO Float
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaModel
modelFPtr ((Ptr CLlamaModel -> IO Float) -> IO Float)
-> (Ptr CLlamaModel -> IO Float) -> IO Float
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaModel
modelPtr ->
    CFloat -> Float
forall a b. (Real a, Fractional b) => a -> b
realToFrac (CFloat -> Float) -> IO CFloat -> IO Float
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CLlamaModel -> IO CFloat
c_llama_model_rope_freq_scale_train (Ptr CLlamaModel -> CLlamaModel
CLlamaModel Ptr CLlamaModel
modelPtr)

-- | Get the number of key/value heads in the model.
getModelNumKVHeads :: Model -> IO Int
getModelNumKVHeads :: Model -> IO Int
getModelNumKVHeads (Model ForeignPtr CLlamaModel
modelFPtr) =
  ForeignPtr CLlamaModel -> (Ptr CLlamaModel -> IO Int) -> IO Int
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaModel
modelFPtr ((Ptr CLlamaModel -> IO Int) -> IO Int)
-> (Ptr CLlamaModel -> IO Int) -> IO Int
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaModel
modelPtr ->
    Int32 -> Int
fromCInt (Int32 -> Int) -> IO Int32 -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CLlamaModel -> IO Int32
c_llama_model_n_head_kv (Ptr CLlamaModel -> CLlamaModel
CLlamaModel Ptr CLlamaModel
modelPtr)

-- | Get the number of attention heads in the model.
getModelNumHeads :: Model -> IO Int
getModelNumHeads :: Model -> IO Int
getModelNumHeads (Model ForeignPtr CLlamaModel
modelFPtr) =
  ForeignPtr CLlamaModel -> (Ptr CLlamaModel -> IO Int) -> IO Int
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaModel
modelFPtr ((Ptr CLlamaModel -> IO Int) -> IO Int)
-> (Ptr CLlamaModel -> IO Int) -> IO Int
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaModel
modelPtr ->
    Int32 -> Int
fromCInt (Int32 -> Int) -> IO Int32 -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CLlamaModel -> IO Int32
c_llama_model_n_head (Ptr CLlamaModel -> CLlamaModel
CLlamaModel Ptr CLlamaModel
modelPtr)

-- | Get the number of transformer layers in the model.
getModelNumLayers :: Model -> IO Int
getModelNumLayers :: Model -> IO Int
getModelNumLayers (Model ForeignPtr CLlamaModel
modelFPtr) =
  ForeignPtr CLlamaModel -> (Ptr CLlamaModel -> IO Int) -> IO Int
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaModel
modelFPtr ((Ptr CLlamaModel -> IO Int) -> IO Int)
-> (Ptr CLlamaModel -> IO Int) -> IO Int
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaModel
modelPtr ->
    Int32 -> Int
fromCInt (Int32 -> Int) -> IO Int32 -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CLlamaModel -> IO Int32
c_llama_model_n_layer (Ptr CLlamaModel -> CLlamaModel
CLlamaModel Ptr CLlamaModel
modelPtr)

-- | Get the embedding dimension of the model.
getModelEmbeddingDim :: Model -> IO Int
getModelEmbeddingDim :: Model -> IO Int
getModelEmbeddingDim (Model ForeignPtr CLlamaModel
modelFPtr) =
  ForeignPtr CLlamaModel -> (Ptr CLlamaModel -> IO Int) -> IO Int
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaModel
modelFPtr ((Ptr CLlamaModel -> IO Int) -> IO Int)
-> (Ptr CLlamaModel -> IO Int) -> IO Int
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaModel
modelPtr ->
    Int32 -> Int
fromCInt (Int32 -> Int) -> IO Int32 -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CLlamaModel -> IO Int32
c_llama_model_n_embd (Ptr CLlamaModel -> CLlamaModel
CLlamaModel Ptr CLlamaModel
modelPtr)

-- | Get the training context size of the model.
getModelTrainingContextSize :: Model -> IO Int
getModelTrainingContextSize :: Model -> IO Int
getModelTrainingContextSize (Model ForeignPtr CLlamaModel
modelFPtr) =
  ForeignPtr CLlamaModel -> (Ptr CLlamaModel -> IO Int) -> IO Int
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaModel
modelFPtr ((Ptr CLlamaModel -> IO Int) -> IO Int)
-> (Ptr CLlamaModel -> IO Int) -> IO Int
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaModel
modelPtr ->
    Int32 -> Int
fromCInt (Int32 -> Int) -> IO Int32 -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CLlamaModel -> IO Int32
c_llama_model_n_ctx_train (Ptr CLlamaModel -> CLlamaModel
CLlamaModel Ptr CLlamaModel
modelPtr)

fromCInt :: Int32 -> Int
fromCInt :: Int32 -> Int
fromCInt = Int32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral

-- | Get the size of the model in bytes
getModelSize :: Model -> IO Int64
getModelSize :: Model -> IO Int64
getModelSize (Model ForeignPtr CLlamaModel
modelFPtr) =
  ForeignPtr CLlamaModel -> (Ptr CLlamaModel -> IO Int64) -> IO Int64
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaModel
modelFPtr ((Ptr CLlamaModel -> IO Int64) -> IO Int64)
-> (Ptr CLlamaModel -> IO Int64) -> IO Int64
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaModel
modelPtr ->
    Word64 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word64 -> Int64) -> IO Word64 -> IO Int64
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CLlamaModel -> IO Word64
c_llama_model_size (Ptr CLlamaModel -> CLlamaModel
CLlamaModel Ptr CLlamaModel
modelPtr)

-- | Check if the model has an encoder
getModelHasEncoder :: Model -> IO Bool
getModelHasEncoder :: Model -> IO Bool
getModelHasEncoder (Model ForeignPtr CLlamaModel
modelFPtr) = do
  ForeignPtr CLlamaModel -> (Ptr CLlamaModel -> IO Bool) -> IO Bool
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaModel
modelFPtr ((Ptr CLlamaModel -> IO Bool) -> IO Bool)
-> (Ptr CLlamaModel -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaModel
modelPtr -> do
    CLlamaModel -> IO CBool
c_llama_model_has_encoder (Ptr CLlamaModel -> CLlamaModel
CLlamaModel Ptr CLlamaModel
modelPtr) IO CBool -> (CBool -> Bool) -> IO Bool
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (CBool -> CBool -> Bool
forall a. Eq a => a -> a -> Bool
/= CBool
0)

-- | Get the chat template from a model, optionally by name
getModelChatTemplate :: Model -> Maybe String -> IO (Either String String)
getModelChatTemplate :: Model -> Maybe FilePath -> IO (Either FilePath FilePath)
getModelChatTemplate (Model ForeignPtr CLlamaModel
modelFPtr) Maybe FilePath
mName = do
  ForeignPtr CLlamaModel
-> (Ptr CLlamaModel -> IO (Either FilePath FilePath))
-> IO (Either FilePath FilePath)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaModel
modelFPtr ((Ptr CLlamaModel -> IO (Either FilePath FilePath))
 -> IO (Either FilePath FilePath))
-> (Ptr CLlamaModel -> IO (Either FilePath FilePath))
-> IO (Either FilePath FilePath)
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaModel
modelPtr -> do
    cName <- IO CString
-> (FilePath -> IO CString) -> Maybe FilePath -> IO CString
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (CString -> IO CString
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure CString
forall a. Ptr a
nullPtr) FilePath -> IO CString
newCString Maybe FilePath
mName
    template <- c_llama_model_chat_template (CLlamaModel modelPtr) cName
    if template == nullPtr
      then return $ Left "Failed to get chat template"
      else Right <$> peekCString template

-- | Get the number of parameters in the model
getModelNumParams :: Model -> IO (Either String Int64)
getModelNumParams :: Model -> IO (Either FilePath Int64)
getModelNumParams (Model ForeignPtr CLlamaModel
modelFPtr) = do
  ForeignPtr CLlamaModel
-> (Ptr CLlamaModel -> IO (Either FilePath Int64))
-> IO (Either FilePath Int64)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaModel
modelFPtr ((Ptr CLlamaModel -> IO (Either FilePath Int64))
 -> IO (Either FilePath Int64))
-> (Ptr CLlamaModel -> IO (Either FilePath Int64))
-> IO (Either FilePath Int64)
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaModel
modelPtr -> do
    params <- CLlamaModel -> IO Word64
c_llama_model_n_params (Ptr CLlamaModel -> CLlamaModel
CLlamaModel Ptr CLlamaModel
modelPtr)
    if params == 0
      then return $ Left "Failed to get number of parameters"
      else return $ Right $ fromIntegral params

-- | Check if the model has a decoder
getModelHasDecoder :: Model -> IO Bool
getModelHasDecoder :: Model -> IO Bool
getModelHasDecoder (Model ForeignPtr CLlamaModel
modelFPtr) = do
  ForeignPtr CLlamaModel -> (Ptr CLlamaModel -> IO Bool) -> IO Bool
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaModel
modelFPtr ((Ptr CLlamaModel -> IO Bool) -> IO Bool)
-> (Ptr CLlamaModel -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaModel
modelPtr -> do
    CLlamaModel -> IO CBool
c_llama_model_has_decoder (Ptr CLlamaModel -> CLlamaModel
CLlamaModel Ptr CLlamaModel
modelPtr) IO CBool -> (CBool -> Bool) -> IO Bool
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (CBool -> CBool -> Bool
forall a. Eq a => a -> a -> Bool
/= CBool
0)

-- | Get the decoder start token from the model
getModelDecoderStartToken :: Model -> IO (Either String LlamaToken)
getModelDecoderStartToken :: Model -> IO (Either FilePath CInt)
getModelDecoderStartToken (Model ForeignPtr CLlamaModel
modelFPtr) = do
  ForeignPtr CLlamaModel
-> (Ptr CLlamaModel -> IO (Either FilePath CInt))
-> IO (Either FilePath CInt)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaModel
modelFPtr ((Ptr CLlamaModel -> IO (Either FilePath CInt))
 -> IO (Either FilePath CInt))
-> (Ptr CLlamaModel -> IO (Either FilePath CInt))
-> IO (Either FilePath CInt)
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaModel
modelPtr -> do
    token_ <- CLlamaModel -> IO CInt
c_llama_model_decoder_start_token (Ptr CLlamaModel -> CLlamaModel
CLlamaModel Ptr CLlamaModel
modelPtr)
    if token_ == -1
      then return $ Left "Failed to get decoder start token"
      else return $ Right token_

-- | Check if the model is recurrent
getModelIsRecurrent :: Model -> IO Bool
getModelIsRecurrent :: Model -> IO Bool
getModelIsRecurrent (Model ForeignPtr CLlamaModel
modelFPtr) = do
  ForeignPtr CLlamaModel -> (Ptr CLlamaModel -> IO Bool) -> IO Bool
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaModel
modelFPtr ((Ptr CLlamaModel -> IO Bool) -> IO Bool)
-> (Ptr CLlamaModel -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaModel
modelPtr -> do
    CLlamaModel -> IO CBool
c_llama_model_is_recurrent (Ptr CLlamaModel -> CLlamaModel
CLlamaModel Ptr CLlamaModel
modelPtr) IO CBool -> (CBool -> Bool) -> IO Bool
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (CBool -> CBool -> Bool
forall a. Eq a => a -> a -> Bool
/= CBool
0)

-- | Quantize a model from a file to another file using specified parameters
quantizeModel ::
  FilePath ->
  FilePath ->
  LlamaModelQuantizeParams ->
  IO (Either String Word32)
quantizeModel :: FilePath
-> FilePath
-> LlamaModelQuantizeParams
-> IO (Either FilePath Word32)
quantizeModel FilePath
inpPath FilePath
outPath LlamaModelQuantizeParams
params = do
  FilePath
-> (CString -> IO (Either FilePath Word32))
-> IO (Either FilePath Word32)
forall a. FilePath -> (CString -> IO a) -> IO a
withCString FilePath
inpPath ((CString -> IO (Either FilePath Word32))
 -> IO (Either FilePath Word32))
-> (CString -> IO (Either FilePath Word32))
-> IO (Either FilePath Word32)
forall a b. (a -> b) -> a -> b
$ \CString
cInpPath -> do
    FilePath
-> (CString -> IO (Either FilePath Word32))
-> IO (Either FilePath Word32)
forall a. FilePath -> (CString -> IO a) -> IO a
withCString FilePath
outPath ((CString -> IO (Either FilePath Word32))
 -> IO (Either FilePath Word32))
-> (CString -> IO (Either FilePath Word32))
-> IO (Either FilePath Word32)
forall a b. (a -> b) -> a -> b
$ \CString
cOutPath -> do
      LlamaModelQuantizeParams
-> (Ptr LlamaModelQuantizeParams -> IO (Either FilePath Word32))
-> IO (Either FilePath Word32)
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
withStorable LlamaModelQuantizeParams
params ((Ptr LlamaModelQuantizeParams -> IO (Either FilePath Word32))
 -> IO (Either FilePath Word32))
-> (Ptr LlamaModelQuantizeParams -> IO (Either FilePath Word32))
-> IO (Either FilePath Word32)
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaModelQuantizeParams
paramsPtr -> do
        result <-
          CString -> CString -> CLlamaModelQuantizeParams -> IO Word32
c_llama_model_quantize
            CString
cInpPath
            CString
cOutPath
            (Ptr LlamaModelQuantizeParams -> CLlamaModelQuantizeParams
CLlamaModelQuantizeParams Ptr LlamaModelQuantizeParams
paramsPtr)
        if result == 0
          then return $ Left "Failed to quantize model"
          else return $ Right result

-- | Quantize a model from a file to another file using default parameters
quantizeModelDefault :: FilePath -> FilePath -> IO (Either String Word32)
quantizeModelDefault :: FilePath -> FilePath -> IO (Either FilePath Word32)
quantizeModelDefault FilePath
inpPath FilePath
outPath = do
  params <- IO LlamaModelQuantizeParams
defaultQuantizeParams
  quantizeModel inpPath outPath params

-- | Get the default quantization parameters
defaultQuantizeParams :: IO LlamaModelQuantizeParams
defaultQuantizeParams :: IO LlamaModelQuantizeParams
defaultQuantizeParams = do
  (CLlamaModelQuantizeParams paramsPtr) <- IO CLlamaModelQuantizeParams
c_llama_model_quantize_default_params
  peek paramsPtr

-- | Get a metadata value as a string from a model
getModelMetaValue :: Model -> String -> IO (Either String String)
getModelMetaValue :: Model -> FilePath -> IO (Either FilePath FilePath)
getModelMetaValue (Model ForeignPtr CLlamaModel
modelFPtr) FilePath
key = do
  ForeignPtr CLlamaModel
-> (Ptr CLlamaModel -> IO (Either FilePath FilePath))
-> IO (Either FilePath FilePath)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaModel
modelFPtr ((Ptr CLlamaModel -> IO (Either FilePath FilePath))
 -> IO (Either FilePath FilePath))
-> (Ptr CLlamaModel -> IO (Either FilePath FilePath))
-> IO (Either FilePath FilePath)
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaModel
modelPtr -> do
    FilePath
-> (CString -> IO (Either FilePath FilePath))
-> IO (Either FilePath FilePath)
forall a. FilePath -> (CString -> IO a) -> IO a
withCString FilePath
key ((CString -> IO (Either FilePath FilePath))
 -> IO (Either FilePath FilePath))
-> (CString -> IO (Either FilePath FilePath))
-> IO (Either FilePath FilePath)
forall a b. (a -> b) -> a -> b
$ \CString
cKey -> do
      Int
-> (CString -> IO (Either FilePath FilePath))
-> IO (Either FilePath FilePath)
forall a b. Storable a => Int -> (Ptr a -> IO b) -> IO b
allocaArray Int
256 ((CString -> IO (Either FilePath FilePath))
 -> IO (Either FilePath FilePath))
-> (CString -> IO (Either FilePath FilePath))
-> IO (Either FilePath FilePath)
forall a b. (a -> b) -> a -> b
$ \CString
bufPtr -> do
        result <- CLlamaModel -> CString -> CString -> CSize -> IO CInt
c_llama_model_meta_val_str (Ptr CLlamaModel -> CLlamaModel
CLlamaModel Ptr CLlamaModel
modelPtr) CString
cKey CString
bufPtr CSize
256
        if result == -1
          then return $ Left "Failed to get metadata value"
          else do
            str <- peekCString bufPtr
            return $ Right str

-- | Get the number of metadata entries in a model
getModelMetaCount :: Model -> IO Int
getModelMetaCount :: Model -> IO Int
getModelMetaCount (Model ForeignPtr CLlamaModel
modelFPtr) = do
  ForeignPtr CLlamaModel -> (Ptr CLlamaModel -> IO Int) -> IO Int
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaModel
modelFPtr ((Ptr CLlamaModel -> IO Int) -> IO Int)
-> (Ptr CLlamaModel -> IO Int) -> IO Int
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaModel
modelPtr -> do
    CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> Int) -> IO CInt -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CLlamaModel -> IO CInt
c_llama_model_meta_count (Ptr CLlamaModel -> CLlamaModel
CLlamaModel Ptr CLlamaModel
modelPtr)

-- | Get a metadata key by index from a model
getModelMetaKeyByIndex :: Model -> Int -> IO (Either String String)
getModelMetaKeyByIndex :: Model -> Int -> IO (Either FilePath FilePath)
getModelMetaKeyByIndex (Model ForeignPtr CLlamaModel
modelFPtr) Int
index = do
  ForeignPtr CLlamaModel
-> (Ptr CLlamaModel -> IO (Either FilePath FilePath))
-> IO (Either FilePath FilePath)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaModel
modelFPtr ((Ptr CLlamaModel -> IO (Either FilePath FilePath))
 -> IO (Either FilePath FilePath))
-> (Ptr CLlamaModel -> IO (Either FilePath FilePath))
-> IO (Either FilePath FilePath)
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaModel
modelPtr -> do
    Int
-> (CString -> IO (Either FilePath FilePath))
-> IO (Either FilePath FilePath)
forall a b. Storable a => Int -> (Ptr a -> IO b) -> IO b
allocaArray Int
256 ((CString -> IO (Either FilePath FilePath))
 -> IO (Either FilePath FilePath))
-> (CString -> IO (Either FilePath FilePath))
-> IO (Either FilePath FilePath)
forall a b. (a -> b) -> a -> b
$ \CString
bufPtr -> do
      result <-
        CLlamaModel -> CInt -> CString -> CSize -> IO CInt
c_llama_model_meta_key_by_index
          (Ptr CLlamaModel -> CLlamaModel
CLlamaModel Ptr CLlamaModel
modelPtr)
          (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
index)
          CString
bufPtr
          CSize
256
      if result == -1
        then return $ Left "Failed to get metadata key"
        else do
          str <- peekCString bufPtr
          return $ Right str

-- | Get a metadata value by index from a model
getModelMetaValueByIndex :: Model -> Int -> IO (Either String String)
getModelMetaValueByIndex :: Model -> Int -> IO (Either FilePath FilePath)
getModelMetaValueByIndex (Model ForeignPtr CLlamaModel
modelFPtr) Int
index = do
  ForeignPtr CLlamaModel
-> (Ptr CLlamaModel -> IO (Either FilePath FilePath))
-> IO (Either FilePath FilePath)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaModel
modelFPtr ((Ptr CLlamaModel -> IO (Either FilePath FilePath))
 -> IO (Either FilePath FilePath))
-> (Ptr CLlamaModel -> IO (Either FilePath FilePath))
-> IO (Either FilePath FilePath)
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaModel
modelPtr -> do
    Int
-> (CString -> IO (Either FilePath FilePath))
-> IO (Either FilePath FilePath)
forall a b. Storable a => Int -> (Ptr a -> IO b) -> IO b
allocaArray Int
256 ((CString -> IO (Either FilePath FilePath))
 -> IO (Either FilePath FilePath))
-> (CString -> IO (Either FilePath FilePath))
-> IO (Either FilePath FilePath)
forall a b. (a -> b) -> a -> b
$ \CString
bufPtr -> do
      result <-
        CLlamaModel -> CInt -> CString -> CSize -> IO CInt
c_llama_model_meta_val_str_by_index
          (Ptr CLlamaModel -> CLlamaModel
CLlamaModel Ptr CLlamaModel
modelPtr)
          (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
index)
          CString
bufPtr
          CSize
256
      if result == -1
        then return $ Left "Failed to get metadata value"
        else do
          str <- peekCString bufPtr
          return $ Right str

-- | Get a model description
getModelDescription :: Model -> IO String
getModelDescription :: Model -> IO FilePath
getModelDescription (Model ForeignPtr CLlamaModel
modelFPtr) = do
  ForeignPtr CLlamaModel
-> (Ptr CLlamaModel -> IO FilePath) -> IO FilePath
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaModel
modelFPtr ((Ptr CLlamaModel -> IO FilePath) -> IO FilePath)
-> (Ptr CLlamaModel -> IO FilePath) -> IO FilePath
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaModel
modelPtr -> do
    Int -> (CString -> IO FilePath) -> IO FilePath
forall a b. Storable a => Int -> (Ptr a -> IO b) -> IO b
allocaArray Int
256 ((CString -> IO FilePath) -> IO FilePath)
-> (CString -> IO FilePath) -> IO FilePath
forall a b. (a -> b) -> a -> b
$ \CString
bufPtr -> do
      _ <- CLlamaModel -> CString -> CSize -> IO CInt
c_llama_model_desc (Ptr CLlamaModel -> CLlamaModel
CLlamaModel Ptr CLlamaModel
modelPtr) CString
bufPtr CSize
256
      peekCString bufPtr

-- | Load a model from multiple file paths using specified parameters
loadModelFromSplits :: [FilePath] -> LlamaModelParams -> IO (Either String Model)
loadModelFromSplits :: [FilePath] -> LlamaModelParams -> IO (Either FilePath Model)
loadModelFromSplits [FilePath]
paths LlamaModelParams
params = do
  LlamaModelParams
-> (Ptr LlamaModelParams -> IO (Either FilePath Model))
-> IO (Either FilePath Model)
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
withStorable LlamaModelParams
params ((Ptr LlamaModelParams -> IO (Either FilePath Model))
 -> IO (Either FilePath Model))
-> (Ptr LlamaModelParams -> IO (Either FilePath Model))
-> IO (Either FilePath Model)
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaModelParams
paramsPtr -> do
    pathsPtr <- [FilePath] -> IO (ForeignPtr CString)
newArrayOfPtrs [FilePath]
paths
    withForeignPtr pathsPtr $ \Ptr CString
pathsPtr' -> do
      model <-
        Ptr CString -> CSize -> CLlamaModelParams -> IO CLlamaModel
c_llama_model_load_from_splits
          Ptr CString
pathsPtr'
          (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([FilePath] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [FilePath]
paths))
          (Ptr LlamaModelParams -> CLlamaModelParams
CLlamaModelParams Ptr LlamaModelParams
paramsPtr)
      if model == CLlamaModel nullPtr
        then return $ Left "Failed to load model"
        else do
          let (CLlamaModel modelPtr) = model
          fp <- newForeignPtr p_llama_model_free modelPtr
          return $ Right $ Model fp

-- | Get the RoPE type from a model
getModelRopeType :: Model -> IO (Maybe LlamaRopeTypeScaling)
getModelRopeType :: Model -> IO (Maybe LlamaRopeTypeScaling)
getModelRopeType (Model ForeignPtr CLlamaModel
modelFPtr) = do
  ForeignPtr CLlamaModel
-> (Ptr CLlamaModel -> IO (Maybe LlamaRopeTypeScaling))
-> IO (Maybe LlamaRopeTypeScaling)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaModel
modelFPtr ((Ptr CLlamaModel -> IO (Maybe LlamaRopeTypeScaling))
 -> IO (Maybe LlamaRopeTypeScaling))
-> (Ptr CLlamaModel -> IO (Maybe LlamaRopeTypeScaling))
-> IO (Maybe LlamaRopeTypeScaling)
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaModel
modelPtr -> do
    (Ptr CInt -> IO (Maybe LlamaRopeTypeScaling))
-> IO (Maybe LlamaRopeTypeScaling)
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr CInt -> IO (Maybe LlamaRopeTypeScaling))
 -> IO (Maybe LlamaRopeTypeScaling))
-> (Ptr CInt -> IO (Maybe LlamaRopeTypeScaling))
-> IO (Maybe LlamaRopeTypeScaling)
forall a b. (a -> b) -> a -> b
$ \Ptr CInt
outPtr -> do
      CLlamaModel -> Ptr CInt -> IO ()
c_llama_model_rope_type_into (Ptr CLlamaModel -> CLlamaModel
CLlamaModel Ptr CLlamaModel
modelPtr) Ptr CInt
outPtr
      val <- Ptr CInt -> IO CInt
forall a. Storable a => Ptr a -> IO a
peek Ptr CInt
outPtr
      return $ fromLlamaRopeTypeScaling val

newArrayOfPtrs :: [FilePath] -> IO (ForeignPtr CString)
newArrayOfPtrs :: [FilePath] -> IO (ForeignPtr CString)
newArrayOfPtrs [FilePath]
xs = do
  ptrs <- Int -> IO (Ptr CString)
forall a. Int -> IO (Ptr a)
mallocBytes ([FilePath] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [FilePath]
xs Int -> Int -> Int
forall a. Num a => a -> a -> a
* Ptr CString -> Int
forall a. Storable a => a -> Int
sizeOf (Ptr CString
forall a. HasCallStack => a
undefined :: Ptr CString))
  mapM_
    ( \(Int
i, FilePath
x) -> FilePath -> (CString -> IO ()) -> IO ()
forall a. FilePath -> (CString -> IO a) -> IO a
withCString FilePath
x ((CString -> IO ()) -> IO ()) -> (CString -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \CString
cstr ->
        Ptr CString -> CString -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke
          (Ptr CString -> Ptr (ZonkAny 0)
forall a b. Ptr a -> Ptr b
castPtr Ptr CString
ptrs Ptr (ZonkAny 0) -> Int -> Ptr CString
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
* Ptr CString -> Int
forall a. Storable a => a -> Int
sizeOf (Ptr CString
forall a. HasCallStack => a
undefined :: Ptr CString)))
          CString
cstr
    )
    (zip [0 ..] xs)
  newForeignPtr_ ptrs