module Llama.State (
getStateSize
, getStateData
, setStateData
, loadStateFromFile
, saveStateToFile
, getSequenceStateSize
, setSequenceStateData
, saveSequenceStateToFile
, loadSequenceStateFromFile
) where
import Llama.Internal.Foreign
import Llama.Internal.Types
import Foreign
import Foreign.C.String
import qualified Data.ByteString as BS
import Data.ByteString (ByteString)
getStateSize :: Context -> IO Word64
getStateSize :: Context -> IO Word64
getStateSize (Context ForeignPtr CLlamaContext
ctxFPtr) = do
ForeignPtr CLlamaContext
-> (Ptr CLlamaContext -> IO Word64) -> IO Word64
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO Word64) -> IO Word64)
-> (Ptr CLlamaContext -> IO Word64) -> IO Word64
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
CSize -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CSize -> Word64) -> IO CSize -> IO Word64
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CLlamaContext -> IO CSize
c_llama_state_get_size (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr)
getStateData :: Context -> IO ByteString
getStateData :: Context -> IO ByteString
getStateData (Context ForeignPtr CLlamaContext
ctxFPtr) = do
ForeignPtr CLlamaContext
-> (Ptr CLlamaContext -> IO ByteString) -> IO ByteString
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO ByteString) -> IO ByteString)
-> (Ptr CLlamaContext -> IO ByteString) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr -> do
size <- CLlamaContext -> IO CSize
c_llama_state_get_size (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr)
allocaBytes (fromIntegral size) $ \Ptr Word8
dstPtr -> do
_ <- CLlamaContext -> Ptr Word8 -> CSize -> IO CSize
c_llama_state_get_data (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr) Ptr Word8
dstPtr CSize
size
BS.pack <$> peekArray (fromIntegral size) dstPtr
setStateData :: Context -> ByteString -> IO ()
setStateData :: Context -> ByteString -> IO ()
setStateData (Context ForeignPtr CLlamaContext
ctxFPtr) ByteString
bs = do
ForeignPtr CLlamaContext -> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO ()) -> IO ())
-> (Ptr CLlamaContext -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr -> do
[Word8] -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. Storable a => [a] -> (Ptr a -> IO b) -> IO b
withArray (ByteString -> [Word8]
BS.unpack ByteString
bs) ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
srcPtr -> do
_ <- CLlamaContext -> Ptr Word8 -> CSize -> IO CSize
c_llama_state_set_data (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr) Ptr Word8
srcPtr (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int
BS.length ByteString
bs))
return ()
loadStateFromFile :: Context -> FilePath -> [LlamaToken] -> IO [LlamaToken]
loadStateFromFile :: Context -> FilePath -> [LlamaToken] -> IO [LlamaToken]
loadStateFromFile (Context ForeignPtr CLlamaContext
ctxFPtr) FilePath
pathSession [LlamaToken]
tokens = do
ForeignPtr CLlamaContext
-> (Ptr CLlamaContext -> IO [LlamaToken]) -> IO [LlamaToken]
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO [LlamaToken]) -> IO [LlamaToken])
-> (Ptr CLlamaContext -> IO [LlamaToken]) -> IO [LlamaToken]
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr -> do
FilePath -> (CString -> IO [LlamaToken]) -> IO [LlamaToken]
forall a. FilePath -> (CString -> IO a) -> IO a
withCString FilePath
pathSession ((CString -> IO [LlamaToken]) -> IO [LlamaToken])
-> (CString -> IO [LlamaToken]) -> IO [LlamaToken]
forall a b. (a -> b) -> a -> b
$ \CString
cPathSession -> do
Int -> (Ptr LlamaToken -> IO [LlamaToken]) -> IO [LlamaToken]
forall a b. Storable a => Int -> (Ptr a -> IO b) -> IO b
allocaArray ([LlamaToken] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LlamaToken]
tokens) ((Ptr LlamaToken -> IO [LlamaToken]) -> IO [LlamaToken])
-> (Ptr LlamaToken -> IO [LlamaToken]) -> IO [LlamaToken]
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaToken
tokensOutPtr -> do
(Ptr CSize -> IO [LlamaToken]) -> IO [LlamaToken]
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr CSize -> IO [LlamaToken]) -> IO [LlamaToken])
-> (Ptr CSize -> IO [LlamaToken]) -> IO [LlamaToken]
forall a b. (a -> b) -> a -> b
$ \Ptr CSize
nTokenCountOutPtr -> do
success <- CLlamaContext
-> CString -> Ptr LlamaToken -> CSize -> Ptr CSize -> IO CBool
c_llama_state_load_file
(Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr)
CString
cPathSession
Ptr LlamaToken
tokensOutPtr
(Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([LlamaToken] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LlamaToken]
tokens))
Ptr CSize
nTokenCountOutPtr
if success == 0
then return []
else do
tokenCount <- peek nTokenCountOutPtr
peekArray (fromIntegral tokenCount) tokensOutPtr
saveStateToFile :: Context -> FilePath -> [LlamaToken] -> IO Bool
saveStateToFile :: Context -> FilePath -> [LlamaToken] -> IO Bool
saveStateToFile (Context ForeignPtr CLlamaContext
ctxFPtr) FilePath
pathSession [LlamaToken]
tokens = do
ForeignPtr CLlamaContext
-> (Ptr CLlamaContext -> IO Bool) -> IO Bool
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO Bool) -> IO Bool)
-> (Ptr CLlamaContext -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr -> do
FilePath -> (CString -> IO Bool) -> IO Bool
forall a. FilePath -> (CString -> IO a) -> IO a
withCString FilePath
pathSession ((CString -> IO Bool) -> IO Bool)
-> (CString -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \CString
cPathSession -> do
[LlamaToken] -> (Ptr LlamaToken -> IO Bool) -> IO Bool
forall a b. Storable a => [a] -> (Ptr a -> IO b) -> IO b
withArray [LlamaToken]
tokens ((Ptr LlamaToken -> IO Bool) -> IO Bool)
-> (Ptr LlamaToken -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaToken
tokensPtr -> do
CBool -> Bool
forall a. (Eq a, Num a) => a -> Bool
toBool (CBool -> Bool) -> IO CBool -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CLlamaContext -> CString -> Ptr LlamaToken -> CSize -> IO CBool
c_llama_state_save_file
(Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr)
CString
cPathSession
Ptr LlamaToken
tokensPtr
(Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([LlamaToken] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LlamaToken]
tokens))
getSequenceStateSize :: Context -> LlamaSeqId -> IO Word64
getSequenceStateSize :: Context -> LlamaToken -> IO Word64
getSequenceStateSize (Context ForeignPtr CLlamaContext
ctxFPtr) LlamaToken
seqId = do
ForeignPtr CLlamaContext
-> (Ptr CLlamaContext -> IO Word64) -> IO Word64
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO Word64) -> IO Word64)
-> (Ptr CLlamaContext -> IO Word64) -> IO Word64
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
CSize -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CSize -> Word64) -> IO CSize -> IO Word64
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CLlamaContext -> LlamaToken -> IO CSize
c_llama_state_seq_get_size (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr) LlamaToken
seqId
setSequenceStateData :: Context -> ByteString -> LlamaSeqId -> IO Word64
setSequenceStateData :: Context -> ByteString -> LlamaToken -> IO Word64
setSequenceStateData (Context ForeignPtr CLlamaContext
ctxFPtr) ByteString
bs LlamaToken
seqId = do
ForeignPtr CLlamaContext
-> (Ptr CLlamaContext -> IO Word64) -> IO Word64
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO Word64) -> IO Word64)
-> (Ptr CLlamaContext -> IO Word64) -> IO Word64
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr -> do
[Word8] -> (Ptr Word8 -> IO Word64) -> IO Word64
forall a b. Storable a => [a] -> (Ptr a -> IO b) -> IO b
withArray (ByteString -> [Word8]
BS.unpack ByteString
bs) ((Ptr Word8 -> IO Word64) -> IO Word64)
-> (Ptr Word8 -> IO Word64) -> IO Word64
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
srcPtr -> do
CSize -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CSize -> Word64) -> IO CSize -> IO Word64
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
CLlamaContext -> Ptr Word8 -> CSize -> LlamaToken -> IO CSize
c_llama_state_seq_set_data (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr) Ptr Word8
srcPtr (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int
BS.length ByteString
bs)) LlamaToken
seqId
saveSequenceStateToFile :: Context -> FilePath -> LlamaSeqId -> [LlamaToken] -> IO Word64
saveSequenceStateToFile :: Context -> FilePath -> LlamaToken -> [LlamaToken] -> IO Word64
saveSequenceStateToFile (Context ForeignPtr CLlamaContext
ctxFPtr) FilePath
filepath LlamaToken
seqId [LlamaToken]
tokens = do
ForeignPtr CLlamaContext
-> (Ptr CLlamaContext -> IO Word64) -> IO Word64
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO Word64) -> IO Word64)
-> (Ptr CLlamaContext -> IO Word64) -> IO Word64
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr -> do
[LlamaToken] -> (Ptr LlamaToken -> IO Word64) -> IO Word64
forall a b. Storable a => [a] -> (Ptr a -> IO b) -> IO b
withArray [LlamaToken]
tokens ((Ptr LlamaToken -> IO Word64) -> IO Word64)
-> (Ptr LlamaToken -> IO Word64) -> IO Word64
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaToken
tokenPtr -> do
FilePath -> (CString -> IO Word64) -> IO Word64
forall a. FilePath -> (CString -> IO a) -> IO a
withCString FilePath
filepath ((CString -> IO Word64) -> IO Word64)
-> (CString -> IO Word64) -> IO Word64
forall a b. (a -> b) -> a -> b
$ \CString
cfilepath -> do
CSize -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CSize -> Word64) -> IO CSize -> IO Word64
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
CLlamaContext
-> CString -> LlamaToken -> Ptr LlamaToken -> CSize -> IO CSize
c_llama_state_seq_save_file (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr) CString
cfilepath LlamaToken
seqId Ptr LlamaToken
tokenPtr (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([LlamaToken] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LlamaToken]
tokens))
loadSequenceStateFromFile :: Context -> FilePath -> LlamaSeqId -> [LlamaToken] -> IO [LlamaToken]
loadSequenceStateFromFile :: Context
-> FilePath -> LlamaToken -> [LlamaToken] -> IO [LlamaToken]
loadSequenceStateFromFile (Context ForeignPtr CLlamaContext
ctxFPtr) FilePath
filepath LlamaToken
destSeqId [LlamaToken]
tokens = do
ForeignPtr CLlamaContext
-> (Ptr CLlamaContext -> IO [LlamaToken]) -> IO [LlamaToken]
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO [LlamaToken]) -> IO [LlamaToken])
-> (Ptr CLlamaContext -> IO [LlamaToken]) -> IO [LlamaToken]
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr -> do
FilePath -> (CString -> IO [LlamaToken]) -> IO [LlamaToken]
forall a. FilePath -> (CString -> IO a) -> IO a
withCString FilePath
filepath ((CString -> IO [LlamaToken]) -> IO [LlamaToken])
-> (CString -> IO [LlamaToken]) -> IO [LlamaToken]
forall a b. (a -> b) -> a -> b
$ \CString
cfilepath -> do
Int -> (Ptr LlamaToken -> IO [LlamaToken]) -> IO [LlamaToken]
forall a b. Storable a => Int -> (Ptr a -> IO b) -> IO b
allocaArray ([LlamaToken] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LlamaToken]
tokens) ((Ptr LlamaToken -> IO [LlamaToken]) -> IO [LlamaToken])
-> (Ptr LlamaToken -> IO [LlamaToken]) -> IO [LlamaToken]
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaToken
tokensOutPtr -> do
(Ptr CSize -> IO [LlamaToken]) -> IO [LlamaToken]
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr CSize -> IO [LlamaToken]) -> IO [LlamaToken])
-> (Ptr CSize -> IO [LlamaToken]) -> IO [LlamaToken]
forall a b. (a -> b) -> a -> b
$ \Ptr CSize
nTokenCountOutPtr -> do
_ <- CLlamaContext
-> CString
-> LlamaToken
-> Ptr LlamaToken
-> CSize
-> Ptr CSize
-> IO CSize
c_llama_state_seq_load_file
(Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr)
CString
cfilepath
LlamaToken
destSeqId
Ptr LlamaToken
tokensOutPtr
(Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([LlamaToken] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LlamaToken]
tokens))
Ptr CSize
nTokenCountOutPtr
tokenCount <- peek nTokenCountOutPtr
peekArray (fromIntegral tokenCount) tokensOutPtr