module Llama.Sampler (
defaultSamplerChainParams
, initSampler
, getSamplerName
, acceptTokenWithSampler
, applySampler
, resetSampler
, cloneSampler
, initSamplerChain
, addSamplerToChain
, getSamplerFromChain
, getSamplerChainLength
, removeSamplerFromChain
, initGreedySampler
, initDistributedSampler
, initTopKSampler
, initTopPSampler
, initMinPSampler
, initTypicalSampler
, initTempSampler
, initTempExtSampler
, initXTCSampler
, initTopNSigmaSampler
, initMirostatSampler
, initMirostatV2Sampler
, initGrammarSampler
, initGrammarLazyPatternsSampler
, initPenaltiesSampler
, initDrySampler
, initLogitBiasSampler
, initInfillSampler
, getSamplerSeed
, sampleWithSampler
) where
import Llama.Internal.Types
import Llama.Internal.Types.Params
import Foreign
import Llama.Internal.Foreign
import Foreign.C.String
withStorable :: Storable a => a -> (Ptr a -> IO b) -> IO b
withStorable :: forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
withStorable a
x Ptr a -> IO b
f = (Ptr a -> IO b) -> IO b
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr a -> IO b) -> IO b) -> (Ptr a -> IO b) -> IO b
forall a b. (a -> b) -> a -> b
$ \Ptr a
ptr -> do
Ptr a -> a -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr a
ptr a
x
Ptr a -> IO b
f Ptr a
ptr
defaultSamplerChainParams :: IO LlamaSamplerChainParams
defaultSamplerChainParams :: IO LlamaSamplerChainParams
defaultSamplerChainParams = (Ptr LlamaSamplerChainParams -> IO LlamaSamplerChainParams)
-> IO LlamaSamplerChainParams
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr LlamaSamplerChainParams -> IO LlamaSamplerChainParams)
-> IO LlamaSamplerChainParams)
-> (Ptr LlamaSamplerChainParams -> IO LlamaSamplerChainParams)
-> IO LlamaSamplerChainParams
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaSamplerChainParams
paramsPtr -> do
Ptr LlamaSamplerChainParams -> IO ()
c_llama_sampler_chain_default_params_into Ptr LlamaSamplerChainParams
paramsPtr
Ptr LlamaSamplerChainParams -> IO LlamaSamplerChainParams
forall a. Storable a => Ptr a -> IO a
peek Ptr LlamaSamplerChainParams
paramsPtr
initSampler :: LlamaSamplerI -> LlamaSamplerContext -> IO (Either String Sampler)
initSampler :: LlamaSamplerI -> LlamaSamplerContext -> IO (Either String Sampler)
initSampler LlamaSamplerI
iface_ LlamaSamplerContext
ctx_ = do
ifacePtr <- Int -> IO (Ptr LlamaSamplerI)
forall a. Int -> IO (Ptr a)
mallocBytes (LlamaSamplerI -> Int
forall a. Storable a => a -> Int
sizeOf (LlamaSamplerI
forall a. HasCallStack => a
undefined :: LlamaSamplerI))
poke ifacePtr iface_
samplerPtr <- c_llama_sampler_init ifacePtr ctx_
if samplerPtr == nullPtr
then return $ Left "Failed to initialize sampler"
else do
fp <- newForeignPtr p_llama_sampler_free samplerPtr
return $ Right (Sampler fp)
getSamplerName :: Sampler -> IO String
getSamplerName :: Sampler -> IO String
getSamplerName (Sampler ForeignPtr LlamaSampler
samplerFPtr) = do
ForeignPtr LlamaSampler
-> (Ptr LlamaSampler -> IO String) -> IO String
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr LlamaSampler
samplerFPtr ((Ptr LlamaSampler -> IO String) -> IO String)
-> (Ptr LlamaSampler -> IO String) -> IO String
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaSampler
samplerPtr -> do
name_ <- Ptr LlamaSampler -> IO CString
c_llama_sampler_name Ptr LlamaSampler
samplerPtr
peekCString name_
acceptTokenWithSampler :: Sampler -> LlamaToken -> IO ()
acceptTokenWithSampler :: Sampler -> CInt -> IO ()
acceptTokenWithSampler (Sampler ForeignPtr LlamaSampler
samplerFPtr) CInt
token_ = do
ForeignPtr LlamaSampler -> (Ptr LlamaSampler -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr LlamaSampler
samplerFPtr ((Ptr LlamaSampler -> IO ()) -> IO ())
-> (Ptr LlamaSampler -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaSampler
samplerPtr ->
Ptr LlamaSampler -> CInt -> IO ()
c_llama_sampler_accept Ptr LlamaSampler
samplerPtr CInt
token_
applySampler :: Sampler -> LlamaTokenDataArray -> IO ()
applySampler :: Sampler -> LlamaTokenDataArray -> IO ()
applySampler (Sampler ForeignPtr LlamaSampler
samplerFPtr) LlamaTokenDataArray
tokenDataArray = do
ForeignPtr LlamaSampler -> (Ptr LlamaSampler -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr LlamaSampler
samplerFPtr ((Ptr LlamaSampler -> IO ()) -> IO ())
-> (Ptr LlamaSampler -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaSampler
samplerPtr ->
LlamaTokenDataArray -> (Ptr LlamaTokenDataArray -> IO ()) -> IO ()
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
withStorable LlamaTokenDataArray
tokenDataArray ((Ptr LlamaTokenDataArray -> IO ()) -> IO ())
-> (Ptr LlamaTokenDataArray -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaTokenDataArray
tokenDataArrayPtr ->
Ptr LlamaSampler -> Ptr LlamaTokenDataArray -> IO ()
c_llama_sampler_apply Ptr LlamaSampler
samplerPtr Ptr LlamaTokenDataArray
tokenDataArrayPtr
resetSampler :: Sampler -> IO ()
resetSampler :: Sampler -> IO ()
resetSampler (Sampler ForeignPtr LlamaSampler
samplerFPtr) = do
ForeignPtr LlamaSampler -> (Ptr LlamaSampler -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr LlamaSampler
samplerFPtr ((Ptr LlamaSampler -> IO ()) -> IO ())
-> (Ptr LlamaSampler -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaSampler
samplerPtr ->
Ptr LlamaSampler -> IO ()
c_llama_sampler_reset (Ptr LlamaSampler -> Ptr LlamaSampler
forall a b. Ptr a -> Ptr b
castPtr Ptr LlamaSampler
samplerPtr)
cloneSampler :: Sampler -> IO (Either String Sampler)
cloneSampler :: Sampler -> IO (Either String Sampler)
cloneSampler (Sampler ForeignPtr LlamaSampler
samplerFPtr) = do
ForeignPtr LlamaSampler
-> (Ptr LlamaSampler -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr LlamaSampler
samplerFPtr ((Ptr LlamaSampler -> IO (Either String Sampler))
-> IO (Either String Sampler))
-> (Ptr LlamaSampler -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaSampler
samplerPtr -> do
clonedSamplerPtr <- Ptr LlamaSampler -> IO (Ptr LlamaSampler)
c_llama_sampler_clone (Ptr LlamaSampler -> Ptr LlamaSampler
forall a b. Ptr a -> Ptr b
castPtr Ptr LlamaSampler
samplerPtr)
if clonedSamplerPtr == nullPtr
then return $ Left "Failed to clone sampler"
else do
fp <- newForeignPtr p_llama_sampler_free clonedSamplerPtr
return $ Right $ Sampler fp
initSamplerChain :: LlamaSamplerChainParams -> IO (Either String Sampler)
initSamplerChain :: LlamaSamplerChainParams -> IO (Either String Sampler)
initSamplerChain LlamaSamplerChainParams
params = do
chainPtr <- LlamaSamplerChainParams -> IO (Ptr LlamaSampler)
c_llama_sampler_chain_init LlamaSamplerChainParams
params
if chainPtr == nullPtr
then return $ Left "Failed to initialize sampler chain"
else do
fp <- newForeignPtr p_llama_sampler_free chainPtr
return $ Right $ Sampler fp
addSamplerToChain :: Sampler -> Ptr LlamaSampler -> IO ()
addSamplerToChain :: Sampler -> Ptr LlamaSampler -> IO ()
addSamplerToChain (Sampler ForeignPtr LlamaSampler
chainFPtr) Ptr LlamaSampler
samplerPtr = do
ForeignPtr LlamaSampler -> (Ptr LlamaSampler -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr LlamaSampler
chainFPtr ((Ptr LlamaSampler -> IO ()) -> IO ())
-> (Ptr LlamaSampler -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaSampler
chainPtr ->
Ptr LlamaSampler -> Ptr LlamaSampler -> IO ()
c_llama_sampler_chain_add Ptr LlamaSampler
chainPtr Ptr LlamaSampler
samplerPtr
getSamplerFromChain :: Sampler -> Int -> IO (Either String Sampler)
getSamplerFromChain :: Sampler -> Int -> IO (Either String Sampler)
getSamplerFromChain (Sampler ForeignPtr LlamaSampler
chainFPtr) Int
index = do
ForeignPtr LlamaSampler
-> (Ptr LlamaSampler -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr LlamaSampler
chainFPtr ((Ptr LlamaSampler -> IO (Either String Sampler))
-> IO (Either String Sampler))
-> (Ptr LlamaSampler -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaSampler
chainPtr -> do
samplerPtr <- Ptr LlamaSampler -> CInt -> IO (Ptr LlamaSampler)
c_llama_sampler_chain_get Ptr LlamaSampler
chainPtr (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
index)
if samplerPtr == nullPtr
then return $ Left "Failed to get sampler from chain"
else do
fp <- newForeignPtr p_llama_sampler_free samplerPtr
return $ Right $ Sampler fp
getSamplerChainLength :: Sampler -> IO Int
getSamplerChainLength :: Sampler -> IO Int
getSamplerChainLength (Sampler ForeignPtr LlamaSampler
chainFPtr) = do
ForeignPtr LlamaSampler -> (Ptr LlamaSampler -> IO Int) -> IO Int
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr LlamaSampler
chainFPtr ((Ptr LlamaSampler -> IO Int) -> IO Int)
-> (Ptr LlamaSampler -> IO Int) -> IO Int
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaSampler
chainPtr -> do
CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> Int) -> IO CInt -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr LlamaSampler -> IO CInt
c_llama_sampler_chain_n Ptr LlamaSampler
chainPtr
removeSamplerFromChain :: Sampler -> Int -> IO (Either String Sampler)
removeSamplerFromChain :: Sampler -> Int -> IO (Either String Sampler)
removeSamplerFromChain (Sampler ForeignPtr LlamaSampler
chainFPtr) Int
index = do
ForeignPtr LlamaSampler
-> (Ptr LlamaSampler -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr LlamaSampler
chainFPtr ((Ptr LlamaSampler -> IO (Either String Sampler))
-> IO (Either String Sampler))
-> (Ptr LlamaSampler -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaSampler
chainPtr -> do
samplerPtr <- Ptr LlamaSampler -> CInt -> IO (Ptr LlamaSampler)
c_llama_sampler_chain_remove Ptr LlamaSampler
chainPtr (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
index)
if samplerPtr == nullPtr
then return $ Left "Failed to remove sampler from chain"
else do
fp <- newForeignPtr p_llama_sampler_free samplerPtr
return $ Right $ Sampler fp
initGreedySampler :: IO (Either String (Ptr LlamaSampler))
initGreedySampler :: IO (Either String (Ptr LlamaSampler))
initGreedySampler = do
samplerPtr <- IO (Ptr LlamaSampler)
c_llama_sampler_init_greedy
if samplerPtr == nullPtr
then return $ Left "Failed to initialize greedy sampler"
else do
return $ Right samplerPtr
initDistributedSampler :: Word32 -> IO (Either String Sampler)
initDistributedSampler :: Word32 -> IO (Either String Sampler)
initDistributedSampler Word32
seed = do
samplerPtr <- CUInt -> IO (Ptr LlamaSampler)
c_llama_sampler_init_dist (Word32 -> CUInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
seed)
if samplerPtr == nullPtr
then return $ Left "Failed to initialize distributed sampler"
else do
fp <- newForeignPtr p_llama_sampler_free samplerPtr
return $ Right $ Sampler fp
initTopKSampler :: Int -> IO (Either String Sampler)
initTopKSampler :: Int -> IO (Either String Sampler)
initTopKSampler Int
k = do
samplerPtr <- CInt -> IO (Ptr LlamaSampler)
c_llama_sampler_init_top_k (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k)
if samplerPtr == nullPtr
then return $ Left "Failed to initialize top-k sampler"
else do
fp <- newForeignPtr p_llama_sampler_free samplerPtr
return $ Right $ Sampler fp
initTopPSampler :: Float -> Int -> IO (Either String Sampler)
initTopPSampler :: Float -> Int -> IO (Either String Sampler)
initTopPSampler Float
p1 Int
minKeep = do
samplerPtr <- CFloat -> CSize -> IO (Ptr LlamaSampler)
c_llama_sampler_init_top_p (Float -> CFloat
forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
p1) (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
minKeep)
if samplerPtr == nullPtr
then return $ Left "Failed to initialize top-p sampler"
else do
fp <- newForeignPtr p_llama_sampler_free samplerPtr
return $ Right $ Sampler fp
initMinPSampler :: Float -> Int -> IO (Either String Sampler)
initMinPSampler :: Float -> Int -> IO (Either String Sampler)
initMinPSampler Float
p1 Int
minKeep = do
samplerPtr <- CFloat -> CSize -> IO (Ptr LlamaSampler)
c_llama_sampler_init_min_p (Float -> CFloat
forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
p1) (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
minKeep)
if samplerPtr == nullPtr
then return $ Left "Failed to initialize min-p sampler"
else do
fp <- newForeignPtr p_llama_sampler_free samplerPtr
return $ Right $ Sampler fp
initTypicalSampler :: Float -> Int -> IO (Either String Sampler)
initTypicalSampler :: Float -> Int -> IO (Either String Sampler)
initTypicalSampler Float
p_ Int
minKeep = do
samplerPtr <- CFloat -> CSize -> IO (Ptr LlamaSampler)
c_llama_sampler_init_typical (Float -> CFloat
forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
p_) (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
minKeep)
if samplerPtr == nullPtr
then return $ Left "Failed to initialize typical sampler"
else do
fp <- newForeignPtr p_llama_sampler_free samplerPtr
return $ Right $ Sampler fp
initTempSampler :: Float -> IO (Either String Sampler)
initTempSampler :: Float -> IO (Either String Sampler)
initTempSampler Float
t = do
samplerPtr <- CFloat -> IO (Ptr LlamaSampler)
c_llama_sampler_init_temp (Float -> CFloat
forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
t)
if samplerPtr == nullPtr
then return $ Left "Failed to initialize temperature sampler"
else do
fp <- newForeignPtr p_llama_sampler_free samplerPtr
return $ Right $ Sampler fp
initTempExtSampler :: Float -> Float -> Float -> IO (Either String Sampler)
initTempExtSampler :: Float -> Float -> Float -> IO (Either String Sampler)
initTempExtSampler Float
t Float
delta Float
exponent_ = do
samplerPtr <- CFloat -> CFloat -> CFloat -> IO (Ptr LlamaSampler)
c_llama_sampler_init_temp_ext (Float -> CFloat
forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
t) (Float -> CFloat
forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
delta) (Float -> CFloat
forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
exponent_)
if samplerPtr == nullPtr
then return $ Left "Failed to initialize extended temperature sampler"
else do
fp <- newForeignPtr p_llama_sampler_free samplerPtr
return $ Right $ Sampler fp
initXTCSampler :: Float -> Float -> Int -> Word32 -> IO (Either String Sampler)
initXTCSampler :: Float -> Float -> Int -> Word32 -> IO (Either String Sampler)
initXTCSampler Float
p1 Float
t Int
minKeep Word32
seed = do
samplerPtr <- CFloat -> CFloat -> CSize -> CUInt -> IO (Ptr LlamaSampler)
c_llama_sampler_init_xtc (Float -> CFloat
forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
p1) (Float -> CFloat
forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
t) (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
minKeep) (Word32 -> CUInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
seed)
if samplerPtr == nullPtr
then return $ Left "Failed to initialize XTC sampler"
else do
fp <- newForeignPtr p_llama_sampler_free samplerPtr
return $ Right $ Sampler fp
initTopNSigmaSampler :: Float -> IO (Either String Sampler)
initTopNSigmaSampler :: Float -> IO (Either String Sampler)
initTopNSigmaSampler Float
n = do
samplerPtr <- CFloat -> IO (Ptr LlamaSampler)
c_llama_sampler_init_top_n_sigma (Float -> CFloat
forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
n)
if samplerPtr == nullPtr
then return $ Left "Failed to initialize top-N sigma sampler"
else do
fp <- newForeignPtr p_llama_sampler_free samplerPtr
return $ Right $ Sampler fp
initMirostatSampler :: Int -> Word32 -> Float -> Float -> Int -> IO (Either String Sampler)
initMirostatSampler :: Int
-> Word32 -> Float -> Float -> Int -> IO (Either String Sampler)
initMirostatSampler Int
nVocab Word32
seed Float
tau Float
eta Int
m = do
samplerPtr <- CInt -> CUInt -> CFloat -> CFloat -> CInt -> IO (Ptr LlamaSampler)
c_llama_sampler_init_mirostat (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nVocab) (Word32 -> CUInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
seed) (Float -> CFloat
forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
tau) (Float -> CFloat
forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
eta) (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m)
if samplerPtr == nullPtr
then return $ Left "Failed to initialize Mirostat sampler"
else do
fp <- newForeignPtr p_llama_sampler_free samplerPtr
return $ Right $ Sampler fp
initMirostatV2Sampler :: Word32 -> Float -> Float -> IO (Either String Sampler)
initMirostatV2Sampler :: Word32 -> Float -> Float -> IO (Either String Sampler)
initMirostatV2Sampler Word32
seed Float
tau Float
eta = do
samplerPtr <- CUInt -> CFloat -> CFloat -> IO (Ptr LlamaSampler)
c_llama_sampler_init_mirostat_v2 (Word32 -> CUInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
seed) (Float -> CFloat
forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
tau) (Float -> CFloat
forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
eta)
if samplerPtr == nullPtr
then return $ Left "Failed to initialize Mirostat V2 sampler"
else do
fp <- newForeignPtr p_llama_sampler_free samplerPtr
return $ Right $ Sampler fp
initGrammarSampler :: Vocab -> String -> String -> IO (Either String Sampler)
initGrammarSampler :: Vocab -> String -> String -> IO (Either String Sampler)
initGrammarSampler (Vocab ForeignPtr CLlamaVocab
vocab) String
grammarStr String
grammarRoot = do
ForeignPtr CLlamaVocab
-> (Ptr CLlamaVocab -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaVocab
vocab ((Ptr CLlamaVocab -> IO (Either String Sampler))
-> IO (Either String Sampler))
-> (Ptr CLlamaVocab -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaVocab
vocabPtr -> do
String
-> (CString -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a. String -> (CString -> IO a) -> IO a
withCString String
grammarStr ((CString -> IO (Either String Sampler))
-> IO (Either String Sampler))
-> (CString -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a b. (a -> b) -> a -> b
$ \CString
grammarStrPtr -> do
String
-> (CString -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a. String -> (CString -> IO a) -> IO a
withCString String
grammarRoot ((CString -> IO (Either String Sampler))
-> IO (Either String Sampler))
-> (CString -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a b. (a -> b) -> a -> b
$ \CString
grammarRootPtr -> do
samplerPtr <- CLlamaVocab -> CString -> CString -> IO (Ptr LlamaSampler)
c_llama_sampler_init_grammar (Ptr CLlamaVocab -> CLlamaVocab
CLlamaVocab Ptr CLlamaVocab
vocabPtr) CString
grammarStrPtr CString
grammarRootPtr
if samplerPtr == nullPtr
then return $ Left "Failed to initialize grammar sampler"
else do
fp <- newForeignPtr p_llama_sampler_free samplerPtr
return $ Right $ Sampler fp
initGrammarLazyPatternsSampler :: Vocab -> String -> String -> [String] -> [LlamaToken] -> IO (Either String Sampler)
initGrammarLazyPatternsSampler :: Vocab
-> String
-> String
-> [String]
-> [CInt]
-> IO (Either String Sampler)
initGrammarLazyPatternsSampler (Vocab ForeignPtr CLlamaVocab
vocab) String
grammarStr String
grammarRoot [String]
triggerPatterns [CInt]
triggerTokens = do
ForeignPtr CLlamaVocab
-> (Ptr CLlamaVocab -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaVocab
vocab ((Ptr CLlamaVocab -> IO (Either String Sampler))
-> IO (Either String Sampler))
-> (Ptr CLlamaVocab -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaVocab
vocabPtr -> do
String
-> (CString -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a. String -> (CString -> IO a) -> IO a
withCString String
grammarStr ((CString -> IO (Either String Sampler))
-> IO (Either String Sampler))
-> (CString -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a b. (a -> b) -> a -> b
$ \CString
grammarStrPtr -> do
String
-> (CString -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a. String -> (CString -> IO a) -> IO a
withCString String
grammarRoot ((CString -> IO (Either String Sampler))
-> IO (Either String Sampler))
-> (CString -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a b. (a -> b) -> a -> b
$ \CString
grammarRootPtr -> do
triggerPatternsPtr <- [String] -> IO (ForeignPtr CString)
newArrayOfPtrs [String]
triggerPatterns
withForeignPtr triggerPatternsPtr $ \Ptr CString
triggerPatternsPtr' -> do
triggerTokensPtr <- Int -> IO (Ptr CInt)
forall a. Int -> IO (Ptr a)
mallocBytes ([CInt] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [CInt]
triggerTokens Int -> Int -> Int
forall a. Num a => a -> a -> a
* CInt -> Int
forall a. Storable a => a -> Int
sizeOf (CInt
forall a. HasCallStack => a
undefined :: LlamaToken))
pokeArray triggerTokensPtr triggerTokens
samplerPtr <- c_llama_sampler_init_grammar_lazy_patterns
(CLlamaVocab vocabPtr)
grammarStrPtr
grammarRootPtr
triggerPatternsPtr'
(fromIntegral (length triggerPatterns))
triggerTokensPtr
(fromIntegral (length triggerTokens))
if samplerPtr == nullPtr
then return $ Left "Failed to initialize grammar lazy patterns sampler"
else do
fp <- newForeignPtr p_llama_sampler_free samplerPtr
return $ Right $ Sampler fp
initPenaltiesSampler :: Int -> Float -> Float -> Float -> IO (Either String Sampler)
initPenaltiesSampler :: Int -> Float -> Float -> Float -> IO (Either String Sampler)
initPenaltiesSampler Int
penaltyLastN Float
penaltyRepeat Float
penaltyFreq Float
penaltyPresent = do
samplerPtr <- CInt -> CFloat -> CFloat -> CFloat -> IO (Ptr LlamaSampler)
c_llama_sampler_init_penalties
(Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
penaltyLastN)
(Float -> CFloat
forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
penaltyRepeat)
(Float -> CFloat
forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
penaltyFreq)
(Float -> CFloat
forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
penaltyPresent)
if samplerPtr == nullPtr
then return $ Left "Failed to initialize penalties sampler"
else do
fp <- newForeignPtr p_llama_sampler_free samplerPtr
return $ Right $ Sampler fp
newArrayOfPtrs :: [String] -> IO (ForeignPtr CString)
newArrayOfPtrs :: [String] -> IO (ForeignPtr CString)
newArrayOfPtrs [String]
xs = do
ptrs <- Int -> IO (Ptr CString)
forall a. Int -> IO (Ptr a)
mallocBytes ([String] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [String]
xs Int -> Int -> Int
forall a. Num a => a -> a -> a
* Ptr CString -> Int
forall a. Storable a => a -> Int
sizeOf (Ptr CString
forall a. HasCallStack => a
undefined :: Ptr CString))
mapM_ (\(Int
i, String
x) -> String -> (CString -> IO ()) -> IO ()
forall a. String -> (CString -> IO a) -> IO a
withCString String
x ((CString -> IO ()) -> IO ()) -> (CString -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \CString
cstr -> Ptr CString -> CString -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr CString -> Ptr (ZonkAny 0)
forall a b. Ptr a -> Ptr b
castPtr Ptr CString
ptrs Ptr (ZonkAny 0) -> Int -> Ptr CString
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
* Ptr CString -> Int
forall a. Storable a => a -> Int
sizeOf (Ptr CString
forall a. HasCallStack => a
undefined :: Ptr CString))) CString
cstr) (zip [0..] xs)
newForeignPtr_ ptrs
withCStringArray :: [String] -> (Ptr CString -> IO a) -> IO a
withCStringArray :: forall a. [String] -> (Ptr CString -> IO a) -> IO a
withCStringArray [String]
xs Ptr CString -> IO a
f = do
ptrs <- [String] -> IO (ForeignPtr CString)
newArrayOfPtrs [String]
xs
withForeignPtr ptrs $ \Ptr CString
ptrs' ->
Ptr CString -> IO a
f Ptr CString
ptrs'
initDrySampler :: Vocab -> Int -> Float -> Float -> Int -> Int -> [String] -> IO (Either String Sampler)
initDrySampler :: Vocab
-> Int
-> Float
-> Float
-> Int
-> Int
-> [String]
-> IO (Either String Sampler)
initDrySampler (Vocab ForeignPtr CLlamaVocab
vocab) Int
nCtxTrain Float
dryMultiplier Float
dryBase Int
dryAllowedLength Int
dryPenaltyLastN [String]
seqBreakers = do
ForeignPtr CLlamaVocab
-> (Ptr CLlamaVocab -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaVocab
vocab ((Ptr CLlamaVocab -> IO (Either String Sampler))
-> IO (Either String Sampler))
-> (Ptr CLlamaVocab -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaVocab
vocabPtr -> do
[String]
-> (Ptr CString -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a. [String] -> (Ptr CString -> IO a) -> IO a
withCStringArray [String]
seqBreakers ((Ptr CString -> IO (Either String Sampler))
-> IO (Either String Sampler))
-> (Ptr CString -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a b. (a -> b) -> a -> b
$ \Ptr CString
seqBreakersPtr -> do
samplerPtr <- CLlamaVocab
-> CInt
-> CFloat
-> CFloat
-> CInt
-> CInt
-> Ptr CString
-> CSize
-> IO (Ptr LlamaSampler)
c_llama_sampler_init_dry
(Ptr CLlamaVocab -> CLlamaVocab
CLlamaVocab Ptr CLlamaVocab
vocabPtr)
(Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nCtxTrain)
(Float -> CFloat
forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
dryMultiplier)
(Float -> CFloat
forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
dryBase)
(Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
dryAllowedLength)
(Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
dryPenaltyLastN)
Ptr CString
seqBreakersPtr
(Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([String] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [String]
seqBreakers))
if samplerPtr == nullPtr
then return $ Left "Failed to initialize dry sampler"
else do
fp <- newForeignPtr p_llama_sampler_free samplerPtr
return $ Right $ Sampler fp
initLogitBiasSampler :: Int -> [LlamaLogitBias] -> IO (Either String Sampler)
initLogitBiasSampler :: Int -> [LlamaLogitBias] -> IO (Either String Sampler)
initLogitBiasSampler Int
nVocab [LlamaLogitBias]
logitBiases = do
Int
-> (Ptr LlamaLogitBias -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a b. Storable a => Int -> (Ptr a -> IO b) -> IO b
allocaArray ([LlamaLogitBias] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LlamaLogitBias]
logitBiases) ((Ptr LlamaLogitBias -> IO (Either String Sampler))
-> IO (Either String Sampler))
-> (Ptr LlamaLogitBias -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaLogitBias
logitBiasPtr -> do
Ptr LlamaLogitBias -> [LlamaLogitBias] -> IO ()
forall a. Storable a => Ptr a -> [a] -> IO ()
pokeArray Ptr LlamaLogitBias
logitBiasPtr [LlamaLogitBias]
logitBiases
samplerPtr <- CInt -> CInt -> Ptr LlamaLogitBias -> IO (Ptr LlamaSampler)
c_llama_sampler_init_logit_bias
(Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nVocab)
(Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([LlamaLogitBias] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LlamaLogitBias]
logitBiases))
Ptr LlamaLogitBias
logitBiasPtr
if samplerPtr == nullPtr
then return $ Left "Failed to initialize logit bias sampler"
else do
fp <- newForeignPtr p_llama_sampler_free samplerPtr
return $ Right $ Sampler fp
initInfillSampler :: Vocab -> IO (Either String Sampler)
initInfillSampler :: Vocab -> IO (Either String Sampler)
initInfillSampler (Vocab ForeignPtr CLlamaVocab
vocab) = do
ForeignPtr CLlamaVocab
-> (Ptr CLlamaVocab -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaVocab
vocab ((Ptr CLlamaVocab -> IO (Either String Sampler))
-> IO (Either String Sampler))
-> (Ptr CLlamaVocab -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaVocab
vocabPtr -> do
samplerPtr <- CLlamaVocab -> IO (Ptr LlamaSampler)
c_llama_sampler_init_infill (Ptr CLlamaVocab -> CLlamaVocab
CLlamaVocab Ptr CLlamaVocab
vocabPtr)
if samplerPtr == nullPtr
then return $ Left "Failed to initialize infill sampler"
else do
fp <- newForeignPtr p_llama_sampler_free samplerPtr
return $ Right $ Sampler fp
getSamplerSeed :: Sampler -> IO Word32
getSamplerSeed :: Sampler -> IO Word32
getSamplerSeed (Sampler ForeignPtr LlamaSampler
samplerFPtr) = do
ForeignPtr LlamaSampler
-> (Ptr LlamaSampler -> IO Word32) -> IO Word32
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr LlamaSampler
samplerFPtr ((CUInt -> Word32) -> IO CUInt -> IO Word32
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap CUInt -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (IO CUInt -> IO Word32)
-> (Ptr LlamaSampler -> IO CUInt) -> Ptr LlamaSampler -> IO Word32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr LlamaSampler -> IO CUInt
c_llama_sampler_get_seed)
sampleWithSampler :: Sampler -> Context -> Int -> IO LlamaToken
sampleWithSampler :: Sampler -> Context -> Int -> IO CInt
sampleWithSampler (Sampler ForeignPtr LlamaSampler
samplerFPtr) (Context ForeignPtr CLlamaContext
ctxFPtr) Int
idx = do
ForeignPtr LlamaSampler -> (Ptr LlamaSampler -> IO CInt) -> IO CInt
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr LlamaSampler
samplerFPtr ((Ptr LlamaSampler -> IO CInt) -> IO CInt)
-> (Ptr LlamaSampler -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaSampler
samplerPtr -> do
ForeignPtr CLlamaContext
-> (Ptr CLlamaContext -> IO CInt) -> IO CInt
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO CInt) -> IO CInt)
-> (Ptr CLlamaContext -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
Ptr LlamaSampler -> CLlamaContext -> CInt -> IO CInt
c_llama_sampler_sample Ptr LlamaSampler
samplerPtr (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr) (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
idx)