{- |
Module      : Llama.Context
Description : High level context interface for llama-cpp
Copyright   : (c) 2025 Tushar Adhatrao
License     : MIT
Maintainer  : Tushar Adhatrao <tusharadhatrao@gmail.com>
-}
module Llama.Context (
    supportsRpc
, supportsGpuOffload
, supportsMLock
, supportsMMap
, getMaxDevices
, getTimeUs
, getContextSize
, getBatchSize
, getUnbatchedSize
, getMaxSeqCount
, getPoolingType
, detachThreadPool
, defaultContextParams
) where

import Llama.Internal.Foreign
import Llama.Internal.Types
import Llama.Internal.Types.Params
import Foreign

-- | Check if the backend supports remote procedure calls (RPC).
supportsRpc :: IO Bool
supportsRpc :: IO Bool
supportsRpc = CBool -> Bool
forall a. (Eq a, Num a) => a -> Bool
toBool (CBool -> Bool) -> IO CBool -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO CBool
llama_supports_rpc

-- | Check if the backend supports GPU offloading.
supportsGpuOffload :: IO Bool
supportsGpuOffload :: IO Bool
supportsGpuOffload = CBool -> Bool
forall a. (Eq a, Num a) => a -> Bool
toBool (CBool -> Bool) -> IO CBool -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO CBool
llama_supports_gpu_offload

-- | Check if the backend supports locking model memory into RAM (no swapping).
supportsMLock :: IO Bool
supportsMLock :: IO Bool
supportsMLock = CBool -> Bool
forall a. (Eq a, Num a) => a -> Bool
toBool (CBool -> Bool) -> IO CBool -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO CBool
llama_supports_mlock

-- | Check if the backend supports memory mapping models.
supportsMMap :: IO Bool
supportsMMap :: IO Bool
supportsMMap = CBool -> Bool
forall a. (Eq a, Num a) => a -> Bool
toBool (CBool -> Bool) -> IO CBool -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO CBool
llama_supports_mmap

-- | Get maximum number of devices supported by the backend (e.g., GPUs).
getMaxDevices :: IO Int
getMaxDevices :: IO Int
getMaxDevices = CSize -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CSize -> Int) -> IO CSize -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO CSize
llama_max_devices

-- | Get current time in microseconds since some unspecified epoch.
getTimeUs :: IO Int
getTimeUs :: IO Int
getTimeUs = Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> Int) -> IO Int64 -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO Int64
llama_time_us

-- | Get the maximum context size (n_ctx) of the model in the given context.
getContextSize :: Context -> IO Int
getContextSize :: Context -> IO Int
getContextSize (Context ForeignPtr CLlamaContext
ctxFPtr) =
  ForeignPtr CLlamaContext -> (Ptr CLlamaContext -> IO Int) -> IO Int
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO Int) -> IO Int)
-> (Ptr CLlamaContext -> IO Int) -> IO Int
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
    CUInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CUInt -> Int) -> IO CUInt -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CLlamaContext -> IO CUInt
llama_n_ctx (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr)

-- | Get the batch size (n_batch) used by the context.
getBatchSize :: Context -> IO Int
getBatchSize :: Context -> IO Int
getBatchSize (Context ForeignPtr CLlamaContext
ctxFPtr) =
  ForeignPtr CLlamaContext -> (Ptr CLlamaContext -> IO Int) -> IO Int
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO Int) -> IO Int)
-> (Ptr CLlamaContext -> IO Int) -> IO Int
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
    CUInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CUInt -> Int) -> IO CUInt -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CLlamaContext -> IO CUInt
llama_n_batch (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr)

-- | Get the unbatched size (n_ubatch).
getUnbatchedSize :: Context -> IO Int
getUnbatchedSize :: Context -> IO Int
getUnbatchedSize (Context ForeignPtr CLlamaContext
ctxFPtr) =
  ForeignPtr CLlamaContext -> (Ptr CLlamaContext -> IO Int) -> IO Int
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO Int) -> IO Int)
-> (Ptr CLlamaContext -> IO Int) -> IO Int
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
    CUInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CUInt -> Int) -> IO CUInt -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CLlamaContext -> IO CUInt
llama_n_ubatch (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr)

-- | Get the maximum number of sequences supported.
getMaxSeqCount :: Context -> IO Int
getMaxSeqCount :: Context -> IO Int
getMaxSeqCount (Context ForeignPtr CLlamaContext
ctxFPtr) =
  ForeignPtr CLlamaContext -> (Ptr CLlamaContext -> IO Int) -> IO Int
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO Int) -> IO Int)
-> (Ptr CLlamaContext -> IO Int) -> IO Int
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
    CUInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CUInt -> Int) -> IO CUInt -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CLlamaContext -> IO CUInt
llama_n_seq_max (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr)

-- | Get the pooling type used by the context.
getPoolingType :: Context -> IO (Maybe LlamaPoolingType)
getPoolingType :: Context -> IO (Maybe LlamaPoolingType)
getPoolingType (Context ForeignPtr CLlamaContext
ctxFPtr) =
  ForeignPtr CLlamaContext
-> (Ptr CLlamaContext -> IO (Maybe LlamaPoolingType))
-> IO (Maybe LlamaPoolingType)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO (Maybe LlamaPoolingType))
 -> IO (Maybe LlamaPoolingType))
-> (Ptr CLlamaContext -> IO (Maybe LlamaPoolingType))
-> IO (Maybe LlamaPoolingType)
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr -> (Ptr CInt -> IO (Maybe LlamaPoolingType))
-> IO (Maybe LlamaPoolingType)
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr CInt -> IO (Maybe LlamaPoolingType))
 -> IO (Maybe LlamaPoolingType))
-> (Ptr CInt -> IO (Maybe LlamaPoolingType))
-> IO (Maybe LlamaPoolingType)
forall a b. (a -> b) -> a -> b
$ \Ptr CInt
outPtr -> do
    CLlamaContext -> Ptr CInt -> IO ()
c_llama_pooling_type_into (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr) Ptr CInt
outPtr
    val <- Ptr CInt -> IO CInt
forall a. Storable a => Ptr a -> IO a
peek Ptr CInt
outPtr
    case val of
      -1 -> Maybe LlamaPoolingType -> IO (Maybe LlamaPoolingType)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe LlamaPoolingType
forall a. Maybe a
Nothing
      CInt
_  -> Maybe LlamaPoolingType -> IO (Maybe LlamaPoolingType)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe LlamaPoolingType -> IO (Maybe LlamaPoolingType))
-> Maybe LlamaPoolingType -> IO (Maybe LlamaPoolingType)
forall a b. (a -> b) -> a -> b
$ CInt -> Maybe LlamaPoolingType
fromLlamaRopePoolingType CInt
val

-- | Detach the internal threadpool from the context.
detachThreadPool :: Context -> IO ()
detachThreadPool :: Context -> IO ()
detachThreadPool (Context ForeignPtr CLlamaContext
ctxFPtr) =
  ForeignPtr CLlamaContext -> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO ()) -> IO ())
-> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
    CLlamaContext -> IO ()
c_llama_detach_threadpool (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr)

-- | Allocate and initialize a new 'LlamaContextParams' with defaults.
defaultContextParams :: IO LlamaContextParams
defaultContextParams :: IO LlamaContextParams
defaultContextParams = do
  (Ptr LlamaContextParams -> IO LlamaContextParams)
-> IO LlamaContextParams
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr LlamaContextParams -> IO LlamaContextParams)
 -> IO LlamaContextParams)
-> (Ptr LlamaContextParams -> IO LlamaContextParams)
-> IO LlamaContextParams
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaContextParams
ptr -> do
    CLlamaContextParams -> IO ()
c_llama_context_default_params_into (Ptr LlamaContextParams -> CLlamaContextParams
CLlamaContextParams Ptr LlamaContextParams
ptr)
    Ptr LlamaContextParams -> IO LlamaContextParams
forall a. Storable a => Ptr a -> IO a
peek Ptr LlamaContextParams
ptr