{- |
Module      : Llama.Sampler
Description : High level Sampler interface for llama-cpp
Copyright   : (c) 2025 Tushar Adhatrao
License     : MIT
Maintainer  : Tushar Adhatrao <tusharadhatrao@gmail.com>
-}
module Llama.Sampler (
    defaultSamplerChainParams
, initSampler
, getSamplerName
, acceptTokenWithSampler
, applySampler
, resetSampler
, cloneSampler
, initSamplerChain
, addSamplerToChain
, getSamplerFromChain
, getSamplerChainLength
, removeSamplerFromChain
, initGreedySampler
, initDistributedSampler
, initTopKSampler
, initTopPSampler
, initMinPSampler
, initTypicalSampler
, initTempSampler
, initTempExtSampler
, initXTCSampler
, initTopNSigmaSampler
, initMirostatSampler
, initMirostatV2Sampler
, initGrammarSampler
, initGrammarLazyPatternsSampler
, initPenaltiesSampler
, initDrySampler
, initLogitBiasSampler
, initInfillSampler
, getSamplerSeed
, sampleWithSampler
) where

import Llama.Internal.Types
import Llama.Internal.Types.Params
import Foreign
import Llama.Internal.Foreign
import Foreign.C.String

-- | Convert Storable Haskell struct to pointer and run an action
withStorable :: Storable a => a -> (Ptr a -> IO b) -> IO b
withStorable :: forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
withStorable a
x Ptr a -> IO b
f = (Ptr a -> IO b) -> IO b
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr a -> IO b) -> IO b) -> (Ptr a -> IO b) -> IO b
forall a b. (a -> b) -> a -> b
$ \Ptr a
ptr -> do
  Ptr a -> a -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr a
ptr a
x
  Ptr a -> IO b
f Ptr a
ptr

-- | Get the default parameters for a sampler chain
defaultSamplerChainParams :: IO LlamaSamplerChainParams
defaultSamplerChainParams :: IO LlamaSamplerChainParams
defaultSamplerChainParams = (Ptr LlamaSamplerChainParams -> IO LlamaSamplerChainParams)
-> IO LlamaSamplerChainParams
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr LlamaSamplerChainParams -> IO LlamaSamplerChainParams)
 -> IO LlamaSamplerChainParams)
-> (Ptr LlamaSamplerChainParams -> IO LlamaSamplerChainParams)
-> IO LlamaSamplerChainParams
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaSamplerChainParams
paramsPtr -> do
      Ptr LlamaSamplerChainParams -> IO ()
c_llama_sampler_chain_default_params_into Ptr LlamaSamplerChainParams
paramsPtr
      Ptr LlamaSamplerChainParams -> IO LlamaSamplerChainParams
forall a. Storable a => Ptr a -> IO a
peek Ptr LlamaSamplerChainParams
paramsPtr

-- | Initialize a sampler
initSampler :: LlamaSamplerI -> LlamaSamplerContext -> IO (Either String Sampler)
initSampler :: LlamaSamplerI -> LlamaSamplerContext -> IO (Either String Sampler)
initSampler LlamaSamplerI
iface_ LlamaSamplerContext
ctx_ = do
  ifacePtr <- Int -> IO (Ptr LlamaSamplerI)
forall a. Int -> IO (Ptr a)
mallocBytes (LlamaSamplerI -> Int
forall a. Storable a => a -> Int
sizeOf (LlamaSamplerI
forall a. HasCallStack => a
undefined :: LlamaSamplerI))
  poke ifacePtr iface_
  samplerPtr <- c_llama_sampler_init ifacePtr ctx_
  if samplerPtr == nullPtr
    then return $ Left "Failed to initialize sampler"
    else do
      fp <- newForeignPtr p_llama_sampler_free samplerPtr
      return $ Right (Sampler fp)

-- | Get the name of a sampler
getSamplerName :: Sampler -> IO String
getSamplerName :: Sampler -> IO String
getSamplerName (Sampler ForeignPtr LlamaSampler
samplerFPtr) = do
  ForeignPtr LlamaSampler
-> (Ptr LlamaSampler -> IO String) -> IO String
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr LlamaSampler
samplerFPtr ((Ptr LlamaSampler -> IO String) -> IO String)
-> (Ptr LlamaSampler -> IO String) -> IO String
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaSampler
samplerPtr -> do
    name_ <- Ptr LlamaSampler -> IO CString
c_llama_sampler_name Ptr LlamaSampler
samplerPtr
    peekCString name_

-- | Accept a token with a sampler
acceptTokenWithSampler :: Sampler -> LlamaToken -> IO ()
acceptTokenWithSampler :: Sampler -> CInt -> IO ()
acceptTokenWithSampler (Sampler ForeignPtr LlamaSampler
samplerFPtr) CInt
token_ = do
  ForeignPtr LlamaSampler -> (Ptr LlamaSampler -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr LlamaSampler
samplerFPtr ((Ptr LlamaSampler -> IO ()) -> IO ())
-> (Ptr LlamaSampler -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaSampler
samplerPtr ->
    Ptr LlamaSampler -> CInt -> IO ()
c_llama_sampler_accept Ptr LlamaSampler
samplerPtr CInt
token_

-- | Apply a sampler to a token data array
applySampler :: Sampler -> LlamaTokenDataArray -> IO ()
applySampler :: Sampler -> LlamaTokenDataArray -> IO ()
applySampler (Sampler ForeignPtr LlamaSampler
samplerFPtr) LlamaTokenDataArray
tokenDataArray = do
  ForeignPtr LlamaSampler -> (Ptr LlamaSampler -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr LlamaSampler
samplerFPtr ((Ptr LlamaSampler -> IO ()) -> IO ())
-> (Ptr LlamaSampler -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaSampler
samplerPtr ->
    LlamaTokenDataArray -> (Ptr LlamaTokenDataArray -> IO ()) -> IO ()
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
withStorable LlamaTokenDataArray
tokenDataArray ((Ptr LlamaTokenDataArray -> IO ()) -> IO ())
-> (Ptr LlamaTokenDataArray -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaTokenDataArray
tokenDataArrayPtr ->
      Ptr LlamaSampler -> Ptr LlamaTokenDataArray -> IO ()
c_llama_sampler_apply Ptr LlamaSampler
samplerPtr Ptr LlamaTokenDataArray
tokenDataArrayPtr

-- | Reset a sampler
resetSampler :: Sampler -> IO ()
resetSampler :: Sampler -> IO ()
resetSampler (Sampler ForeignPtr LlamaSampler
samplerFPtr) = do
  ForeignPtr LlamaSampler -> (Ptr LlamaSampler -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr LlamaSampler
samplerFPtr ((Ptr LlamaSampler -> IO ()) -> IO ())
-> (Ptr LlamaSampler -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaSampler
samplerPtr ->
    Ptr LlamaSampler -> IO ()
c_llama_sampler_reset (Ptr LlamaSampler -> Ptr LlamaSampler
forall a b. Ptr a -> Ptr b
castPtr Ptr LlamaSampler
samplerPtr)

-- | Clone a sampler
cloneSampler :: Sampler -> IO (Either String Sampler)
cloneSampler :: Sampler -> IO (Either String Sampler)
cloneSampler (Sampler ForeignPtr LlamaSampler
samplerFPtr) = do
  ForeignPtr LlamaSampler
-> (Ptr LlamaSampler -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr LlamaSampler
samplerFPtr ((Ptr LlamaSampler -> IO (Either String Sampler))
 -> IO (Either String Sampler))
-> (Ptr LlamaSampler -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaSampler
samplerPtr -> do
    clonedSamplerPtr <- Ptr LlamaSampler -> IO (Ptr LlamaSampler)
c_llama_sampler_clone (Ptr LlamaSampler -> Ptr LlamaSampler
forall a b. Ptr a -> Ptr b
castPtr Ptr LlamaSampler
samplerPtr)
    if clonedSamplerPtr == nullPtr
      then return $ Left "Failed to clone sampler"
      else do
        fp <- newForeignPtr p_llama_sampler_free clonedSamplerPtr
        return $ Right $ Sampler fp

-- | Initialize a sampler chain
initSamplerChain :: LlamaSamplerChainParams -> IO (Either String Sampler)
initSamplerChain :: LlamaSamplerChainParams -> IO (Either String Sampler)
initSamplerChain LlamaSamplerChainParams
params = do
  chainPtr <- LlamaSamplerChainParams -> IO (Ptr LlamaSampler)
c_llama_sampler_chain_init LlamaSamplerChainParams
params
  if chainPtr == nullPtr
    then return $ Left "Failed to initialize sampler chain"
    else do
      fp <- newForeignPtr p_llama_sampler_free chainPtr
      return $ Right $ Sampler fp

-- | Add a sampler to a sampler chain
addSamplerToChain :: Sampler -> Ptr LlamaSampler -> IO ()
addSamplerToChain :: Sampler -> Ptr LlamaSampler -> IO ()
addSamplerToChain (Sampler ForeignPtr LlamaSampler
chainFPtr) Ptr LlamaSampler
samplerPtr = do
  ForeignPtr LlamaSampler -> (Ptr LlamaSampler -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr LlamaSampler
chainFPtr ((Ptr LlamaSampler -> IO ()) -> IO ())
-> (Ptr LlamaSampler -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaSampler
chainPtr ->
      Ptr LlamaSampler -> Ptr LlamaSampler -> IO ()
c_llama_sampler_chain_add Ptr LlamaSampler
chainPtr Ptr LlamaSampler
samplerPtr

getSamplerFromChain :: Sampler -> Int -> IO (Either String Sampler)
getSamplerFromChain :: Sampler -> Int -> IO (Either String Sampler)
getSamplerFromChain (Sampler ForeignPtr LlamaSampler
chainFPtr) Int
index = do
  ForeignPtr LlamaSampler
-> (Ptr LlamaSampler -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr LlamaSampler
chainFPtr ((Ptr LlamaSampler -> IO (Either String Sampler))
 -> IO (Either String Sampler))
-> (Ptr LlamaSampler -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaSampler
chainPtr -> do
    samplerPtr <- Ptr LlamaSampler -> CInt -> IO (Ptr LlamaSampler)
c_llama_sampler_chain_get Ptr LlamaSampler
chainPtr (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
index)
    if samplerPtr == nullPtr
      then return $ Left "Failed to get sampler from chain"
      else do
        fp <- newForeignPtr p_llama_sampler_free samplerPtr
        return $ Right $ Sampler fp

-- | Get the number of samplers in a sampler chain
getSamplerChainLength :: Sampler -> IO Int
getSamplerChainLength :: Sampler -> IO Int
getSamplerChainLength (Sampler ForeignPtr LlamaSampler
chainFPtr) = do
  ForeignPtr LlamaSampler -> (Ptr LlamaSampler -> IO Int) -> IO Int
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr LlamaSampler
chainFPtr ((Ptr LlamaSampler -> IO Int) -> IO Int)
-> (Ptr LlamaSampler -> IO Int) -> IO Int
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaSampler
chainPtr -> do
    CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> Int) -> IO CInt -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr LlamaSampler -> IO CInt
c_llama_sampler_chain_n Ptr LlamaSampler
chainPtr

-- | Remove a sampler from a sampler chain
removeSamplerFromChain :: Sampler -> Int -> IO (Either String Sampler)
removeSamplerFromChain :: Sampler -> Int -> IO (Either String Sampler)
removeSamplerFromChain (Sampler ForeignPtr LlamaSampler
chainFPtr) Int
index = do
  ForeignPtr LlamaSampler
-> (Ptr LlamaSampler -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr LlamaSampler
chainFPtr ((Ptr LlamaSampler -> IO (Either String Sampler))
 -> IO (Either String Sampler))
-> (Ptr LlamaSampler -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaSampler
chainPtr -> do
    samplerPtr <- Ptr LlamaSampler -> CInt -> IO (Ptr LlamaSampler)
c_llama_sampler_chain_remove Ptr LlamaSampler
chainPtr (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
index)
    if samplerPtr == nullPtr
      then return $ Left "Failed to remove sampler from chain"
      else do
        fp <- newForeignPtr p_llama_sampler_free samplerPtr
        return $ Right $ Sampler fp

-- | Initialize a greedy sampler
initGreedySampler :: IO (Either String (Ptr LlamaSampler))
initGreedySampler :: IO (Either String (Ptr LlamaSampler))
initGreedySampler = do
  samplerPtr <- IO (Ptr LlamaSampler)
c_llama_sampler_init_greedy
  if samplerPtr == nullPtr
    then return $ Left "Failed to initialize greedy sampler"
    else do
      return $ Right samplerPtr

-- | Initialize a distributed sampler
initDistributedSampler :: Word32 -> IO (Either String Sampler)
initDistributedSampler :: Word32 -> IO (Either String Sampler)
initDistributedSampler Word32
seed = do
  samplerPtr <- CUInt -> IO (Ptr LlamaSampler)
c_llama_sampler_init_dist (Word32 -> CUInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
seed)
  if samplerPtr == nullPtr
    then return $ Left "Failed to initialize distributed sampler"
    else do
      fp <- newForeignPtr p_llama_sampler_free samplerPtr
      return $ Right $ Sampler fp

-- | Initialize a top-k sampler
initTopKSampler :: Int -> IO (Either String Sampler)
initTopKSampler :: Int -> IO (Either String Sampler)
initTopKSampler Int
k = do
  samplerPtr <- CInt -> IO (Ptr LlamaSampler)
c_llama_sampler_init_top_k (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k)
  if samplerPtr == nullPtr
    then return $ Left "Failed to initialize top-k sampler"
    else do
      fp <- newForeignPtr p_llama_sampler_free samplerPtr
      return $ Right $ Sampler fp

-- | Initialize a top-p sampler
initTopPSampler :: Float -> Int -> IO (Either String Sampler)
initTopPSampler :: Float -> Int -> IO (Either String Sampler)
initTopPSampler Float
p1 Int
minKeep = do
  samplerPtr <- CFloat -> CSize -> IO (Ptr LlamaSampler)
c_llama_sampler_init_top_p (Float -> CFloat
forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
p1) (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
minKeep)
  if samplerPtr == nullPtr
    then return $ Left "Failed to initialize top-p sampler"
    else do
      fp <- newForeignPtr p_llama_sampler_free samplerPtr
      return $ Right $ Sampler fp

-- | Initialize a min-p sampler
initMinPSampler :: Float -> Int -> IO (Either String Sampler)
initMinPSampler :: Float -> Int -> IO (Either String Sampler)
initMinPSampler Float
p1 Int
minKeep = do
  samplerPtr <- CFloat -> CSize -> IO (Ptr LlamaSampler)
c_llama_sampler_init_min_p (Float -> CFloat
forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
p1) (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
minKeep)
  if samplerPtr == nullPtr
    then return $ Left "Failed to initialize min-p sampler"
    else do
      fp <- newForeignPtr p_llama_sampler_free samplerPtr
      return $ Right $ Sampler fp

-- | Initialize a typical sampler
initTypicalSampler :: Float -> Int -> IO (Either String Sampler)
initTypicalSampler :: Float -> Int -> IO (Either String Sampler)
initTypicalSampler Float
p_ Int
minKeep = do
  samplerPtr <- CFloat -> CSize -> IO (Ptr LlamaSampler)
c_llama_sampler_init_typical (Float -> CFloat
forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
p_) (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
minKeep)
  if samplerPtr == nullPtr
    then return $ Left "Failed to initialize typical sampler"
    else do
      fp <- newForeignPtr p_llama_sampler_free samplerPtr
      return $ Right $ Sampler fp

-- | Initialize a temperature sampler
initTempSampler :: Float -> IO (Either String Sampler)
initTempSampler :: Float -> IO (Either String Sampler)
initTempSampler Float
t = do
  samplerPtr <- CFloat -> IO (Ptr LlamaSampler)
c_llama_sampler_init_temp (Float -> CFloat
forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
t)
  if samplerPtr == nullPtr
    then return $ Left "Failed to initialize temperature sampler"
    else do
      fp <- newForeignPtr p_llama_sampler_free samplerPtr
      return $ Right $ Sampler fp

-- | Initialize an extended temperature sampler
initTempExtSampler :: Float -> Float -> Float -> IO (Either String Sampler)
initTempExtSampler :: Float -> Float -> Float -> IO (Either String Sampler)
initTempExtSampler Float
t Float
delta Float
exponent_ = do
  samplerPtr <- CFloat -> CFloat -> CFloat -> IO (Ptr LlamaSampler)
c_llama_sampler_init_temp_ext (Float -> CFloat
forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
t) (Float -> CFloat
forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
delta) (Float -> CFloat
forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
exponent_)
  if samplerPtr == nullPtr
    then return $ Left "Failed to initialize extended temperature sampler"
    else do
      fp <- newForeignPtr p_llama_sampler_free samplerPtr
      return $ Right $ Sampler fp

-- | Initialize an XTC sampler
initXTCSampler :: Float -> Float -> Int -> Word32 -> IO (Either String Sampler)
initXTCSampler :: Float -> Float -> Int -> Word32 -> IO (Either String Sampler)
initXTCSampler Float
p1 Float
t Int
minKeep Word32
seed = do
  samplerPtr <- CFloat -> CFloat -> CSize -> CUInt -> IO (Ptr LlamaSampler)
c_llama_sampler_init_xtc (Float -> CFloat
forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
p1) (Float -> CFloat
forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
t) (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
minKeep) (Word32 -> CUInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
seed)
  if samplerPtr == nullPtr
    then return $ Left "Failed to initialize XTC sampler"
    else do
      fp <- newForeignPtr p_llama_sampler_free samplerPtr
      return $ Right $ Sampler fp

-- | Initialize a top-N sigma sampler
initTopNSigmaSampler :: Float -> IO (Either String Sampler)
initTopNSigmaSampler :: Float -> IO (Either String Sampler)
initTopNSigmaSampler Float
n = do
  samplerPtr <- CFloat -> IO (Ptr LlamaSampler)
c_llama_sampler_init_top_n_sigma (Float -> CFloat
forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
n)
  if samplerPtr == nullPtr
    then return $ Left "Failed to initialize top-N sigma sampler"
    else do
      fp <- newForeignPtr p_llama_sampler_free samplerPtr
      return $ Right $ Sampler fp

-- | Initialize a Mirostat sampler
initMirostatSampler :: Int -> Word32 -> Float -> Float -> Int -> IO (Either String Sampler)
initMirostatSampler :: Int
-> Word32 -> Float -> Float -> Int -> IO (Either String Sampler)
initMirostatSampler Int
nVocab Word32
seed Float
tau Float
eta Int
m = do
  samplerPtr <- CInt -> CUInt -> CFloat -> CFloat -> CInt -> IO (Ptr LlamaSampler)
c_llama_sampler_init_mirostat (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nVocab) (Word32 -> CUInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
seed) (Float -> CFloat
forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
tau) (Float -> CFloat
forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
eta) (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m)
  if samplerPtr == nullPtr
    then return $ Left "Failed to initialize Mirostat sampler"
    else do
      fp <- newForeignPtr p_llama_sampler_free samplerPtr
      return $ Right $ Sampler fp

-- | Initialize a Mirostat V2 sampler
initMirostatV2Sampler :: Word32 -> Float -> Float -> IO (Either String Sampler)
initMirostatV2Sampler :: Word32 -> Float -> Float -> IO (Either String Sampler)
initMirostatV2Sampler Word32
seed Float
tau Float
eta = do
  samplerPtr <- CUInt -> CFloat -> CFloat -> IO (Ptr LlamaSampler)
c_llama_sampler_init_mirostat_v2 (Word32 -> CUInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
seed) (Float -> CFloat
forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
tau) (Float -> CFloat
forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
eta)
  if samplerPtr == nullPtr
    then return $ Left "Failed to initialize Mirostat V2 sampler"
    else do
      fp <- newForeignPtr p_llama_sampler_free samplerPtr
      return $ Right $ Sampler fp

-- | Initialize a grammar sampler
initGrammarSampler :: Vocab -> String -> String -> IO (Either String Sampler)
initGrammarSampler :: Vocab -> String -> String -> IO (Either String Sampler)
initGrammarSampler (Vocab ForeignPtr CLlamaVocab
vocab) String
grammarStr String
grammarRoot = do
  ForeignPtr CLlamaVocab
-> (Ptr CLlamaVocab -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaVocab
vocab ((Ptr CLlamaVocab -> IO (Either String Sampler))
 -> IO (Either String Sampler))
-> (Ptr CLlamaVocab -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaVocab
vocabPtr -> do
    String
-> (CString -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a. String -> (CString -> IO a) -> IO a
withCString String
grammarStr ((CString -> IO (Either String Sampler))
 -> IO (Either String Sampler))
-> (CString -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a b. (a -> b) -> a -> b
$ \CString
grammarStrPtr -> do
      String
-> (CString -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a. String -> (CString -> IO a) -> IO a
withCString String
grammarRoot ((CString -> IO (Either String Sampler))
 -> IO (Either String Sampler))
-> (CString -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a b. (a -> b) -> a -> b
$ \CString
grammarRootPtr -> do
        samplerPtr <- CLlamaVocab -> CString -> CString -> IO (Ptr LlamaSampler)
c_llama_sampler_init_grammar (Ptr CLlamaVocab -> CLlamaVocab
CLlamaVocab Ptr CLlamaVocab
vocabPtr) CString
grammarStrPtr CString
grammarRootPtr
        if samplerPtr == nullPtr
          then return $ Left "Failed to initialize grammar sampler"
          else do
            fp <- newForeignPtr p_llama_sampler_free samplerPtr
            return $ Right $ Sampler fp

-- | Initialize a grammar sampler with lazy patterns
initGrammarLazyPatternsSampler :: Vocab -> String -> String -> [String] -> [LlamaToken] -> IO (Either String Sampler)
initGrammarLazyPatternsSampler :: Vocab
-> String
-> String
-> [String]
-> [CInt]
-> IO (Either String Sampler)
initGrammarLazyPatternsSampler (Vocab ForeignPtr CLlamaVocab
vocab) String
grammarStr String
grammarRoot [String]
triggerPatterns [CInt]
triggerTokens = do
  ForeignPtr CLlamaVocab
-> (Ptr CLlamaVocab -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaVocab
vocab ((Ptr CLlamaVocab -> IO (Either String Sampler))
 -> IO (Either String Sampler))
-> (Ptr CLlamaVocab -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaVocab
vocabPtr -> do
    String
-> (CString -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a. String -> (CString -> IO a) -> IO a
withCString String
grammarStr ((CString -> IO (Either String Sampler))
 -> IO (Either String Sampler))
-> (CString -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a b. (a -> b) -> a -> b
$ \CString
grammarStrPtr -> do
      String
-> (CString -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a. String -> (CString -> IO a) -> IO a
withCString String
grammarRoot ((CString -> IO (Either String Sampler))
 -> IO (Either String Sampler))
-> (CString -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a b. (a -> b) -> a -> b
$ \CString
grammarRootPtr -> do
        triggerPatternsPtr <- [String] -> IO (ForeignPtr CString)
newArrayOfPtrs [String]
triggerPatterns
        withForeignPtr triggerPatternsPtr $ \Ptr CString
triggerPatternsPtr' -> do
          triggerTokensPtr <- Int -> IO (Ptr CInt)
forall a. Int -> IO (Ptr a)
mallocBytes ([CInt] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [CInt]
triggerTokens Int -> Int -> Int
forall a. Num a => a -> a -> a
* CInt -> Int
forall a. Storable a => a -> Int
sizeOf (CInt
forall a. HasCallStack => a
undefined :: LlamaToken))
          pokeArray triggerTokensPtr triggerTokens
          samplerPtr <- c_llama_sampler_init_grammar_lazy_patterns
            (CLlamaVocab vocabPtr)
            grammarStrPtr
            grammarRootPtr
            triggerPatternsPtr'
            (fromIntegral (length triggerPatterns))
            triggerTokensPtr
            (fromIntegral (length triggerTokens))
          if samplerPtr == nullPtr
            then return $ Left "Failed to initialize grammar lazy patterns sampler"
            else do
              fp <- newForeignPtr p_llama_sampler_free samplerPtr
              return $ Right $ Sampler fp

-- | Initialize a penalties sampler
initPenaltiesSampler :: Int -> Float -> Float -> Float -> IO (Either String Sampler)
initPenaltiesSampler :: Int -> Float -> Float -> Float -> IO (Either String Sampler)
initPenaltiesSampler Int
penaltyLastN Float
penaltyRepeat Float
penaltyFreq Float
penaltyPresent = do
  samplerPtr <- CInt -> CFloat -> CFloat -> CFloat -> IO (Ptr LlamaSampler)
c_llama_sampler_init_penalties
    (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
penaltyLastN)
    (Float -> CFloat
forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
penaltyRepeat)
    (Float -> CFloat
forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
penaltyFreq)
    (Float -> CFloat
forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
penaltyPresent)
  if samplerPtr == nullPtr
    then return $ Left "Failed to initialize penalties sampler"
    else do
      fp <- newForeignPtr p_llama_sampler_free samplerPtr
      return $ Right $ Sampler fp

newArrayOfPtrs :: [String] -> IO (ForeignPtr CString)
newArrayOfPtrs :: [String] -> IO (ForeignPtr CString)
newArrayOfPtrs [String]
xs = do
  ptrs <- Int -> IO (Ptr CString)
forall a. Int -> IO (Ptr a)
mallocBytes ([String] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [String]
xs Int -> Int -> Int
forall a. Num a => a -> a -> a
* Ptr CString -> Int
forall a. Storable a => a -> Int
sizeOf (Ptr CString
forall a. HasCallStack => a
undefined :: Ptr CString))
  mapM_ (\(Int
i, String
x) -> String -> (CString -> IO ()) -> IO ()
forall a. String -> (CString -> IO a) -> IO a
withCString String
x ((CString -> IO ()) -> IO ()) -> (CString -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \CString
cstr -> Ptr CString -> CString -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr CString -> Ptr (ZonkAny 0)
forall a b. Ptr a -> Ptr b
castPtr Ptr CString
ptrs Ptr (ZonkAny 0) -> Int -> Ptr CString
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
* Ptr CString -> Int
forall a. Storable a => a -> Int
sizeOf (Ptr CString
forall a. HasCallStack => a
undefined :: Ptr CString))) CString
cstr) (zip [0..] xs)
  newForeignPtr_ ptrs

-- Helper function
withCStringArray :: [String] -> (Ptr CString -> IO a) -> IO a
withCStringArray :: forall a. [String] -> (Ptr CString -> IO a) -> IO a
withCStringArray [String]
xs Ptr CString -> IO a
f = do
  ptrs <- [String] -> IO (ForeignPtr CString)
newArrayOfPtrs [String]
xs
  withForeignPtr ptrs $ \Ptr CString
ptrs' ->
    Ptr CString -> IO a
f Ptr CString
ptrs'

-- | Initialize a dry sampler
initDrySampler :: Vocab -> Int -> Float -> Float -> Int -> Int -> [String] -> IO (Either String Sampler)
initDrySampler :: Vocab
-> Int
-> Float
-> Float
-> Int
-> Int
-> [String]
-> IO (Either String Sampler)
initDrySampler (Vocab ForeignPtr CLlamaVocab
vocab) Int
nCtxTrain Float
dryMultiplier Float
dryBase Int
dryAllowedLength Int
dryPenaltyLastN [String]
seqBreakers = do
  ForeignPtr CLlamaVocab
-> (Ptr CLlamaVocab -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaVocab
vocab ((Ptr CLlamaVocab -> IO (Either String Sampler))
 -> IO (Either String Sampler))
-> (Ptr CLlamaVocab -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaVocab
vocabPtr -> do
    [String]
-> (Ptr CString -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a. [String] -> (Ptr CString -> IO a) -> IO a
withCStringArray [String]
seqBreakers ((Ptr CString -> IO (Either String Sampler))
 -> IO (Either String Sampler))
-> (Ptr CString -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a b. (a -> b) -> a -> b
$ \Ptr CString
seqBreakersPtr -> do
      samplerPtr <- CLlamaVocab
-> CInt
-> CFloat
-> CFloat
-> CInt
-> CInt
-> Ptr CString
-> CSize
-> IO (Ptr LlamaSampler)
c_llama_sampler_init_dry
        (Ptr CLlamaVocab -> CLlamaVocab
CLlamaVocab Ptr CLlamaVocab
vocabPtr)
        (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nCtxTrain)
        (Float -> CFloat
forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
dryMultiplier)
        (Float -> CFloat
forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
dryBase)
        (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
dryAllowedLength)
        (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
dryPenaltyLastN)
        Ptr CString
seqBreakersPtr
        (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([String] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [String]
seqBreakers))
      if samplerPtr == nullPtr
        then return $ Left "Failed to initialize dry sampler"
        else do
          fp <- newForeignPtr p_llama_sampler_free samplerPtr
          return $ Right $ Sampler fp

-- | Initialize a logit bias sampler
initLogitBiasSampler :: Int -> [LlamaLogitBias] -> IO (Either String Sampler)
initLogitBiasSampler :: Int -> [LlamaLogitBias] -> IO (Either String Sampler)
initLogitBiasSampler Int
nVocab [LlamaLogitBias]
logitBiases = do
  Int
-> (Ptr LlamaLogitBias -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a b. Storable a => Int -> (Ptr a -> IO b) -> IO b
allocaArray ([LlamaLogitBias] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LlamaLogitBias]
logitBiases) ((Ptr LlamaLogitBias -> IO (Either String Sampler))
 -> IO (Either String Sampler))
-> (Ptr LlamaLogitBias -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaLogitBias
logitBiasPtr -> do
    Ptr LlamaLogitBias -> [LlamaLogitBias] -> IO ()
forall a. Storable a => Ptr a -> [a] -> IO ()
pokeArray Ptr LlamaLogitBias
logitBiasPtr [LlamaLogitBias]
logitBiases
    samplerPtr <- CInt -> CInt -> Ptr LlamaLogitBias -> IO (Ptr LlamaSampler)
c_llama_sampler_init_logit_bias
      (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nVocab)
      (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([LlamaLogitBias] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LlamaLogitBias]
logitBiases))
      Ptr LlamaLogitBias
logitBiasPtr
    if samplerPtr == nullPtr
      then return $ Left "Failed to initialize logit bias sampler"
      else do
        fp <- newForeignPtr p_llama_sampler_free samplerPtr
        return $ Right $ Sampler fp

-- | Initialize an infill sampler
initInfillSampler :: Vocab -> IO (Either String Sampler)
initInfillSampler :: Vocab -> IO (Either String Sampler)
initInfillSampler (Vocab ForeignPtr CLlamaVocab
vocab) = do
  ForeignPtr CLlamaVocab
-> (Ptr CLlamaVocab -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaVocab
vocab ((Ptr CLlamaVocab -> IO (Either String Sampler))
 -> IO (Either String Sampler))
-> (Ptr CLlamaVocab -> IO (Either String Sampler))
-> IO (Either String Sampler)
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaVocab
vocabPtr -> do
    samplerPtr <- CLlamaVocab -> IO (Ptr LlamaSampler)
c_llama_sampler_init_infill (Ptr CLlamaVocab -> CLlamaVocab
CLlamaVocab Ptr CLlamaVocab
vocabPtr)
    if samplerPtr == nullPtr
      then return $ Left "Failed to initialize infill sampler"
      else do
        fp <- newForeignPtr p_llama_sampler_free samplerPtr
        return $ Right $ Sampler fp

-- | Get the seed used by a sampler
getSamplerSeed :: Sampler -> IO Word32
getSamplerSeed :: Sampler -> IO Word32
getSamplerSeed (Sampler ForeignPtr LlamaSampler
samplerFPtr) = do
  ForeignPtr LlamaSampler
-> (Ptr LlamaSampler -> IO Word32) -> IO Word32
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr LlamaSampler
samplerFPtr ((CUInt -> Word32) -> IO CUInt -> IO Word32
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap CUInt -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (IO CUInt -> IO Word32)
-> (Ptr LlamaSampler -> IO CUInt) -> Ptr LlamaSampler -> IO Word32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr LlamaSampler -> IO CUInt
c_llama_sampler_get_seed)

-- | Sample with a sampler
sampleWithSampler :: Sampler -> Context -> Int -> IO LlamaToken
sampleWithSampler :: Sampler -> Context -> Int -> IO CInt
sampleWithSampler (Sampler ForeignPtr LlamaSampler
samplerFPtr) (Context ForeignPtr CLlamaContext
ctxFPtr) Int
idx = do
  ForeignPtr LlamaSampler -> (Ptr LlamaSampler -> IO CInt) -> IO CInt
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr LlamaSampler
samplerFPtr ((Ptr LlamaSampler -> IO CInt) -> IO CInt)
-> (Ptr LlamaSampler -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \Ptr LlamaSampler
samplerPtr -> do
    ForeignPtr CLlamaContext
-> (Ptr CLlamaContext -> IO CInt) -> IO CInt
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CLlamaContext
ctxFPtr ((Ptr CLlamaContext -> IO CInt) -> IO CInt)
-> (Ptr CLlamaContext -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \Ptr CLlamaContext
ctxPtr ->
      Ptr LlamaSampler -> CLlamaContext -> CInt -> IO CInt
c_llama_sampler_sample Ptr LlamaSampler
samplerPtr (Ptr CLlamaContext -> CLlamaContext
CLlamaContext Ptr CLlamaContext
ctxPtr) (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
idx)