module Llama.Decode
( batchInit
, batchGetOne
, freeBatch
, encodeBatch
, decodeBatch
, setThreadCount
, getThreadCount
, getBatchThreadCount
, setEmbeddingsEnabled
, areEmbeddingsEnabled
, setCausalAttention
, setThreadCounts
, setWarmupMode
, synchronizeContext
) where
import Foreign
import Llama.Internal.Foreign
import Llama.Internal.Types
batchInit :: Int -> Int -> Int -> IO Batch
batchInit :: Int -> Int -> Int -> IO Batch
batchInit Int
nTokens Int
embd_ Int
nSeqMax = do
let cTokens :: CInt
cTokens = Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nTokens
cEmb :: CInt
cEmb = Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
embd_
cSeqMax :: CInt
cSeqMax = Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nSeqMax
(Ptr LlamaBatch -> IO Batch) -> IO Batch
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr LlamaBatch -> IO Batch) -> IO Batch)
-> (Ptr LlamaBatch -> IO Batch) -> IO Batch
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaBatch
ptr -> do
CInt -> CInt -> CInt -> Ptr LlamaBatch -> IO ()
c_llama_batch_init_into CInt
cTokens CInt
cEmb CInt
cSeqMax Ptr LlamaBatch
ptr
Batch -> IO Batch
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Batch -> IO Batch) -> Batch -> IO Batch
forall a b. (a -> b) -> a -> b
$ Ptr LlamaBatch -> Batch
Batch Ptr LlamaBatch
ptr
batchGetOne :: [LlamaToken] -> IO Batch
batchGetOne :: [CInt] -> IO Batch
batchGetOne [CInt]
tokens = do
let nTokens :: Int
nTokens = [CInt] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [CInt]
tokens
(Ptr LlamaBatch -> IO Batch) -> IO Batch
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr LlamaBatch -> IO Batch) -> IO Batch)
-> (Ptr LlamaBatch -> IO Batch) -> IO Batch
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaBatch
ptr -> do
[CInt] -> (Ptr CInt -> IO Batch) -> IO Batch
forall a b. Storable a => [a] -> (Ptr a -> IO b) -> IO b
withArray [CInt]
tokens ((Ptr CInt -> IO Batch) -> IO Batch)
-> (Ptr CInt -> IO Batch) -> IO Batch
forall a b. (a -> b) -> a -> b
$ \Ptr CInt
tokensPtr -> do
Ptr CInt -> CInt -> Ptr LlamaBatch -> IO ()
c_llama_batch_get_one_into Ptr CInt
tokensPtr (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nTokens) Ptr LlamaBatch
ptr
Batch -> IO Batch
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Batch -> IO Batch) -> Batch -> IO Batch
forall a b. (a -> b) -> a -> b
$ Ptr LlamaBatch -> Batch
Batch Ptr LlamaBatch
ptr
freeBatch :: Ptr LlamaBatch -> IO ()
freeBatch :: Ptr LlamaBatch -> IO ()
freeBatch = Ptr LlamaBatch -> IO ()
c_llama_batch_free_wrap
encodeBatch :: Context -> Batch -> IO (Either String ())
encodeBatch :: Context -> Batch -> IO (Either String ())
encodeBatch (Context ForeignPtr CLlamaContext
ctxFPtr) (Batch Ptr LlamaBatch
batchPtr) = do
result <- ForeignPtr CLlamaContext
-> (Ptr CLlamaContext -> IO CInt) -> IO CInt
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO CInt) -> IO CInt)
-> (Ptr CLlamaContext -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
CLlamaContext -> Ptr LlamaBatch -> IO CInt
c_llama_encode (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr) Ptr LlamaBatch
batchPtr
if result < 0
then return $ Left "Encoding failed"
else return $ Right ()
decodeBatch :: Context -> Batch -> IO (Either String ())
decodeBatch :: Context -> Batch -> IO (Either String ())
decodeBatch (Context ForeignPtr CLlamaContext
ctxFPtr) (Batch Ptr LlamaBatch
batchPtr) = do
result <- ForeignPtr CLlamaContext
-> (Ptr CLlamaContext -> IO CInt) -> IO CInt
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO CInt) -> IO CInt)
-> (Ptr CLlamaContext -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
CLlamaContext -> Ptr LlamaBatch -> IO CInt
c_llama_decode_wrap (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr) Ptr LlamaBatch
batchPtr
if result < 0
then return $ Left "Decoding failed"
else return $ Right ()
setThreadCount :: Context -> Int -> IO ()
setThreadCount :: Context -> Int -> IO ()
setThreadCount Context
ctx_ Int
n = Context -> Int -> Int -> IO ()
setThreadCounts Context
ctx_ Int
n Int
n
setThreadCounts :: Context -> Int -> Int -> IO ()
setThreadCounts :: Context -> Int -> Int -> IO ()
setThreadCounts (Context ForeignPtr CLlamaContext
ctxFPtr) Int
nThreads Int
nBatchThreads =
ForeignPtr CLlamaContext -> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO ()) -> IO ())
-> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
CLlamaContext -> CInt -> CInt -> IO ()
c_llama_set_n_threads
(Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr)
(Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nThreads)
(Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nBatchThreads)
getThreadCount :: Context -> IO Int
getThreadCount :: Context -> IO Int
getThreadCount (Context ForeignPtr CLlamaContext
ctxFPtr) =
ForeignPtr CLlamaContext -> (Ptr CLlamaContext -> IO Int) -> IO Int
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO Int) -> IO Int)
-> (Ptr CLlamaContext -> IO Int) -> IO Int
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> Int) -> IO CInt -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CLlamaContext -> IO CInt
c_llama_n_threads (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr)
getBatchThreadCount :: Context -> IO Int
getBatchThreadCount :: Context -> IO Int
getBatchThreadCount (Context ForeignPtr CLlamaContext
ctxFPtr) =
ForeignPtr CLlamaContext -> (Ptr CLlamaContext -> IO Int) -> IO Int
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO Int) -> IO Int)
-> (Ptr CLlamaContext -> IO Int) -> IO Int
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> Int) -> IO CInt -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CLlamaContext -> IO CInt
c_llama_n_threads_batch (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr)
setEmbeddingsEnabled :: Context -> Bool -> IO ()
setEmbeddingsEnabled :: Context -> Bool -> IO ()
setEmbeddingsEnabled (Context ForeignPtr CLlamaContext
ctxFPtr) Bool
enabled =
ForeignPtr CLlamaContext -> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO ()) -> IO ())
-> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
CLlamaContext -> CBool -> IO ()
c_llama_set_embeddings (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr) (Bool -> CBool
forall a. Num a => Bool -> a
fromBool Bool
enabled)
areEmbeddingsEnabled :: Context -> IO Bool
areEmbeddingsEnabled :: Context -> IO Bool
areEmbeddingsEnabled (Context ForeignPtr CLlamaContext
ctxFPtr) =
ForeignPtr CLlamaContext
-> (Ptr CLlamaContext -> IO Bool) -> IO Bool
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO Bool) -> IO Bool)
-> (Ptr CLlamaContext -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr -> do
resPtr <- CLlamaContext -> IO (Ptr CFloat)
c_llama_get_embeddings (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr)
toBool <$> peek resPtr
setCausalAttention :: Context -> Bool -> IO ()
setCausalAttention :: Context -> Bool -> IO ()
setCausalAttention (Context ForeignPtr CLlamaContext
ctxFPtr) Bool
causal =
ForeignPtr CLlamaContext -> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO ()) -> IO ())
-> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
CLlamaContext -> CBool -> IO ()
c_llama_set_causal_attn (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr) (Bool -> CBool
forall a. Num a => Bool -> a
fromBool Bool
causal)
setWarmupMode :: Context -> Bool -> IO ()
setWarmupMode :: Context -> Bool -> IO ()
setWarmupMode (Context ForeignPtr CLlamaContext
ctxFPtr) Bool
warm =
ForeignPtr CLlamaContext -> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO ()) -> IO ())
-> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
CLlamaContext -> CBool -> IO ()
c_llama_set_warmup (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr) (Bool -> CBool
forall a. Num a => Bool -> a
fromBool Bool
warm)
synchronizeContext :: Context -> IO ()
synchronizeContext :: Context -> IO ()
synchronizeContext (Context ForeignPtr CLlamaContext
ctxFPtr) =
ForeignPtr CLlamaContext -> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO ()) -> IO ())
-> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
CLlamaContext -> IO ()
c_llama_synchronize (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr)