{- |
Module      : Llama.State
Description : High level State interface for llama-cpp
Copyright   : (c) 2025 Tushar Adhatrao
License     : MIT
Maintainer  : Tushar Adhatrao <tusharadhatrao@gmail.com>
-}
module Llama.State (
   getStateSize
, getStateData
, setStateData
, loadStateFromFile
, saveStateToFile
, getSequenceStateSize
, setSequenceStateData
, saveSequenceStateToFile
, loadSequenceStateFromFile
) where

import Llama.Internal.Foreign
import Llama.Internal.Types
import Foreign
import Foreign.C.String
import qualified Data.ByteString as BS
import Data.ByteString (ByteString)

-- | Get the size of the state
getStateSize :: Context -> IO Word64
getStateSize :: Context -> IO Word64
getStateSize (Context ForeignPtr CLlamaContext
ctxFPtr) = do
  ForeignPtr CLlamaContext
-> (Ptr CLlamaContext -> IO Word64) -> IO Word64
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO Word64) -> IO Word64)
-> (Ptr CLlamaContext -> IO Word64) -> IO Word64
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
    CSize -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CSize -> Word64) -> IO CSize -> IO Word64
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CLlamaContext -> IO CSize
c_llama_state_get_size (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr)

-- | Get the state data
getStateData :: Context -> IO ByteString
getStateData :: Context -> IO ByteString
getStateData (Context ForeignPtr CLlamaContext
ctxFPtr) = do
  ForeignPtr CLlamaContext
-> (Ptr CLlamaContext -> IO ByteString) -> IO ByteString
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO ByteString) -> IO ByteString)
-> (Ptr CLlamaContext -> IO ByteString) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr -> do
    size <- CLlamaContext -> IO CSize
c_llama_state_get_size (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr)
    allocaBytes (fromIntegral size) $ \Ptr Word8
dstPtr -> do
      _ <- CLlamaContext -> Ptr Word8 -> CSize -> IO CSize
c_llama_state_get_data (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr) Ptr Word8
dstPtr CSize
size
      BS.pack <$> peekArray (fromIntegral size) dstPtr

-- | Set the state data
setStateData :: Context -> ByteString -> IO ()
setStateData :: Context -> ByteString -> IO ()
setStateData (Context ForeignPtr CLlamaContext
ctxFPtr) ByteString
bs = do
  ForeignPtr CLlamaContext -> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO ()) -> IO ())
-> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr -> do
    [Word8] -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. Storable a => [a] -> (Ptr a -> IO b) -> IO b
withArray (ByteString -> [Word8]
BS.unpack ByteString
bs) ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
srcPtr -> do
      _ <- CLlamaContext -> Ptr Word8 -> CSize -> IO CSize
c_llama_state_set_data (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr) Ptr Word8
srcPtr (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int
BS.length ByteString
bs))
      return ()

-- | Load a state from a file
loadStateFromFile :: Context -> FilePath -> [LlamaToken] -> IO [LlamaToken]
loadStateFromFile :: Context -> FilePath -> [LlamaToken] -> IO [LlamaToken]
loadStateFromFile (Context ForeignPtr CLlamaContext
ctxFPtr) FilePath
pathSession [LlamaToken]
tokens = do
  ForeignPtr CLlamaContext
-> (Ptr CLlamaContext -> IO [LlamaToken]) -> IO [LlamaToken]
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO [LlamaToken]) -> IO [LlamaToken])
-> (Ptr CLlamaContext -> IO [LlamaToken]) -> IO [LlamaToken]
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr -> do
    FilePath -> (CString -> IO [LlamaToken]) -> IO [LlamaToken]
forall a. FilePath -> (CString -> IO a) -> IO a
withCString FilePath
pathSession ((CString -> IO [LlamaToken]) -> IO [LlamaToken])
-> (CString -> IO [LlamaToken]) -> IO [LlamaToken]
forall a b. (a -> b) -> a -> b
$ \CString
cPathSession -> do
      Int -> (Ptr LlamaToken -> IO [LlamaToken]) -> IO [LlamaToken]
forall a b. Storable a => Int -> (Ptr a -> IO b) -> IO b
allocaArray ([LlamaToken] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LlamaToken]
tokens) ((Ptr LlamaToken -> IO [LlamaToken]) -> IO [LlamaToken])
-> (Ptr LlamaToken -> IO [LlamaToken]) -> IO [LlamaToken]
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaToken
tokensOutPtr -> do
        (Ptr CSize -> IO [LlamaToken]) -> IO [LlamaToken]
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr CSize -> IO [LlamaToken]) -> IO [LlamaToken])
-> (Ptr CSize -> IO [LlamaToken]) -> IO [LlamaToken]
forall a b. (a -> b) -> a -> b
$ \Ptr CSize
nTokenCountOutPtr -> do
          success <- CLlamaContext
-> CString -> Ptr LlamaToken -> CSize -> Ptr CSize -> IO CBool
c_llama_state_load_file
            (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr)
            CString
cPathSession
            Ptr LlamaToken
tokensOutPtr
            (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([LlamaToken] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LlamaToken]
tokens))
            Ptr CSize
nTokenCountOutPtr
          if success == 0
            then return []
            else do
            tokenCount <- peek nTokenCountOutPtr
            peekArray (fromIntegral tokenCount) tokensOutPtr

-- | Save a state to a file
saveStateToFile :: Context -> FilePath -> [LlamaToken] -> IO Bool
saveStateToFile :: Context -> FilePath -> [LlamaToken] -> IO Bool
saveStateToFile (Context ForeignPtr CLlamaContext
ctxFPtr) FilePath
pathSession [LlamaToken]
tokens = do
  ForeignPtr CLlamaContext
-> (Ptr CLlamaContext -> IO Bool) -> IO Bool
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO Bool) -> IO Bool)
-> (Ptr CLlamaContext -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr -> do
    FilePath -> (CString -> IO Bool) -> IO Bool
forall a. FilePath -> (CString -> IO a) -> IO a
withCString FilePath
pathSession ((CString -> IO Bool) -> IO Bool)
-> (CString -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \CString
cPathSession -> do
      [LlamaToken] -> (Ptr LlamaToken -> IO Bool) -> IO Bool
forall a b. Storable a => [a] -> (Ptr a -> IO b) -> IO b
withArray [LlamaToken]
tokens ((Ptr LlamaToken -> IO Bool) -> IO Bool)
-> (Ptr LlamaToken -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaToken
tokensPtr -> do
        CBool -> Bool
forall a. (Eq a, Num a) => a -> Bool
toBool (CBool -> Bool) -> IO CBool -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CLlamaContext -> CString -> Ptr LlamaToken -> CSize -> IO CBool
c_llama_state_save_file
          (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr)
          CString
cPathSession
          Ptr LlamaToken
tokensPtr
          (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([LlamaToken] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LlamaToken]
tokens))

-- | Get the size of a sequence in the state
getSequenceStateSize :: Context -> LlamaSeqId -> IO Word64
getSequenceStateSize :: Context -> LlamaToken -> IO Word64
getSequenceStateSize (Context ForeignPtr CLlamaContext
ctxFPtr) LlamaToken
seqId = do
  ForeignPtr CLlamaContext
-> (Ptr CLlamaContext -> IO Word64) -> IO Word64
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO Word64) -> IO Word64)
-> (Ptr CLlamaContext -> IO Word64) -> IO Word64
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
    CSize -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CSize -> Word64) -> IO CSize -> IO Word64
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CLlamaContext -> LlamaToken -> IO CSize
c_llama_state_seq_get_size (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr) LlamaToken
seqId

-- | Set the state data for a sequence
setSequenceStateData :: Context -> ByteString -> LlamaSeqId -> IO Word64
setSequenceStateData :: Context -> ByteString -> LlamaToken -> IO Word64
setSequenceStateData (Context ForeignPtr CLlamaContext
ctxFPtr) ByteString
bs LlamaToken
seqId = do
  ForeignPtr CLlamaContext
-> (Ptr CLlamaContext -> IO Word64) -> IO Word64
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO Word64) -> IO Word64)
-> (Ptr CLlamaContext -> IO Word64) -> IO Word64
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr -> do
    [Word8] -> (Ptr Word8 -> IO Word64) -> IO Word64
forall a b. Storable a => [a] -> (Ptr a -> IO b) -> IO b
withArray (ByteString -> [Word8]
BS.unpack ByteString
bs) ((Ptr Word8 -> IO Word64) -> IO Word64)
-> (Ptr Word8 -> IO Word64) -> IO Word64
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
srcPtr -> do
      CSize -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CSize -> Word64) -> IO CSize -> IO Word64
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
        CLlamaContext -> Ptr Word8 -> CSize -> LlamaToken -> IO CSize
c_llama_state_seq_set_data (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr) Ptr Word8
srcPtr (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int
BS.length ByteString
bs)) LlamaToken
seqId

-- | Save a sequence state to a file
saveSequenceStateToFile :: Context -> FilePath -> LlamaSeqId -> [LlamaToken] -> IO Word64
saveSequenceStateToFile :: Context -> FilePath -> LlamaToken -> [LlamaToken] -> IO Word64
saveSequenceStateToFile (Context ForeignPtr CLlamaContext
ctxFPtr) FilePath
filepath LlamaToken
seqId [LlamaToken]
tokens = do
  ForeignPtr CLlamaContext
-> (Ptr CLlamaContext -> IO Word64) -> IO Word64
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO Word64) -> IO Word64)
-> (Ptr CLlamaContext -> IO Word64) -> IO Word64
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr -> do
    [LlamaToken] -> (Ptr LlamaToken -> IO Word64) -> IO Word64
forall a b. Storable a => [a] -> (Ptr a -> IO b) -> IO b
withArray [LlamaToken]
tokens ((Ptr LlamaToken -> IO Word64) -> IO Word64)
-> (Ptr LlamaToken -> IO Word64) -> IO Word64
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaToken
tokenPtr -> do
      FilePath -> (CString -> IO Word64) -> IO Word64
forall a. FilePath -> (CString -> IO a) -> IO a
withCString FilePath
filepath ((CString -> IO Word64) -> IO Word64)
-> (CString -> IO Word64) -> IO Word64
forall a b. (a -> b) -> a -> b
$ \CString
cfilepath -> do
        CSize -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CSize -> Word64) -> IO CSize -> IO Word64
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
          CLlamaContext
-> CString -> LlamaToken -> Ptr LlamaToken -> CSize -> IO CSize
c_llama_state_seq_save_file (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr) CString
cfilepath LlamaToken
seqId Ptr LlamaToken
tokenPtr (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([LlamaToken] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LlamaToken]
tokens))

-- | Load a sequence state from a file
loadSequenceStateFromFile :: Context -> FilePath -> LlamaSeqId -> [LlamaToken] -> IO [LlamaToken]
loadSequenceStateFromFile :: Context
-> FilePath -> LlamaToken -> [LlamaToken] -> IO [LlamaToken]
loadSequenceStateFromFile (Context ForeignPtr CLlamaContext
ctxFPtr) FilePath
filepath LlamaToken
destSeqId [LlamaToken]
tokens = do
  ForeignPtr CLlamaContext
-> (Ptr CLlamaContext -> IO [LlamaToken]) -> IO [LlamaToken]
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO [LlamaToken]) -> IO [LlamaToken])
-> (Ptr CLlamaContext -> IO [LlamaToken]) -> IO [LlamaToken]
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr -> do
    FilePath -> (CString -> IO [LlamaToken]) -> IO [LlamaToken]
forall a. FilePath -> (CString -> IO a) -> IO a
withCString FilePath
filepath ((CString -> IO [LlamaToken]) -> IO [LlamaToken])
-> (CString -> IO [LlamaToken]) -> IO [LlamaToken]
forall a b. (a -> b) -> a -> b
$ \CString
cfilepath -> do
      Int -> (Ptr LlamaToken -> IO [LlamaToken]) -> IO [LlamaToken]
forall a b. Storable a => Int -> (Ptr a -> IO b) -> IO b
allocaArray ([LlamaToken] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LlamaToken]
tokens) ((Ptr LlamaToken -> IO [LlamaToken]) -> IO [LlamaToken])
-> (Ptr LlamaToken -> IO [LlamaToken]) -> IO [LlamaToken]
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaToken
tokensOutPtr -> do
        (Ptr CSize -> IO [LlamaToken]) -> IO [LlamaToken]
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr CSize -> IO [LlamaToken]) -> IO [LlamaToken])
-> (Ptr CSize -> IO [LlamaToken]) -> IO [LlamaToken]
forall a b. (a -> b) -> a -> b
$ \Ptr CSize
nTokenCountOutPtr -> do
          _ <- CLlamaContext
-> CString
-> LlamaToken
-> Ptr LlamaToken
-> CSize
-> Ptr CSize
-> IO CSize
c_llama_state_seq_load_file
            (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr)
            CString
cfilepath
            LlamaToken
destSeqId
            Ptr LlamaToken
tokensOutPtr
            (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([LlamaToken] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LlamaToken]
tokens))
            Ptr CSize
nTokenCountOutPtr
          tokenCount <- peek nTokenCountOutPtr
          peekArray (fromIntegral tokenCount) tokensOutPtr