{- |
Module      : Llama.Adapter
Description : Adapter interface for llama-cpp
Copyright   : (c) 2025 Tushar Adhatrao
License     : MIT
Maintainer  : Tushar Adhatrao <tusharadhatrao@gmail.com>
-}
module Llama.Adapter (
    initAdapterLora
   , setAdapterLora
    , rmAdapterLora
    , clearAdapterLora
    , applyAdapterCVec
) where

import Foreign
import Foreign.C.String
import Llama.Internal.Foreign
import Llama.Internal.Types
import Foreign.C (CFloat)

{- | Initialize a LoRA adapter from a file path.
This function wraps the C function 'llama_adapter_lora_init'
and returns a managed AdapterLora object.
-}
initAdapterLora :: Model -> FilePath -> IO (Either String AdapterLora)
initAdapterLora :: Model -> FilePath -> IO (Either FilePath AdapterLora)
initAdapterLora (Model ForeignPtr CLlamaModel
modelFPtr) FilePath
path = do
  ForeignPtr CLlamaModel
-> (Ptr CLlamaModel -> IO (Either FilePath AdapterLora))
-> IO (Either FilePath AdapterLora)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaModel
modelFPtr ((Ptr CLlamaModel -> IO (Either FilePath AdapterLora))
 -> IO (Either FilePath AdapterLora))
-> (Ptr CLlamaModel -> IO (Either FilePath AdapterLora))
-> IO (Either FilePath AdapterLora)
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaModel
modelPtr -> do
    FilePath
-> (CString -> IO (Either FilePath AdapterLora))
-> IO (Either FilePath AdapterLora)
forall a. FilePath -> (CString -> IO a) -> IO a
withCString FilePath
path ((CString -> IO (Either FilePath AdapterLora))
 -> IO (Either FilePath AdapterLora))
-> (CString -> IO (Either FilePath AdapterLora))
-> IO (Either FilePath AdapterLora)
forall a b. (a -> b) -> a -> b
$ \CString
cPath -> do
      adapter@(CLlamaAdapterLora adapterPtr) <-
        CLlamaModel -> CString -> IO CLlamaAdapterLora
c_llama_adapter_lora_init (Ptr CLlamaModel -> CLlamaModel
CLlamaModel Ptr CLlamaModel
modelPtr) CString
cPath
      if adapter == CLlamaAdapterLora nullPtr
        then return $ Left "Failed to initialize LoRA adapter"
        else do
          -- We directly use c_llama_adapter_lora_free as the finalizer
          fp <- newForeignPtr p_llama_adapter_lora_free adapterPtr
          return $ Right $ AdapterLora fp

-- | Apply a LoRA adapter with a given scale.
setAdapterLora :: Context -> AdapterLora -> Float -> IO (Either String ())
setAdapterLora :: Context -> AdapterLora -> Float -> IO (Either FilePath ())
setAdapterLora (Context ForeignPtr CLlamaContext
ctxFPtr) (AdapterLora ForeignPtr CLlamaAdapterLora
adapterFPtr) Float
scale = do
  result <- ForeignPtr CLlamaContext
-> (Ptr CLlamaContext -> IO CInt) -> IO CInt
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO CInt) -> IO CInt)
-> (Ptr CLlamaContext -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
    ForeignPtr CLlamaAdapterLora
-> (Ptr CLlamaAdapterLora -> IO CInt) -> IO CInt
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaAdapterLora
adapterFPtr ((Ptr CLlamaAdapterLora -> IO CInt) -> IO CInt)
-> (Ptr CLlamaAdapterLora -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaAdapterLora
adapterPtr ->
      CLlamaContext -> CLlamaAdapterLora -> CFloat -> IO CInt
c_llama_set_adapter_lora
        (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr)
        (Ptr CLlamaAdapterLora -> CLlamaAdapterLora
CLlamaAdapterLora Ptr CLlamaAdapterLora
adapterPtr)
        (Float -> CFloat
forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
scale)
  if result == -1
    then return $ Left "Failed to set LoRA adapter"
    else return $ Right ()

-- | Remove a previously applied LoRA adapter.
rmAdapterLora :: Context -> AdapterLora -> IO (Either String ())
rmAdapterLora :: Context -> AdapterLora -> IO (Either FilePath ())
rmAdapterLora (Context ForeignPtr CLlamaContext
ctxFPtr) (AdapterLora ForeignPtr CLlamaAdapterLora
adapterFPtr) = do
  result <- ForeignPtr CLlamaContext
-> (Ptr CLlamaContext -> IO CInt) -> IO CInt
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO CInt) -> IO CInt)
-> (Ptr CLlamaContext -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
    ForeignPtr CLlamaAdapterLora
-> (Ptr CLlamaAdapterLora -> IO CInt) -> IO CInt
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaAdapterLora
adapterFPtr ((Ptr CLlamaAdapterLora -> IO CInt) -> IO CInt)
-> (Ptr CLlamaAdapterLora -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaAdapterLora
adapterPtr ->
      CLlamaContext -> CLlamaAdapterLora -> IO CInt
c_llama_rm_adapter_lora (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr) (Ptr CLlamaAdapterLora -> CLlamaAdapterLora
CLlamaAdapterLora Ptr CLlamaAdapterLora
adapterPtr)
  if result == -1
    then return $ Left "Failed to remove LoRA adapter"
    else return $ Right ()

-- | Clear all active LoRA adapters from the context.
clearAdapterLora :: Context -> IO ()
clearAdapterLora :: Context -> IO ()
clearAdapterLora (Context ForeignPtr CLlamaContext
ctxFPtr) =
  ForeignPtr CLlamaContext -> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO ()) -> IO ())
-> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
    CLlamaContext -> IO ()
c_llama_clear_adapter_lora (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr)

-- | Apply a context vector (cvec).
applyAdapterCVec :: Context -> Maybe [Float] -> Int -> Int -> Int -> IO (Either String ())
applyAdapterCVec :: Context
-> Maybe [Float] -> Int -> Int -> Int -> IO (Either FilePath ())
applyAdapterCVec (Context ForeignPtr CLlamaContext
ctxFPtr) Maybe [Float]
mValues Int
n_embd Int
il_start Int
il_end = do
  (ptr, len) <- case Maybe [Float]
mValues of
    Maybe [Float]
Nothing -> (Ptr CFloat, Int) -> IO (Ptr CFloat, Int)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Ptr CFloat
forall a. Ptr a
nullPtr, Int
0 :: Int)
    Just [Float]
vs -> do
      arrayPtr <- [Float] -> IO (Ptr CFloat)
floatArrayToPtr [Float]
vs
      pure (arrayPtr, length vs)
  result <- withForeignPtr ctxFPtr $ \Ptr CLlamaContext
ctxPtr ->
    CLlamaContext
-> Ptr CFloat -> CSize -> CInt -> CInt -> CInt -> IO CInt
c_llama_apply_adapter_cvec
      (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr)
      Ptr CFloat
ptr
      (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)
      (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n_embd)
      (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
il_start)
      (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
il_end)
  if result == -1
    then return $ Left "Failed to apply context vector"
    else return $ Right ()

floatArrayToPtr :: [Float] -> IO (Ptr CFloat)
floatArrayToPtr :: [Float] -> IO (Ptr CFloat)
floatArrayToPtr [Float]
xs = do
  ptr <- Int -> IO (Ptr CFloat)
forall a. Int -> IO (Ptr a)
mallocBytes ([Float] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Float]
xs Int -> Int -> Int
forall a. Num a => a -> a -> a
* CFloat -> Int
forall a. Storable a => a -> Int
sizeOf (CFloat
forall a. HasCallStack => a
undefined :: CFloat))
  pokeArray ptr (map realToFrac xs)
  return ptr