module Llama.Adapter (
initAdapterLora
, setAdapterLora
, rmAdapterLora
, clearAdapterLora
, applyAdapterCVec
) where
import Foreign
import Foreign.C.String
import Llama.Internal.Foreign
import Llama.Internal.Types
import Foreign.C (CFloat)
initAdapterLora :: Model -> FilePath -> IO (Either String AdapterLora)
initAdapterLora :: Model -> FilePath -> IO (Either FilePath AdapterLora)
initAdapterLora (Model ForeignPtr CLlamaModel
modelFPtr) FilePath
path = do
ForeignPtr CLlamaModel
-> (Ptr CLlamaModel -> IO (Either FilePath AdapterLora))
-> IO (Either FilePath AdapterLora)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaModel
modelFPtr ((Ptr CLlamaModel -> IO (Either FilePath AdapterLora))
-> IO (Either FilePath AdapterLora))
-> (Ptr CLlamaModel -> IO (Either FilePath AdapterLora))
-> IO (Either FilePath AdapterLora)
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaModel
modelPtr -> do
FilePath
-> (CString -> IO (Either FilePath AdapterLora))
-> IO (Either FilePath AdapterLora)
forall a. FilePath -> (CString -> IO a) -> IO a
withCString FilePath
path ((CString -> IO (Either FilePath AdapterLora))
-> IO (Either FilePath AdapterLora))
-> (CString -> IO (Either FilePath AdapterLora))
-> IO (Either FilePath AdapterLora)
forall a b. (a -> b) -> a -> b
$ \CString
cPath -> do
adapter@(CLlamaAdapterLora adapterPtr) <-
CLlamaModel -> CString -> IO CLlamaAdapterLora
c_llama_adapter_lora_init (Ptr CLlamaModel -> CLlamaModel
CLlamaModel Ptr CLlamaModel
modelPtr) CString
cPath
if adapter == CLlamaAdapterLora nullPtr
then return $ Left "Failed to initialize LoRA adapter"
else do
fp <- newForeignPtr p_llama_adapter_lora_free adapterPtr
return $ Right $ AdapterLora fp
setAdapterLora :: Context -> AdapterLora -> Float -> IO (Either String ())
setAdapterLora :: Context -> AdapterLora -> Float -> IO (Either FilePath ())
setAdapterLora (Context ForeignPtr CLlamaContext
ctxFPtr) (AdapterLora ForeignPtr CLlamaAdapterLora
adapterFPtr) Float
scale = do
result <- ForeignPtr CLlamaContext
-> (Ptr CLlamaContext -> IO CInt) -> IO CInt
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO CInt) -> IO CInt)
-> (Ptr CLlamaContext -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
ForeignPtr CLlamaAdapterLora
-> (Ptr CLlamaAdapterLora -> IO CInt) -> IO CInt
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaAdapterLora
adapterFPtr ((Ptr CLlamaAdapterLora -> IO CInt) -> IO CInt)
-> (Ptr CLlamaAdapterLora -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaAdapterLora
adapterPtr ->
CLlamaContext -> CLlamaAdapterLora -> CFloat -> IO CInt
c_llama_set_adapter_lora
(Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr)
(Ptr CLlamaAdapterLora -> CLlamaAdapterLora
CLlamaAdapterLora Ptr CLlamaAdapterLora
adapterPtr)
(Float -> CFloat
forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
scale)
if result == -1
then return $ Left "Failed to set LoRA adapter"
else return $ Right ()
rmAdapterLora :: Context -> AdapterLora -> IO (Either String ())
rmAdapterLora :: Context -> AdapterLora -> IO (Either FilePath ())
rmAdapterLora (Context ForeignPtr CLlamaContext
ctxFPtr) (AdapterLora ForeignPtr CLlamaAdapterLora
adapterFPtr) = do
result <- ForeignPtr CLlamaContext
-> (Ptr CLlamaContext -> IO CInt) -> IO CInt
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO CInt) -> IO CInt)
-> (Ptr CLlamaContext -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
ForeignPtr CLlamaAdapterLora
-> (Ptr CLlamaAdapterLora -> IO CInt) -> IO CInt
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaAdapterLora
adapterFPtr ((Ptr CLlamaAdapterLora -> IO CInt) -> IO CInt)
-> (Ptr CLlamaAdapterLora -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaAdapterLora
adapterPtr ->
CLlamaContext -> CLlamaAdapterLora -> IO CInt
c_llama_rm_adapter_lora (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr) (Ptr CLlamaAdapterLora -> CLlamaAdapterLora
CLlamaAdapterLora Ptr CLlamaAdapterLora
adapterPtr)
if result == -1
then return $ Left "Failed to remove LoRA adapter"
else return $ Right ()
clearAdapterLora :: Context -> IO ()
clearAdapterLora :: Context -> IO ()
clearAdapterLora (Context ForeignPtr CLlamaContext
ctxFPtr) =
ForeignPtr CLlamaContext -> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO ()) -> IO ())
-> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
CLlamaContext -> IO ()
c_llama_clear_adapter_lora (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr)
applyAdapterCVec :: Context -> Maybe [Float] -> Int -> Int -> Int -> IO (Either String ())
applyAdapterCVec :: Context
-> Maybe [Float] -> Int -> Int -> Int -> IO (Either FilePath ())
applyAdapterCVec (Context ForeignPtr CLlamaContext
ctxFPtr) Maybe [Float]
mValues Int
n_embd Int
il_start Int
il_end = do
(ptr, len) <- case Maybe [Float]
mValues of
Maybe [Float]
Nothing -> (Ptr CFloat, Int) -> IO (Ptr CFloat, Int)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Ptr CFloat
forall a. Ptr a
nullPtr, Int
0 :: Int)
Just [Float]
vs -> do
arrayPtr <- [Float] -> IO (Ptr CFloat)
floatArrayToPtr [Float]
vs
pure (arrayPtr, length vs)
result <- withForeignPtr ctxFPtr $ \Ptr CLlamaContext
ctxPtr ->
CLlamaContext
-> Ptr CFloat -> CSize -> CInt -> CInt -> CInt -> IO CInt
c_llama_apply_adapter_cvec
(Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr)
Ptr CFloat
ptr
(Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)
(Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n_embd)
(Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
il_start)
(Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
il_end)
if result == -1
then return $ Left "Failed to apply context vector"
else return $ Right ()
floatArrayToPtr :: [Float] -> IO (Ptr CFloat)
floatArrayToPtr :: [Float] -> IO (Ptr CFloat)
floatArrayToPtr [Float]
xs = do
ptr <- Int -> IO (Ptr CFloat)
forall a. Int -> IO (Ptr a)
mallocBytes ([Float] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Float]
xs Int -> Int -> Int
forall a. Num a => a -> a -> a
* CFloat -> Int
forall a. Storable a => a -> Int
sizeOf (CFloat
forall a. HasCallStack => a
undefined :: CFloat))
pokeArray ptr (map realToFrac xs)
return ptr