{- |
Module      : Llama.Decode
Description : High level Decode interface for llama-cpp
Copyright   : (c) 2025 Tushar Adhatrao
License     : MIT
Maintainer  : Tushar Adhatrao <tusharadhatrao@gmail.com>
-}
module Llama.Decode
  ( batchInit
  , batchGetOne
  , freeBatch
  , encodeBatch
  , decodeBatch
  , setThreadCount
  , getThreadCount
  , getBatchThreadCount
  , setEmbeddingsEnabled
  , areEmbeddingsEnabled
  , setCausalAttention
  , setThreadCounts
  , setWarmupMode
  , synchronizeContext
  ) where

import Foreign
import Llama.Internal.Foreign
import Llama.Internal.Types

batchInit :: Int -> Int -> Int -> IO Batch
batchInit :: Int -> Int -> Int -> IO Batch
batchInit Int
nTokens Int
embd_ Int
nSeqMax = do
  let cTokens :: CInt
cTokens = Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nTokens
      cEmb :: CInt
cEmb = Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
embd_
      cSeqMax :: CInt
cSeqMax = Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nSeqMax
  (Ptr LlamaBatch -> IO Batch) -> IO Batch
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr LlamaBatch -> IO Batch) -> IO Batch)
-> (Ptr LlamaBatch -> IO Batch) -> IO Batch
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaBatch
ptr -> do
    CInt -> CInt -> CInt -> Ptr LlamaBatch -> IO ()
c_llama_batch_init_into CInt
cTokens CInt
cEmb CInt
cSeqMax Ptr LlamaBatch
ptr
    Batch -> IO Batch
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Batch -> IO Batch) -> Batch -> IO Batch
forall a b. (a -> b) -> a -> b
$ Ptr LlamaBatch -> Batch
Batch Ptr LlamaBatch
ptr

-- | Create a batch from a list of tokens.
batchGetOne :: [LlamaToken] -> IO Batch
batchGetOne :: [CInt] -> IO Batch
batchGetOne [CInt]
tokens = do
  let nTokens :: Int
nTokens = [CInt] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [CInt]
tokens
  (Ptr LlamaBatch -> IO Batch) -> IO Batch
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr LlamaBatch -> IO Batch) -> IO Batch)
-> (Ptr LlamaBatch -> IO Batch) -> IO Batch
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaBatch
ptr -> do
    [CInt] -> (Ptr CInt -> IO Batch) -> IO Batch
forall a b. Storable a => [a] -> (Ptr a -> IO b) -> IO b
withArray [CInt]
tokens ((Ptr CInt -> IO Batch) -> IO Batch)
-> (Ptr CInt -> IO Batch) -> IO Batch
forall a b. (a -> b) -> a -> b
$ \Ptr CInt
tokensPtr -> do
      Ptr CInt -> CInt -> Ptr LlamaBatch -> IO ()
c_llama_batch_get_one_into Ptr CInt
tokensPtr (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nTokens) Ptr LlamaBatch
ptr
      Batch -> IO Batch
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Batch -> IO Batch) -> Batch -> IO Batch
forall a b. (a -> b) -> a -> b
$ Ptr LlamaBatch -> Batch
Batch Ptr LlamaBatch
ptr

-- | Free a batch of tokens allocated with initBatch
freeBatch :: Ptr LlamaBatch -> IO ()
freeBatch :: Ptr LlamaBatch -> IO ()
freeBatch = Ptr LlamaBatch -> IO ()
c_llama_batch_free_wrap

-- | Encode tokens using the model context.
encodeBatch :: Context -> Batch -> IO (Either String ())
encodeBatch :: Context -> Batch -> IO (Either String ())
encodeBatch (Context ForeignPtr CLlamaContext
ctxFPtr) (Batch Ptr LlamaBatch
batchPtr) = do
  result <- ForeignPtr CLlamaContext
-> (Ptr CLlamaContext -> IO CInt) -> IO CInt
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO CInt) -> IO CInt)
-> (Ptr CLlamaContext -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
    CLlamaContext -> Ptr LlamaBatch -> IO CInt
c_llama_encode (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr) Ptr LlamaBatch
batchPtr
  if result < 0
    then return $ Left "Encoding failed"
    else return $ Right ()

-- | Decode tokens using the model context.
decodeBatch :: Context -> Batch -> IO (Either String ())
decodeBatch :: Context -> Batch -> IO (Either String ())
decodeBatch (Context ForeignPtr CLlamaContext
ctxFPtr) (Batch Ptr LlamaBatch
batchPtr) = do
  result <- ForeignPtr CLlamaContext
-> (Ptr CLlamaContext -> IO CInt) -> IO CInt
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO CInt) -> IO CInt)
-> (Ptr CLlamaContext -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
    CLlamaContext -> Ptr LlamaBatch -> IO CInt
c_llama_decode_wrap (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr) Ptr LlamaBatch
batchPtr
  if result < 0
    then return $ Left "Decoding failed"
    else return $ Right ()

-- | Set number of threads used for processing.
setThreadCount :: Context -> Int -> IO ()
setThreadCount :: Context -> Int -> IO ()
setThreadCount Context
ctx_ Int
n = Context -> Int -> Int -> IO ()
setThreadCounts Context
ctx_ Int
n Int
n

-- | Set main and batch thread counts separately.
setThreadCounts :: Context -> Int -> Int -> IO ()
setThreadCounts :: Context -> Int -> Int -> IO ()
setThreadCounts (Context ForeignPtr CLlamaContext
ctxFPtr) Int
nThreads Int
nBatchThreads =
  ForeignPtr CLlamaContext -> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO ()) -> IO ())
-> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
    CLlamaContext -> CInt -> CInt -> IO ()
c_llama_set_n_threads
      (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr)
      (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nThreads)
      (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nBatchThreads)

-- | Get current main thread count.
getThreadCount :: Context -> IO Int
getThreadCount :: Context -> IO Int
getThreadCount (Context ForeignPtr CLlamaContext
ctxFPtr) =
  ForeignPtr CLlamaContext -> (Ptr CLlamaContext -> IO Int) -> IO Int
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO Int) -> IO Int)
-> (Ptr CLlamaContext -> IO Int) -> IO Int
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
    CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> Int) -> IO CInt -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CLlamaContext -> IO CInt
c_llama_n_threads (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr)

-- | Get current batch thread count.
getBatchThreadCount :: Context -> IO Int
getBatchThreadCount :: Context -> IO Int
getBatchThreadCount (Context ForeignPtr CLlamaContext
ctxFPtr) =
  ForeignPtr CLlamaContext -> (Ptr CLlamaContext -> IO Int) -> IO Int
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO Int) -> IO Int)
-> (Ptr CLlamaContext -> IO Int) -> IO Int
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
    CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> Int) -> IO CInt -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CLlamaContext -> IO CInt
c_llama_n_threads_batch (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr)

-- | Enable or disable embeddings output.
setEmbeddingsEnabled :: Context -> Bool -> IO ()
setEmbeddingsEnabled :: Context -> Bool -> IO ()
setEmbeddingsEnabled (Context ForeignPtr CLlamaContext
ctxFPtr) Bool
enabled =
  ForeignPtr CLlamaContext -> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO ()) -> IO ())
-> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
    CLlamaContext -> CBool -> IO ()
c_llama_set_embeddings (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr) (Bool -> CBool
forall a. Num a => Bool -> a
fromBool Bool
enabled)

-- | Check if embeddings are enabled.
areEmbeddingsEnabled :: Context -> IO Bool
areEmbeddingsEnabled :: Context -> IO Bool
areEmbeddingsEnabled (Context ForeignPtr CLlamaContext
ctxFPtr) =
  ForeignPtr CLlamaContext
-> (Ptr CLlamaContext -> IO Bool) -> IO Bool
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO Bool) -> IO Bool)
-> (Ptr CLlamaContext -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr -> do
    resPtr <- CLlamaContext -> IO (Ptr CFloat)
c_llama_get_embeddings (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr)
    toBool <$> peek resPtr

-- | Set causal attention mode.
setCausalAttention :: Context -> Bool -> IO ()
setCausalAttention :: Context -> Bool -> IO ()
setCausalAttention (Context ForeignPtr CLlamaContext
ctxFPtr) Bool
causal =
  ForeignPtr CLlamaContext -> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO ()) -> IO ())
-> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
    CLlamaContext -> CBool -> IO ()
c_llama_set_causal_attn (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr) (Bool -> CBool
forall a. Num a => Bool -> a
fromBool Bool
causal)

-- | Set warmup mode (e.g. precompute KV cache).
setWarmupMode :: Context -> Bool -> IO ()
setWarmupMode :: Context -> Bool -> IO ()
setWarmupMode (Context ForeignPtr CLlamaContext
ctxFPtr) Bool
warm =
  ForeignPtr CLlamaContext -> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO ()) -> IO ())
-> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
    CLlamaContext -> CBool -> IO ()
c_llama_set_warmup (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr) (Bool -> CBool
forall a. Num a => Bool -> a
fromBool Bool
warm)

{-
type AbortCallback = Ptr () -> IO CInt
-- | Set an abort callback to cancel long-running ops from another thread.
setAbortCallback :: Context -> FunPtr AbortCallback -> Ptr () -> IO ()
setAbortCallback (Context ctxFPtr) callback cbData =
  withForeignPtr ctxFPtr $ \ctxPtr ->
    c_llama_set_abort_callback (CLlamaContext ctxPtr) callback cbData
    -}

-- | Block until all async work is complete.
synchronizeContext :: Context -> IO ()
synchronizeContext :: Context -> IO ()
synchronizeContext (Context ForeignPtr CLlamaContext
ctxFPtr) =
  ForeignPtr CLlamaContext -> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO ()) -> IO ())
-> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
    CLlamaContext -> IO ()
c_llama_synchronize (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr)