{- |
Module      : Llama.KVCache
Description : High level KVCache interface for llama-cpp
Copyright   : (c) 2025 Tushar Adhatrao
License     : MIT
Maintainer  : Tushar Adhatrao <tusharadhatrao@gmail.com>
-}
module Llama.KVCache (
    kvCacheViewInit
, kvSelfSeqAdd
, kvSelfSeqDiv
, kvSelfSeqPosMax
, kvSelfDefrag
, kvSelfCanShift
, kvSelfUpdate
, kvSelfSeqKeep
, kvSelfSeqCopy
, kvSelfSeqRemove
, kvSelfClear
, kvSelfUsedCells
, kvSelfNumTokens
, kvCacheViewUpdate
) where

import Llama.Internal.Types
import Foreign
import Llama.Internal.Foreign

{-
TODO: no free function for struct llama_kv_cache *
No one is using struct llama_kv_cache
-- | Get the KV cache associated with this context.
getKVCache :: Context -> IO KVCache
getKVCache (Context ctxFPtr) =
  withForeignPtr ctxFPtr $ \ctxPtr -> do
    ptr <- c_llama_get_kv_self (CLlamaContext ctxPtr)
    -- We assume finalization is handled elsewhere or use a no-op finalizer
    fp <- newForeignPtr_ ptr -- assumes Ptr () is actually Ptr CLlamaKVCache
    return $ KVCache fp
-}

-- | Convenience wrapper that allocates a LlamaKvCacheView and initializes it
kvCacheViewInit :: Context -> Int -> IO LlamaKvCacheView
kvCacheViewInit :: Context -> Int -> IO LlamaKvCacheView
kvCacheViewInit (Context ForeignPtr CLlamaContext
fPtr) Int
n_seq_max_ = do
    ForeignPtr CLlamaContext
-> (Ptr CLlamaContext -> IO LlamaKvCacheView)
-> IO LlamaKvCacheView
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
fPtr ((Ptr CLlamaContext -> IO LlamaKvCacheView) -> IO LlamaKvCacheView)
-> (Ptr CLlamaContext -> IO LlamaKvCacheView)
-> IO LlamaKvCacheView
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
contextPtr -> do
      (Ptr LlamaKvCacheView -> IO LlamaKvCacheView)
-> IO LlamaKvCacheView
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr LlamaKvCacheView -> IO LlamaKvCacheView)
 -> IO LlamaKvCacheView)
-> (Ptr LlamaKvCacheView -> IO LlamaKvCacheView)
-> IO LlamaKvCacheView
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaKvCacheView
pView -> do
        CLlamaContext -> CInt -> Ptr LlamaKvCacheView -> IO ()
c_llama_kv_cache_view_init_into (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
contextPtr) (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n_seq_max_) Ptr LlamaKvCacheView
pView
        Ptr LlamaKvCacheView -> IO LlamaKvCacheView
forall a. Storable a => Ptr a -> IO a
peek Ptr LlamaKvCacheView
pView

-- | Shift positions in a sequence by a delta.
kvSelfSeqAdd ::
  Context ->
  LlamaSeqId -> -- seq_id
  LlamaPos -> -- p0
  LlamaPos -> -- p1
  LlamaPos -> -- delta
  IO ()
kvSelfSeqAdd :: Context -> CInt -> CInt -> CInt -> CInt -> IO ()
kvSelfSeqAdd (Context ForeignPtr CLlamaContext
ctxFPtr) CInt
seqId CInt
p0 CInt
p1 CInt
delta =
  ForeignPtr CLlamaContext -> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO ()) -> IO ())
-> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
    CLlamaContext -> CInt -> CInt -> CInt -> CInt -> IO ()
c_llama_kv_self_seq_add (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr) CInt
seqId CInt
p0 CInt
p1 CInt
delta

-- | Divide positions in a sequence by a factor (used for attention windowing).
kvSelfSeqDiv ::
  Context ->
  LlamaSeqId -> -- seq_id
  LlamaPos -> -- p0
  LlamaPos -> -- p1
  Int -> -- d
  IO ()
kvSelfSeqDiv :: Context -> CInt -> CInt -> CInt -> Int -> IO ()
kvSelfSeqDiv (Context ForeignPtr CLlamaContext
ctxFPtr) CInt
seqId CInt
p0 CInt
p1 Int
d =
  ForeignPtr CLlamaContext -> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO ()) -> IO ())
-> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
    CLlamaContext -> CInt -> CInt -> CInt -> CInt -> IO ()
c_llama_kv_self_seq_div (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr) CInt
seqId CInt
p0 CInt
p1 (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
d)

-- | Get the maximum position stored in the KV cache for a given sequence.
kvSelfSeqPosMax ::
  Context ->
  LlamaSeqId -> -- seq_id
  IO LlamaPos
kvSelfSeqPosMax :: Context -> CInt -> IO CInt
kvSelfSeqPosMax (Context ForeignPtr CLlamaContext
ctxFPtr) CInt
seqId =
  ForeignPtr CLlamaContext
-> (Ptr CLlamaContext -> IO CInt) -> IO CInt
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO CInt) -> IO CInt)
-> (Ptr CLlamaContext -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
    CLlamaContext -> CInt -> IO CInt
c_llama_kv_self_seq_pos_max (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr) CInt
seqId

-- | Defragment the KV cache to optimize memory usage.
kvSelfDefrag ::
  Context ->
  IO ()
kvSelfDefrag :: Context -> IO ()
kvSelfDefrag (Context ForeignPtr CLlamaContext
ctxFPtr) =
  ForeignPtr CLlamaContext -> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO ()) -> IO ())
-> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
    CLlamaContext -> IO ()
c_llama_kv_self_defrag (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr)

-- | Check whether the KV cache can be shifted (e.g., due to full buffer).
kvSelfCanShift ::
  Context ->
  IO Bool
kvSelfCanShift :: Context -> IO Bool
kvSelfCanShift (Context ForeignPtr CLlamaContext
ctxFPtr) =
  ForeignPtr CLlamaContext
-> (Ptr CLlamaContext -> IO Bool) -> IO Bool
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO Bool) -> IO Bool)
-> (Ptr CLlamaContext -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
    CBool -> Bool
forall a. (Eq a, Num a) => a -> Bool
toBool (CBool -> Bool) -> IO CBool -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr CLlamaContext -> IO CBool
c_llama_kv_self_can_shift (Ptr CLlamaContext -> Ptr CLlamaContext
forall a b. Ptr a -> Ptr b
castPtr Ptr CLlamaContext
ctxPtr)

-- | Update the KV cache's internal state (e.g., after manual modifications).
kvSelfUpdate ::
  Context ->
  IO ()
kvSelfUpdate :: Context -> IO ()
kvSelfUpdate (Context ForeignPtr CLlamaContext
ctxFPtr) =
  ForeignPtr CLlamaContext -> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO ()) -> IO ())
-> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
    CLlamaContext -> IO ()
c_llama_kv_self_update (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr)

-- | Prevent a sequence from being removed during KV cache cleanup.
kvSelfSeqKeep ::
  Context ->
  LlamaSeqId -> -- seq_id
  IO ()
kvSelfSeqKeep :: Context -> CInt -> IO ()
kvSelfSeqKeep (Context ForeignPtr CLlamaContext
ctxFPtr) CInt
seqId =
  ForeignPtr CLlamaContext -> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO ()) -> IO ())
-> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
    CLlamaContext -> CInt -> IO ()
c_llama_kv_self_seq_keep (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr) CInt
seqId

-- | Copy a range of positions from one sequence to another.
kvSelfSeqCopy ::
  Context ->
  LlamaSeqId -> -- seq_id_src
  LlamaSeqId -> -- seq_id_dst
  LlamaPos -> -- p0
  LlamaPos -> -- p1
  IO ()
kvSelfSeqCopy :: Context -> CInt -> CInt -> CInt -> CInt -> IO ()
kvSelfSeqCopy (Context ForeignPtr CLlamaContext
ctxFPtr) CInt
srcId CInt
dstId CInt
p0 CInt
p1 =
  ForeignPtr CLlamaContext -> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO ()) -> IO ())
-> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
    CLlamaContext -> CInt -> CInt -> CInt -> CInt -> IO ()
c_llama_kv_self_seq_cp (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr) CInt
srcId CInt
dstId CInt
p0 CInt
p1

-- | Remove a range of positions from a sequence.
kvSelfSeqRemove ::
  Context ->
  LlamaSeqId -> -- seq_id
  LlamaPos -> -- p0
  LlamaPos -> -- p1
  IO Bool
kvSelfSeqRemove :: Context -> CInt -> CInt -> CInt -> IO Bool
kvSelfSeqRemove (Context ForeignPtr CLlamaContext
ctxFPtr) CInt
seqId CInt
p0 CInt
p1 =
  ForeignPtr CLlamaContext
-> (Ptr CLlamaContext -> IO Bool) -> IO Bool
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO Bool) -> IO Bool)
-> (Ptr CLlamaContext -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr -> do
    result <- CLlamaContext -> CInt -> CInt -> CInt -> IO CBool
c_llama_kv_self_seq_rm (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr) CInt
seqId CInt
p0 CInt
p1
    return $ toBool result

-- | Clear all sequences from the KV cache.
kvSelfClear ::
  Context ->
  IO ()
kvSelfClear :: Context -> IO ()
kvSelfClear (Context ForeignPtr CLlamaContext
ctxFPtr) =
  ForeignPtr CLlamaContext -> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO ()) -> IO ())
-> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
    CLlamaContext -> IO ()
c_llama_kv_self_clear (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr)

-- | Get number of used cells in the KV cache.
kvSelfUsedCells ::
  Context ->
  IO Int
kvSelfUsedCells :: Context -> IO Int
kvSelfUsedCells (Context ForeignPtr CLlamaContext
ctxFPtr) =
  ForeignPtr CLlamaContext -> (Ptr CLlamaContext -> IO Int) -> IO Int
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO Int) -> IO Int)
-> (Ptr CLlamaContext -> IO Int) -> IO Int
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
    CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> Int) -> IO CInt -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CLlamaContext -> IO CInt
c_llama_kv_self_used_cells (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr)

-- | Get total number of tokens currently stored in the KV cache.
kvSelfNumTokens ::
  Context ->
  IO Int
kvSelfNumTokens :: Context -> IO Int
kvSelfNumTokens (Context ForeignPtr CLlamaContext
ctxFPtr) =
  ForeignPtr CLlamaContext -> (Ptr CLlamaContext -> IO Int) -> IO Int
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO Int) -> IO Int)
-> (Ptr CLlamaContext -> IO Int) -> IO Int
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
    CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> Int) -> IO CInt -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CLlamaContext -> IO CInt
c_llama_kv_self_n_tokens (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr)

-- | Update the KV cache view to reflect current state.
kvCacheViewUpdate ::
  Context ->
  Ptr LlamaKvCacheView -> -- view
  IO ()
kvCacheViewUpdate :: Context -> Ptr LlamaKvCacheView -> IO ()
kvCacheViewUpdate (Context ForeignPtr CLlamaContext
ctxFPtr) Ptr LlamaKvCacheView
view =
  ForeignPtr CLlamaContext -> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO ()) -> IO ())
-> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
    CLlamaContext -> Ptr LlamaKvCacheView -> IO ()
c_llama_kv_cache_view_update (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr) Ptr LlamaKvCacheView
view