diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md index 96ab4057171c..c3a82fa0e60d 100644 --- a/docs/source/en/internal/generation_utils.md +++ b/docs/source/en/internal/generation_utils.md @@ -153,6 +153,9 @@ generation. [[autodoc]] TemperatureLogitsWarper - __call__ +[[autodoc]] TopHLogitsWarper + - __call__ + [[autodoc]] TopKLogitsWarper - __call__ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index b7e302452b54..c609b9e7e006 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -422,6 +422,7 @@ "SynthIDTextWatermarkingConfig", "SynthIDTextWatermarkLogitsProcessor", "TemperatureLogitsWarper", + "TopHLogitsWarper", "TopKLogitsWarper", "TopPLogitsWarper", "TypicalLogitsWarper", @@ -587,6 +588,7 @@ from .generation import TemperatureLogitsWarper as TemperatureLogitsWarper from .generation import TextIteratorStreamer as TextIteratorStreamer from .generation import TextStreamer as TextStreamer + from .generation import TopHLogitsWarper as TopHLogitsWarper from .generation import TopKLogitsWarper as TopKLogitsWarper from .generation import TopPLogitsWarper as TopPLogitsWarper from .generation import TypicalLogitsWarper as TypicalLogitsWarper diff --git a/src/transformers/generation/__init__.py b/src/transformers/generation/__init__.py index ccde5d8bc19c..92ef3184e773 100644 --- a/src/transformers/generation/__init__.py +++ b/src/transformers/generation/__init__.py @@ -67,6 +67,7 @@ "SuppressTokensAtBeginLogitsProcessor", "SynthIDTextWatermarkLogitsProcessor", "TemperatureLogitsWarper", + "TopHLogitsWarper", "TopKLogitsWarper", "TopPLogitsWarper", "TypicalLogitsWarper", @@ -153,6 +154,7 @@ SuppressTokensLogitsProcessor, SynthIDTextWatermarkLogitsProcessor, TemperatureLogitsWarper, + TopHLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, TypicalLogitsWarper, diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 4c5530ebe759..ee24e34c678e 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -167,6 +167,12 @@ class GenerationConfig(PushToHubMixin): Minimum token probability, which will be scaled by the probability of the most likely token. It must be a value between 0 and 1. Typical values are in the 0.01-0.2 range, comparably selective as setting `top_p` in the 0.99-0.8 range (use the opposite of normal `top_p` values). + top_h (`float`, *optional*): + Entropy budget scaling factor, which controls how much of the distribution’s entropy is preserved when sampling. + Must be a value between 0 and 1. At each step, tokens are sorted by probability, and the smallest prefix of tokens + is kept whose *renormalized* entropy is less than or equal to `top_h` times the entropy of the full distribution. + Smaller values (e.g., 0.2–0.5) lead to more focused, deterministic outputs, while values closer to 1.0 allow more + randomness and diversity. Typical values are in the 0.3–0.6 range. typical_p (`float`, *optional*, defaults to 1.0): Local typicality measures how similar the conditional probability of predicting a target token next is to the expected conditional probability of predicting a random token next, given the partial text already @@ -357,6 +363,7 @@ def __init__(self, **kwargs): self.top_k = kwargs.pop("top_k", 50) self.top_p = kwargs.pop("top_p", 1.0) self.min_p = kwargs.pop("min_p", None) + self.top_h = kwargs.pop("top_h", None) self.typical_p = kwargs.pop("typical_p", 1.0) self.epsilon_cutoff = kwargs.pop("epsilon_cutoff", 0.0) self.eta_cutoff = kwargs.pop("eta_cutoff", 0.0) @@ -581,6 +588,8 @@ def validate(self, strict=False): minor_issues["top_p"] = greedy_wrong_parameter_msg.format(flag_name="top_p", flag_value=self.top_p) if self.min_p is not None: minor_issues["min_p"] = greedy_wrong_parameter_msg.format(flag_name="min_p", flag_value=self.min_p) + if self.top_h is not None: + minor_issues["top_h"] = greedy_wrong_parameter_msg.format(flag_name="top_h", flag_value=self.top_h) if self.typical_p is not None and self.typical_p != 1.0: minor_issues["typical_p"] = greedy_wrong_parameter_msg.format( flag_name="typical_p", flag_value=self.typical_p diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index ce150f790051..63940b17d819 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -581,6 +581,112 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to return scores_processed +class TopHLogitsWarper(LogitsProcessor): + """ + [`LogitsProcessor`] that implements Top-H sampling, a decoding method which adaptively selects a subset of + high-probability tokens based on entropy and cumulative probability constraints. + + This method dynamically determines how many tokens to keep by analyzing the entropy difference of the selected + distribution, thereby balancing exploration and exploitation. It ensures that generated text maintains both + diversity and coherence. + + Reference: + For details, see *Top-H Decoding: Adapting the Creativity and Coherence with Bounded Entropy in Text Generation* + (NeurIPS 2025): https://arxiv.org/abs/2509.02510 + + Args: + top_h (`float`): + Scaling coefficient for the entropy-based threshold (`tau`). Must be in the range `(0, 1]`. + + filter_value (`float`, *optional*, defaults to -inf): + All filtered values will be set to this float value. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM + + >>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B") + + >>> inputs = tokenizer("A sequence: 1, 2", return_tensors="pt") + + >>> outputs = model.generate(**inputs, do_sample=True, top_h=0.4) + >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]) + A sequence: 1, 2, 3, 4, 5, 6, 7, 8, 9 + ``` + """ + + def __init__(self, top_h: float, filter_value: float = -float("Inf")): + super().__init__() + + # input checks + if not (0 < top_h <= 1): + raise ValueError("`top_h` must be in the range (0, 1].") + + # Maximum number of top tokens to consider before applying the entropy-based filter. + # Acts as a cap for efficiency and numerical stability — increasing this allows more + # tokens to be evaluated but may slow down generation. Default is 100. + self.top_n = 100 + + self.top_h = top_h + self.filter_value = filter_value + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + """ + Filters logits using Top-H sampling. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Input token IDs. + scores (`torch.FloatTensor` of shape `(batch_size, vocab_size)`): + Raw logits from the model. + + Return: + `torch.FloatTensor` of shape `(batch_size, vocab_size)`: + Processed logits where invalid tokens are masked with `-inf`. + """ + batch_size, vocab_size = scores.shape + device = scores.device + keep_mask = torch.zeros((batch_size, vocab_size), dtype=torch.bool, device=device) + top_n = min(self.top_n, vocab_size) + + # 1. Get top-k logits and indices for the whole batch + top_logits, top_idx = torch.topk(scores, top_n, dim=-1, largest=True, sorted=True) + + # 2. Create a batch of categorical distributions + dist = torch.distributions.Categorical(logits=top_logits) + probs = dist.probs + log_probs = torch.log(probs) # dist.log_prob(idx) + + # 3. Calculate the entropy-based threshold tau for the whole batch + # We unsqueeze tau to enable broadcasting against the cumulative entropy tensor. + tau = (dist.entropy() * self.top_h).unsqueeze(-1) + + # 4. Calculate cumulative entropy using torch.cumsum + # The individual entropy terms (-p * log(p)) are calculated for all top_n tokens at once. + entropy_terms = -probs * log_probs + cumulative_entropy = torch.cumsum(entropy_terms, dim=-1) + + # 5. Determine which tokens to keep based on the stopping condition + # Create a boolean mask for the top_n tokens. + # Stopping rule: keep adding tokens in order of probability until the cumulative entropy + # exceeds the threshold τ = H(p) * top_h. This ensures diversity (via entropy) while + # guaranteeing at least the most probable token is always included. + selection_mask = cumulative_entropy <= tau + selection_mask[:, 0] = True + + # 6. Update the final keep_mask for the entire batch in one operation + # The scatter_ operation efficiently updates the keep_mask at the indices + # specified by top_idx with the boolean values from selection_mask. + keep_mask.scatter_(dim=1, index=top_idx, src=selection_mask) + + # apply filtering + scores_processed = scores.clone() + scores_processed[~keep_mask] = self.filter_value + return scores_processed + + class MinPLogitsWarper(LogitsProcessor): """ [`LogitsProcessor`] that performs min-p, i.e. keeps all tokens that are above a minimum probability, scaled by the diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 97b98f96c202..f08a24e5af96 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -93,6 +93,7 @@ SuppressTokensAtBeginLogitsProcessor, SuppressTokensLogitsProcessor, TemperatureLogitsWarper, + TopHLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, TypicalLogitsWarper, @@ -1243,6 +1244,8 @@ def _get_logits_processor( # all samplers can be found in `generation_utils_samplers.py` if generation_config.temperature is not None and generation_config.temperature != 1.0: processors.append(TemperatureLogitsWarper(generation_config.temperature)) + if generation_config.top_h is not None: + processors.append(TopHLogitsWarper(top_h=generation_config.top_h)) if generation_config.top_k is not None and generation_config.top_k != 0: processors.append( TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep) diff --git a/tests/generation/test_logits_process.py b/tests/generation/test_logits_process.py index 768e216ef534..06531b52f5a5 100644 --- a/tests/generation/test_logits_process.py +++ b/tests/generation/test_logits_process.py @@ -49,6 +49,7 @@ SequenceBiasLogitsProcessor, SynthIDTextWatermarkLogitsProcessor, TemperatureLogitsWarper, + TopHLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, TypicalLogitsWarper, @@ -394,6 +395,95 @@ def test_top_p_dist_warper(self): # first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2. self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [3, 2]) + def test_top_h_dist_warper(self): + """ + We construct small distributions where the expected kept set is obvious for a given alpha. + We pass *log-probabilities* as "scores" so that softmax(scores) == original probabilities, + matching the style in other warper tests (e.g., MinP). + """ + + input_ids = None + + # --- Case 1: Highly peaked distribution -> small alpha keeps only the top-1 + dist1 = torch.log( + torch.tensor( + [[0.97, 0.01, 0.01, 0.01]], + device=torch_device, + dtype=torch.float, + ) + ) + top_h_warp = TopHLogitsWarper(top_h=0.3) + filtered_logits = top_h_warp(input_ids, dist1.clone()) + filtered_dist = torch.exp(filtered_logits) # exp(-inf) -> 0 + + EXPECTED1 = torch.tensor( + [[0.97, 0.0, 0.0, 0.0]], + device=torch_device, + dtype=torch.float, + ) + torch.testing.assert_close(filtered_dist, EXPECTED1, rtol=1e-3, atol=1e-3) + + # --- Case 2: Moderately skewed distribution -> alpha large enough to keep exactly top-2 + dist2 = torch.log( + torch.tensor( + [[0.4, 0.3, 0.2, 0.1]], # entropy budget with alpha=0.7 yields 2-token prefix + device=torch_device, + dtype=torch.float, + ) + ) + top_h_warp = TopHLogitsWarper(top_h=0.7) + filtered_logits = top_h_warp(input_ids, dist2.clone()) + filtered_dist = torch.exp(filtered_logits) + + EXPECTED2 = torch.tensor( + [[0.4, 0.3, 0.0, 0.0]], + device=torch_device, + dtype=torch.float, + ) + torch.testing.assert_close(filtered_dist, EXPECTED2, rtol=1e-3, atol=1e-3) + + # --- Case 3: Uniform distribution -> alpha=1.0 keeps all tokens + dist3 = torch.log( + torch.tensor( + [[0.25, 0.25, 0.25, 0.25]], + device=torch_device, + dtype=torch.float, + ) + ) + top_h_warp = TopHLogitsWarper(top_h=1.0) + filtered_logits = top_h_warp(input_ids, dist3.clone()) + filtered_dist = torch.exp(filtered_logits) + + EXPECTED3 = torch.tensor( + [[0.25, 0.25, 0.25, 0.25]], + device=torch_device, + dtype=torch.float, + ) + torch.testing.assert_close(filtered_dist, EXPECTED3, rtol=1e-3, atol=1e-3) + + # --- Case 4: Probabilities including 0 value + dist4 = torch.log( + torch.tensor( + [[0.75, 0.25, 0.0, 0.0]], + device=torch_device, + dtype=torch.float, + ) + ) + top_h_warp = TopHLogitsWarper(top_h=0.4) + filtered_logits = top_h_warp(input_ids, dist4.clone()) + filtered_dist = torch.exp(filtered_logits) + + EXPECTED4 = torch.tensor( + [[0.75, 0.0, 0.0, 0.0]], + device=torch_device, + dtype=torch.float, + ) + torch.testing.assert_close(filtered_dist, EXPECTED4, rtol=1e-3, atol=1e-3) + # Processor should not change logits in-place + top_h_warp = TopHLogitsWarper(top_h=0.5) + out_again = top_h_warp(input_ids, dist3) + assert not torch.all(out_again == dist3) + def test_min_p_dist_warper(self): input_ids = None vocab_size = 10 diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index c838f8c885d5..829b8824142c 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -3059,6 +3059,34 @@ def test_synthid_text_watermark_generation_mean_expected_bias(self): ) self.assertTrue(torch.all(is_close)) + @slow + def test_TopH_example_integration(self): + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B") + model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-3B") + tokenizer.pad_token = tokenizer.eos_token + model.config.pad_token_id = tokenizer.pad_token_id + encoder_input_str = "Tell me a joke about a monkey." + input_ids = tokenizer(encoder_input_str, return_tensors="pt") + + torch.manual_seed(0) + + outputs = model.generate( + **input_ids, + eos_token_id=model.config.eos_token_id, + do_sample=True, + temperature=1.0, + top_h=0.4, + max_new_tokens=32, + pad_token_id=tokenizer.pad_token_id, + ) + outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) + self.assertListEqual( + outputs, + [ + 'Tell me a joke about a monkey. Why did the monkey go to the doctor? Because he was feeling a little "tropic"!' + ], + ) + @slow def test_beam_search_example_integration(self): # exactly the example provided in the docstrings of beam search, which previously