diff --git a/llama_cpp/_internals.py b/llama_cpp/_internals.py index b5175a7f2..6f396e49a 100644 --- a/llama_cpp/_internals.py +++ b/llama_cpp/_internals.py @@ -703,6 +703,28 @@ def add_temp_ext(self, t: float, delta: float, exponent: float): def add_xtc(self, p: float, t: float, min_keep: int, seed: int): sampler = llama_cpp.llama_sampler_init_xtc(p, t, min_keep, seed) llama_cpp.llama_sampler_chain_add(self.sampler, sampler) + + # def add_dry(self, model: LlamaModel, ctx: LlamaContext, multiplier: float, base: float, + # allowed_length: int, penalty_last_n: int, seq_breakers: list[str] = []): + + # # Convert Python strings to bytes + # seq_breakers_bytes = [s.encode('utf-8') for s in seq_breakers] + # # Create array of char* + # arr = (ctypes.c_char_p * len(seq_breakers_bytes))(*seq_breakers_bytes) + # sampler = llama_cpp.llama_sampler_init_dry(model.vocab, ctx.n_ctx(), multiplier, base, + # allowed_length, penalty_last_n, + # arr, len(seq_breakers)) + # self._add_sampler(sampler) + # def add_dry(self, m: float, b: float, l: int, n: int, breakers: List[str]): + # # Convert breakers to C array + # seq_breakers_bytes = [s.encode('utf-8') for s in breakers] + # arr = (ctypes.c_char_p * len(seq_breakers_bytes))(*seq_breakers_bytes) + # sampler = llama_cpp.llama_sampler_init_dry( + # None, 0, m, b, l, n, + # arr, + # len(breakers) + # ) + # llama_cpp.llama_sampler_chain_add(self.sampler, sampler) def add_top_n_sigma(self, n: float): sampler = llama_cpp.llama_sampler_init_top_n_sigma(n) @@ -766,7 +788,7 @@ def add_penalties( def add_dry( self, model: LlamaModel, - n_ctx_train: int, + ctx: LlamaContext, dry_multiplier: float, dry_base: float, dry_allowed_length: int, @@ -779,8 +801,16 @@ def add_dry( breaker_ptrs[i] = breaker.encode("utf-8") sampler = llama_cpp.llama_sampler_init_dry( + # vocab=model.vocab, + # n_ctx_train=n_ctx_train, + # dry_multiplier=dry_multiplier, + # dry_base=dry_base, + # dry_allowed_length=dry_allowed_length, + # dry_penalty_last_n=dry_penalty_last_n, + # breaker_ptrs=breaker_ptrs, + # num_breakers=len(seq_breakers) model.vocab, - n_ctx_train, + ctx.n_ctx(), dry_multiplier, dry_base, dry_allowed_length, diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 71d94ebd8..64b5ffb30 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -682,6 +682,13 @@ def _init_sampler( mirostat_mode: int = 0, mirostat_eta: float = 0.1, mirostat_tau: float = 5.0, + xtc_probability: float = 0.0, + xtc_threshold: float = 0.1, + dry_multiplier: float = 0.0, + dry_allowed_length: int = 2, + dry_base: float = 1.75, + dry_penalty_last_n: int = -1, + dry_seq_breakers: list[str] = [], penalize_nl: bool = True, logits_processor: Optional[LogitsProcessorList] = None, grammar: Optional[LlamaGrammar] = None, @@ -749,11 +756,13 @@ def apply_func(token_data_array: llama_cpp.llama_token_data_array_p): else: n_probs = 0 min_keep = max(1, n_probs) + sampler.add_dry(self._model, self._ctx, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, dry_seq_breakers) sampler.add_top_k(top_k) sampler.add_typical(typical_p, min_keep) sampler.add_top_p(top_p, min_keep) sampler.add_min_p(min_p, min_keep) sampler.add_temp(temp) + sampler.add_xtc(xtc_probability, xtc_threshold, min_keep, self._seed) sampler.add_dist(self._seed) return sampler @@ -771,6 +780,13 @@ def sample( mirostat_mode: int = 0, mirostat_eta: float = 0.1, mirostat_tau: float = 5.0, + xtc_probability: float = 0.0, + xtc_threshold: float = 0.1, + dry_multiplier: float = 0.0, + dry_allowed_length: int = 2, + dry_base: float = 1.75, + dry_penalty_last_n: int = -1, + dry_seq_breakers: list[str] = [], penalize_nl: bool = True, logits_processor: Optional[LogitsProcessorList] = None, grammar: Optional[LlamaGrammar] = None, @@ -806,6 +822,13 @@ def sample( mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + xtc_probability=xtc_probability, + xtc_threshold=xtc_threshold, + dry_multiplier=dry_multiplier, + dry_allowed_length=dry_allowed_length, + dry_base=dry_base, + dry_penalty_last_n=dry_penalty_last_n, + dry_seq_breakers=dry_seq_breakers, penalize_nl=penalize_nl, logits_processor=logits_processor, grammar=grammar, @@ -835,6 +858,13 @@ def generate( mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, + xtc_probability: float = 0.0, + xtc_threshold: float = 0.1, + dry_multiplier: float = 0.0, + dry_allowed_length: int = 2, + dry_base: float = 1.75, + dry_penalty_last_n: int = -1, + dry_seq_breakers: list[str] = [], penalize_nl: bool = True, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, @@ -874,6 +904,13 @@ def generate( mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + xtc_probability=xtc_probability, + xtc_threshold=xtc_threshold, + dry_multiplier=dry_multiplier, + dry_allowed_length=dry_allowed_length, + dry_base=dry_base, + dry_penalty_last_n=dry_penalty_last_n, + dry_seq_breakers=dry_seq_breakers, penalize_nl=penalize_nl, logits_processor=logits_processor, grammar=grammar, @@ -926,6 +963,13 @@ def generate( mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + xtc_probability=xtc_probability, + xtc_threshold=xtc_threshold, + dry_multiplier=dry_multiplier, + dry_allowed_length=dry_allowed_length, + dry_base=dry_base, + dry_penalty_last_n=dry_penalty_last_n, + dry_seq_breakers=dry_seq_breakers, logits_processor=logits_processor, grammar=grammar, penalize_nl=penalize_nl, @@ -1142,6 +1186,13 @@ def _create_completion( mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, + xtc_probability: float = 0.0, + xtc_threshold: float = 0.1, + dry_multiplier: float = 0.0, + dry_allowed_length: int = 2, + dry_base: float = 1.75, + dry_penalty_last_n: int = -1, + dry_seq_breakers: list[str] = [], model: Optional[str] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, logits_processor: Optional[LogitsProcessorList] = None, @@ -1330,6 +1381,13 @@ def logit_bias_processor( mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + xtc_probability=xtc_probability, + xtc_threshold=xtc_threshold, + dry_multiplier=dry_multiplier, + dry_allowed_length=dry_allowed_length, + dry_base=dry_base, + dry_penalty_last_n=dry_penalty_last_n, + dry_seq_breakers=dry_seq_breakers, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, repeat_penalty=repeat_penalty, @@ -1762,6 +1820,13 @@ def create_completion( mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, + xtc_probability: float = 0.0, + xtc_threshold: float = 0.1, + dry_multiplier: float = 0.0, + dry_allowed_length: int = 2, + dry_base: float = 1.75, + dry_penalty_last_n: int = -1, + dry_seq_breakers: list[str] = [], model: Optional[str] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, logits_processor: Optional[LogitsProcessorList] = None, @@ -1791,6 +1856,13 @@ def create_completion( mirostat_mode: The mirostat sampling mode. mirostat_tau: The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. mirostat_eta: The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. + xtc_probability: ??? + xtc_threshold: ??? + dry_multiplier: ??? + dry_allowed_length: ??? + dry_base: ??? + dry_penalty_last_n: ??? + dry_seq_breakers: ??? model: The name to use for the model in the completion object. stopping_criteria: A list of stopping criteria to use. logits_processor: A list of logits processors to use. @@ -1825,6 +1897,13 @@ def create_completion( mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + xtc_probability=xtc_probability, + xtc_threshold=xtc_threshold, + dry_multiplier=dry_multiplier, + dry_allowed_length=dry_allowed_length, + dry_base=dry_base, + dry_penalty_last_n=dry_penalty_last_n, + dry_seq_breakers=dry_seq_breakers, model=model, stopping_criteria=stopping_criteria, logits_processor=logits_processor, @@ -1859,6 +1938,13 @@ def __call__( mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, + xtc_probability: float = 0.0, + xtc_threshold: float = 0.1, + dry_multiplier: float = 0.0, + dry_allowed_length: int = 2, + dry_base: float = 1.75, + dry_penalty_last_n: int = -1, + dry_seq_breakers: list[str] = [], model: Optional[str] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, logits_processor: Optional[LogitsProcessorList] = None, @@ -1922,6 +2008,13 @@ def __call__( mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + xtc_probability=xtc_probability, + xtc_threshold=xtc_threshold, + dry_multiplier=dry_multiplier, + dry_allowed_length=dry_allowed_length, + dry_base=dry_base, + dry_penalty_last_n=dry_penalty_last_n, + dry_seq_breakers=dry_seq_breakers, model=model, stopping_criteria=stopping_criteria, logits_processor=logits_processor, @@ -1953,6 +2046,13 @@ def create_chat_completion( mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, + xtc_probability: float = 0.0, + xtc_threshold: float = 0.1, + dry_multiplier: float = 0.0, + dry_allowed_length: int = 2, + dry_base: float = 1.75, + dry_penalty_last_n: int = -1, + dry_seq_breakers: list[str] = [], model: Optional[str] = None, logits_processor: Optional[LogitsProcessorList] = None, grammar: Optional[LlamaGrammar] = None, @@ -2026,6 +2126,13 @@ def create_chat_completion( mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + xtc_probability=xtc_probability, + xtc_threshold=xtc_threshold, + dry_multiplier=dry_multiplier, + dry_allowed_length=dry_allowed_length, + dry_base=dry_base, + dry_penalty_last_n=dry_penalty_last_n, + dry_seq_breakers=dry_seq_breakers, model=model, logits_processor=logits_processor, grammar=grammar, diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 711d42a6a..f1e4324b0 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -3884,6 +3884,41 @@ def llama_sampler_init_xtc( ) -> llama_sampler_p: ... +# LLAMA_API struct llama_sampler * llama_sampler_init_dry( +# const struct llama_vocab * vocab, +# int32_t context_size, +# float dry_multiplier, +# float dry_base, +# int32_t dry_allowed_length, +# int32_t dry_penalty_last_n, +# const char ** seq_breakers, +# size_t num_breakers); +@ctypes_function( +"llama_sampler_init_dry", + [ + llama_vocab_p_ctypes, + ctypes.c_int32, + ctypes.c_float, + ctypes.c_float, + ctypes.c_int32, + ctypes.c_int32, + ctypes.POINTER(ctypes.c_char_p), + ctypes.c_size_t + ], + llama_sampler_p_ctypes, +) +def llama_sampler_init_dry( + vocab: llama_vocab_p, + context_size: int, + dry_multiplier: float, + dry_base: float, + dry_allowed_length: int, + dry_penalty_last_n: int, + seq_breakers: list[str], + num_breakers: int, +) -> llama_sampler_p: + ... + # /// @details Top n sigma sampling as described in academic paper "Top-nσ: Not All Logits Are You Need" https://arxiv.org/pdf/2411.07641 # LLAMA_API struct llama_sampler * llama_sampler_init_top_n_sigma(float n);