Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
207 commits
Select commit Hold shift + click to select a range
a855b64
init
ErfanBaghaei Sep 9, 2025
2109ccf
added TopH
ErfanBaghaei Sep 9, 2025
9902115
Update TopH logits_process.py
ErfanBaghaei Sep 11, 2025
519d675
Update logits_process.py
ErfanBaghaei Sep 11, 2025
d56a261
Update test_logits_process.py
ErfanBaghaei Sep 11, 2025
cd30869
Update test_logits_process.py
ArminAzizi98 Sep 11, 2025
c7c1472
added test No. 4
ErfanBaghaei Sep 12, 2025
febfc04
Merge branch 'main' into Top-H-Decoding
ErfanBaghaei Sep 12, 2025
91bc1b7
Resolving __init__.py issues
ErfanBaghaei Sep 12, 2025
009aa73
Resolving configuration_utils.py Issues
ErfanBaghaei Sep 12, 2025
872bd47
Resolving logits_process.py Issues
ErfanBaghaei Sep 12, 2025
2054fb6
Resolving utils.py Issues
ErfanBaghaei Sep 12, 2025
5bc900d
Resolving test_logits_process.py Issues
ErfanBaghaei Sep 12, 2025
768bda6
Resolving __init__.py issues
ErfanBaghaei Sep 12, 2025
d843f1c
Resolving logits_process.py Issues
ErfanBaghaei Sep 12, 2025
290f97d
Resolving __init__.py issues
ErfanBaghaei Sep 12, 2025
2b785ad
Updated Docs
ErfanBaghaei Sep 12, 2025
f35b6ce
Updated Docstring
ErfanBaghaei Sep 12, 2025
a566561
style: autoformat with make fixup
ErfanBaghaei Sep 12, 2025
49a611d
Fixing Docstring
ErfanBaghaei Sep 12, 2025
3fb3a87
Update logits_process.py removed defaults
ErfanBaghaei Sep 16, 2025
4917572
Variable H name -> cumulative_entropy
ErfanBaghaei Sep 22, 2025
11ef0a2
Using torch.distributions.Categorical
ErfanBaghaei Sep 25, 2025
90a3d94
Improve torch_dtype checks (#40808)
cyyever Sep 12, 2025
2db9152
Add VideoProcessors to auto-backend requirements (#40843)
Cyrilvallez Sep 12, 2025
e71afc5
Adds Causal Conv 1D kernel for mamba models (#40765)
MekkCyber Sep 12, 2025
c19ca3e
Update no split modules in T5Gemma model (#40810)
npuichigo Sep 12, 2025
1817410
Replace image classification loss functions to `self.loss_function` (…
qubvel Sep 12, 2025
fb2795e
Fix the misalignment between the l2norm in GDN of Qwen3-Next and the …
bozheng-hit Sep 12, 2025
a300d04
Fixes for continuous batching (#40828)
remi-or Sep 12, 2025
d5ab59f
[tests] re-enable aria fast tests (#40846)
gante Sep 12, 2025
e25fcbf
[SAM2] Fix inconsistent results with original implementation with inp…
yonigozlan Sep 12, 2025
1e83816
[Sam2Video] Fix video inference with batched boxes and add test (#40797)
yonigozlan Sep 12, 2025
55d3458
add: differential privacy research model (#40851)
RyanMullins Sep 12, 2025
1814fa6
[test] Fix test_eager_matches_sdpa incorrectly skipped (#40852)
eustlb Sep 12, 2025
d78b3a9
[tests] move generative tests away from `test_modeling_common.py` (#4…
gante Sep 12, 2025
a790005
[generate] Always use decoder config to init cache (#40772)
gante Sep 12, 2025
e135660
Use checkpoint in auto_class_docstring (#40844)
cyyever Sep 13, 2025
c274330
Fix TrainingArguments.parallelism_config NameError with accelerate<1.…
albertvillanova Sep 14, 2025
33c51a8
Redirect MI355 CI results to dummy dataset (#40862)
ahadnagy Sep 14, 2025
c8416c8
[Bug fix #40813] Fix base_model_tp_plan of Starcoder2 model. (#40814)
greg-kwasniewski1 Sep 15, 2025
6cf9c59
[docstrings / type hints] Update outdated annotations for `past_key_v…
gante Sep 15, 2025
23e87bb
fix florence kwargs (#40826)
SunMarc Sep 15, 2025
7e29410
fix: XIELU act parameters not being casted to correct dtype (#40812)
NanoCode012 Sep 15, 2025
bb5b768
Update model tags and integration references in bug report (#40881)
ArthurZucker Sep 15, 2025
d69d754
[Qwen3 Next] Use numerically stable `rsqrt` (#40848)
thalahors Sep 15, 2025
4d3d07f
Adding Support for Qwen3-VL Series (#40795)
JJJYmmm Sep 15, 2025
f8b3311
[`VaultGemma`] Update expectations in integration tests (#40855)
vasqu Sep 15, 2025
dd64685
Fix modular consistency (#40883)
Cyrilvallez Sep 15, 2025
d8a69ff
🔴 Move variable output controls to `_prepare_generation_config ` (#40…
manueldeprada Sep 15, 2025
777b559
Clarify passing is_causal in sdpa_attention_paged_forward (#40838)
cyyever Sep 15, 2025
b13c6d8
Use torch.expm1 and torch.log1p for better numerical results (#40860)
cyyever Sep 15, 2025
332286f
Add Fast PromptDepthAnything Processor (#40602)
SamuelBarryCS Sep 15, 2025
a4d417c
Fix deta loading & dataclass (#40878)
Cyrilvallez Sep 15, 2025
493dd21
Remove dict branch of attention_mask in sdpa_attention_paged_forward …
cyyever Sep 15, 2025
aa8fea4
🌐 [i18n-KO] Translated smolvlm.md to Korean (#40414)
HyunZ118 Sep 15, 2025
cbe9f2e
🌐 [i18n-KO] Translated `imageprocessor.md` to Korean (#39557)
HyunZ118 Sep 15, 2025
23772dc
[generate] remove docs of a feature that no longer exists (#40895)
gante Sep 15, 2025
b4d7f5f
Make debugging failing tests (check and update expect output values) …
ydshieh Sep 16, 2025
60c9553
Fixing the call to kernelize (#40628)
MekkCyber Sep 16, 2025
294ec23
Fix getter regression (#40824)
molbap Sep 16, 2025
dcb52bf
Fix flaky `Gemma3nAudioFeatureExtractionTest::test_dither` (#40902)
ydshieh Sep 16, 2025
b947b60
[cache] Merge static sliding and static chunked layer (#40893)
Cyrilvallez Sep 16, 2025
f096c5b
Harmonize CacheLayer names (#40892)
Cyrilvallez Sep 16, 2025
15d5f49
[cache] Only use scalars in `get_mask_sizes` (#40907)
Cyrilvallez Sep 16, 2025
288352d
Set seed for `Glm4vIntegrationTest` (#40905)
ydshieh Sep 16, 2025
a418ac8
Add Olmo3 model (#40778)
2015aroras Sep 16, 2025
8534f2d
remove dummy EncodingFast (#40864)
cyyever Sep 16, 2025
1f0df5f
Improve module name handling for local custom code (#40809)
XuehaiPan Sep 16, 2025
b067650
Remove `runner_map` (#40880)
ydshieh Sep 16, 2025
5c7684e
disable `test_fast_is_faster_than_slow` (#40909)
ydshieh Sep 16, 2025
f8fb8a5
[gemma3] `Gemma3ForConditionalGeneration` compatible with assisted ge…
gante Sep 16, 2025
4248a67
[generate] misc fixes (#40906)
gante Sep 16, 2025
cc4f313
🔴Make `center_crop` fast equivalent to slow (#40856)
yonigozlan Sep 16, 2025
c689f16
Fix dtype in Paligemma (#40912)
zucchini-nlp Sep 16, 2025
cf7356b
[Docs] Adding documentation of MXFP4 Quantization (#40885)
ariG23498 Sep 16, 2025
053228a
Processor load with multi-processing (#40786)
zucchini-nlp Sep 17, 2025
030af75
[Llama4] Remove `image_sizes` arg and deprecate `vision_feature_layer…
yaswanth19 Sep 17, 2025
3e2e555
Fix #40067: Add dedicated UMT5 support to GGUF loader (config, tokeni…
akshay-babbar Sep 17, 2025
1575c03
[torchao safetensors] renaming get_state_dict function (#40774)
liangel-02 Sep 17, 2025
901e5d7
Adding activation kernels (#40890)
MekkCyber Sep 17, 2025
8b942de
Minor fix for #40727 (#40929)
ydshieh Sep 17, 2025
1935c22
Add support for Florence-2 training (#40914)
ducviet00 Sep 17, 2025
2e287d1
Add LongCat-Flash (#40730)
molbap Sep 17, 2025
9baa3d6
[DOC] Add missing dates in model cards (#40922)
yonigozlan Sep 17, 2025
dccd2df
[models] remove unused `import torch.utils.checkpoint` (#40934)
gante Sep 17, 2025
da501ec
Intel CPU dockerfile (#40806)
jiqing-feng Sep 17, 2025
385aeb6
docs(i18n): Correct the descriptive text in the README_zh-hans.md (#4…
lilin-1 Sep 17, 2025
e7a14d9
Fix trainer tests (#40823)
SunMarc Sep 17, 2025
d8d78c6
Fix `Glm4vMoeIntegrationTest` (#40930)
ydshieh Sep 17, 2025
f0150ad
Raise error instead of warning when using meta device in from_pretrai…
Cyrilvallez Sep 17, 2025
8b8b353
Consistent naming for images kwargs (#40834)
zucchini-nlp Sep 17, 2025
5301d16
Remove nested import logic for torchvision (#40940)
yonigozlan Sep 17, 2025
b5cbfd5
Fix `Glm4vModelTest::test_eager_matches_fa2_generate` (#40947)
ydshieh Sep 17, 2025
b8207cb
Update expected values for some `test_speculative_generation` (#40949)
ydshieh Sep 17, 2025
f962aaf
Standardize audio embedding function name for audio multimodal models…
jackzhxng Sep 18, 2025
cd1a661
Add FlexOlmo model (#40921)
2015aroras Sep 18, 2025
3ab94a1
Don't list dropout in eager_paged_attention_forward (#40924)
cyyever Sep 18, 2025
9f65eab
Update expected values for one more `test_speculative_generation` aft…
ydshieh Sep 18, 2025
4a5f348
FIX(trainer): ensure final checkpoint is saved when resuming training…
rangehow Sep 18, 2025
b38d52a
Add new model LFM2-VL (#40624)
zucchini-nlp Sep 18, 2025
4d4932e
Fix outdated version checks of accelerator (#40969)
cyyever Sep 18, 2025
9104de8
Use `skip_predictor=True` in vjepa2 `get_vision_features` (#40966)
hamishs Sep 18, 2025
b9ad602
[Trainer] Fix DP loss (#40799)
SunMarc Sep 18, 2025
55e48bf
[timm_wrapper] better handling of "Unknown model" exception in timm (…
harshaljanjani Sep 18, 2025
ca8eed3
Fix Issue #39030: AutoTokenizer.from_pretrained does not propagate to…
brandenkmurray Sep 18, 2025
3373554
[tests] Really use small models in all fast tests (#40945)
Cyrilvallez Sep 18, 2025
1e8b8d3
Add captured actual outputs to CI artifacts (#40965)
ydshieh Sep 18, 2025
e5da669
Revert change in `compile_friendly_resize` (#40645)
qubvel Sep 18, 2025
740ff67
Track the CI (model) jobs that don't produce test output files (proce…
ydshieh Sep 18, 2025
c9b01c3
Using torch.distributions.Categorical
ErfanBaghaei Sep 25, 2025
345c86a
Remove `set_model_tester_for_less_flaky_tests` (#40982)
Cyrilvallez Sep 18, 2025
b16d054
Benchmarking v2 GH workflows (#40716)
ahadnagy Sep 19, 2025
e0fb372
🔴[`Attention`] Bert-based Models Attention Refactor (#38301)
vasqu Sep 19, 2025
0dbfde2
Remove [[autodoc]] refs to TF/Flax objects (#40996)
Cyrilvallez Sep 19, 2025
46922b3
ENH: Enable readline support for transformers chat (#40911)
BenjaminBossan Sep 19, 2025
dbc0952
[testing] test `num_hidden_layers` being small in model tester (#40992)
ydshieh Sep 19, 2025
17be25b
blt wip (#38579)
itazap Sep 19, 2025
4a17be0
[docs] rm stray tf/flax autodocs references (#40999)
gante Sep 19, 2025
e08f64c
[`RMSNorm`] Fix rms norm init for models that center around 1 (#40796)
vasqu Sep 19, 2025
40dcb51
Make `EfficientLoFTRModelTest` faster (#41000)
ydshieh Sep 19, 2025
85702fd
Fix typoes in src and tests (#40845)
cyyever Sep 19, 2025
d471b2e
Fix more dates in model cards and wrong modalities in _toctree.yml (#…
yonigozlan Sep 19, 2025
ae88512
RUFF fix on CI scripts (#40805)
cyyever Sep 19, 2025
c52a158
fix dict like init for ModelOutput (#41002)
SunMarc Sep 19, 2025
425b2b4
🚨 [v5] remove generate output retrocompatibility aliases (#40998)
gante Sep 19, 2025
4e05e80
[tests] update `test_left_padding_compatibility` (and minimize overwr…
gante Sep 19, 2025
387fb9a
Patch more `unittest.case.TestCase.assertXXX` methods (#41008)
ydshieh Sep 19, 2025
e1c13bc
🚨 [v5] remove deprecated entry point (#40997)
gante Sep 19, 2025
9896a3f
🚨 [lightglue] fix: matches order changed because of early stopped ind…
sbucaille Sep 19, 2025
b16b156
Fix `PhimoeIntegrationTest` (#41007)
ydshieh Sep 19, 2025
002d853
Fix Glm4v test (#41011)
Cyrilvallez Sep 19, 2025
0f598ff
Update after #41007 (#41014)
ydshieh Sep 19, 2025
00aa6c7
Fix benchmark runner argument name (#41012)
ahadnagy Sep 20, 2025
ceefb54
Adding support for Qwen3Omni (#41025)
BakerBunker Sep 21, 2025
2f2d193
Making compute_loss_func always take priority in Trainer (#40632)
Flakes342 Sep 22, 2025
21031f5
Modify Qwen3Omni parameter name since VL changed it (#41045)
BakerBunker Sep 22, 2025
17f5a92
Fix Qwen video tests (#41049)
zucchini-nlp Sep 22, 2025
2e07406
[testing] Fix `qwen2_audio` (#41018)
ydshieh Sep 22, 2025
73f6379
Fix typing of tuples (#41028)
cyyever Sep 22, 2025
a945d26
Remove optax (#41030)
cyyever Sep 22, 2025
755a1e5
Fix typos in English/Chinese documentation (#41031)
cyyever Sep 22, 2025
586c487
Use torch.autocast (#40975)
cyyever Sep 22, 2025
0cfc691
docs: improved RoPE function Docstrings (#41004)
RyanMullins Sep 22, 2025
e832420
Fix condition for emitting warning when generation exceeds max model …
yannicks1 Sep 22, 2025
8b26d9f
Fix outdated torch version check (#40925)
cyyever Sep 22, 2025
e5e269e
Remove doc of tf and flax (#41029)
cyyever Sep 22, 2025
0bedf8a
Add Whole Word Masking and Padding Strategy to DataCollatorForLanguag…
rjgleaton Sep 22, 2025
54810d7
[testing] Fix `seed_oss` (#41052)
ydshieh Sep 22, 2025
973b3fc
Remove repeated import (#40937)
cyyever Sep 22, 2025
c036a71
Simplify unnecessary Optional typing (#40839)
cyyever Sep 22, 2025
a062de7
Add write token for uploading benchmark results to the Hub (#41047)
ahadnagy Sep 22, 2025
edf22db
Ci utils (#40978)
remi-or Sep 22, 2025
126962e
Remove <frameworkcontent> and <pt> tags from documentation (#41055)
cyyever Sep 22, 2025
fa3c2d7
Fix CI jobs being all red 🔴 (false positive) (#41059)
ydshieh Sep 22, 2025
7d90855
Update quantization CI (#41068)
SunMarc Sep 22, 2025
0f21b54
[i18n-bn] Add Bengali language README file (#40935)
saidurpulok Sep 22, 2025
f84f441
Improve documentation and errors in Mamba2-based models (#41063)
mapmeld Sep 22, 2025
b4f0c46
Update team member list for some CI workflows (#41094)
ydshieh Sep 23, 2025
9b9fb23
fix crash when using chat to send 2+ request to gptoss (#40536)
sywangyi Sep 23, 2025
6a8b33a
Minor addition, no split modules for VideoMAEE (#41051)
DuyguA Sep 23, 2025
33aaccc
Switch to `python:3.10-slim` for CircleCI docker images (#41067)
ydshieh Sep 23, 2025
8115fbd
Fix argument name in benchmarking script (#41086)
ahadnagy Sep 23, 2025
eb22858
Remove mention of TensorFlow/Flax/JAX from English documentation (#41…
cyyever Sep 23, 2025
6c08b04
Fix typos in documentation (#41087)
cyyever Sep 23, 2025
71a8ad0
Fix typing (#40788)
cyyever Sep 23, 2025
6766e81
Remove unused arguments (#40916)
cyyever Sep 23, 2025
cd36b9b
Remove tf and flax from Chinese documentation (#41057)
cyyever Sep 23, 2025
f82b096
fix wrong height and width when read video use torchvision (#41091)
Juude Sep 23, 2025
824415f
docs: Fix Tool Use links and remove dead RAG links (#41104)
RyanMullins Sep 23, 2025
6a94124
🚨 [generate] update paligemma mask updates (and other assisted genera…
gante Sep 23, 2025
78c6f7a
[tests] gpt2 + `CausalLMModelTester` (#41003)
gante Sep 23, 2025
384b671
Fix `_get_test_info` for inherited tests (#41106)
ydshieh Sep 23, 2025
fe09b8a
Remove bad test skips (#41109)
Cyrilvallez Sep 23, 2025
e1b55ff
Format empty lines and white space in markdown files. (#41100)
cyyever Sep 23, 2025
2dd5e73
Update ruff to 0.13.1 + target Python 3.10 + apply fixes (#37809)
cyyever Sep 24, 2025
e450e0d
🚨 [V5] Remove deprecated training arguments (#41017)
cyyever Sep 24, 2025
20a4c45
Support loading LFM2 GGUF (#41111)
HaroldBenoit Sep 24, 2025
f0b7d24
[torchao safetensors] integrate torchao safetensors support with tran…
liangel-02 Sep 24, 2025
34fd896
[Qwen3-next] Fix dimension mismatch in torch_chunk_gated_delta_rule a…
notkisk Sep 24, 2025
ffa6a76
Fix the error where a keyword argument appearing before *args (#41099)
cyyever Sep 24, 2025
6558e75
Fix broken `` expressions in markdown files (#41113)
cyyever Sep 24, 2025
0ab9d77
Remove self-assignment (#41062)
cyyever Sep 24, 2025
7d70f39
🚨Refactor: Update text2text generation pipelines to use max_new_token…
lilin-1 Sep 24, 2025
13f9a7d
Fixed MXFP4 model storage issue (#41118)
YangKai0616 Sep 24, 2025
0f312b2
Fixed loading LongT5 from legacy checkpoints (#40724)
Szustarol Sep 24, 2025
295cf0b
dummy commit (#41133)
ydshieh Sep 24, 2025
212e827
Fix loading logic flaw with regards to unexpected and missing keys (#…
LysandreJik Sep 24, 2025
18941ba
Using torch.distributions.Categorical
ErfanBaghaei Sep 25, 2025
94336c5
Resolving logits_process.py Issues
ErfanBaghaei Sep 12, 2025
643d9c2
style: autoformat with make fixup
ErfanBaghaei Sep 12, 2025
2cc41c6
Update logits_process.py removed defaults
ErfanBaghaei Sep 16, 2025
5255a72
Variable H name -> cumulative_entropy
ErfanBaghaei Sep 22, 2025
75c809c
Merge branch 'main' into Top-H-Decoding
ErfanBaghaei Sep 25, 2025
70214c1
Resolving format error
ErfanBaghaei Sep 25, 2025
9dad329
Correction of the loop variables in logit processor
ErfanBaghaei Sep 25, 2025
bf23aef
Vectorized the loop in logits_process
ErfanBaghaei Sep 26, 2025
5829189
formatted logits_process
ErfanBaghaei Sep 27, 2025
cd9f22e
paper reference and stopping rule comment logits_process
ErfanBaghaei Sep 27, 2025
116c55d
Trigger CI rerun
ErfanBaghaei Sep 27, 2025
6b3eea3
Update logits_process.py
ArminAzizi98 Sep 27, 2025
0ebb99d
added test_TopH_example_integration
ErfanBaghaei Sep 28, 2025
f4ea5e4
added test_TopH_example_integration
ErfanBaghaei Sep 28, 2025
5e7a92d
Update README.md
souvikku Sep 30, 2025
0c83d0e
Restore CI config to match main (remove accidental changes)
ErfanBaghaei Oct 8, 2025
aa15f5d
Restore CI config to match upstream main (no diffs)
ErfanBaghaei Oct 8, 2025
0d977a9
Merge branch 'main' into Top-H-Decoding
ErfanBaghaei Oct 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/en/internal/generation_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ generation.
[[autodoc]] TemperatureLogitsWarper
- __call__

[[autodoc]] TopHLogitsWarper
- __call__

[[autodoc]] TopKLogitsWarper
- __call__

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,7 @@
"SynthIDTextWatermarkingConfig",
"SynthIDTextWatermarkLogitsProcessor",
"TemperatureLogitsWarper",
"TopHLogitsWarper",
"TopKLogitsWarper",
"TopPLogitsWarper",
"TypicalLogitsWarper",
Expand Down Expand Up @@ -587,6 +588,7 @@
from .generation import TemperatureLogitsWarper as TemperatureLogitsWarper
from .generation import TextIteratorStreamer as TextIteratorStreamer
from .generation import TextStreamer as TextStreamer
from .generation import TopHLogitsWarper as TopHLogitsWarper
from .generation import TopKLogitsWarper as TopKLogitsWarper
from .generation import TopPLogitsWarper as TopPLogitsWarper
from .generation import TypicalLogitsWarper as TypicalLogitsWarper
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
"SuppressTokensAtBeginLogitsProcessor",
"SynthIDTextWatermarkLogitsProcessor",
"TemperatureLogitsWarper",
"TopHLogitsWarper",
"TopKLogitsWarper",
"TopPLogitsWarper",
"TypicalLogitsWarper",
Expand Down Expand Up @@ -153,6 +154,7 @@
SuppressTokensLogitsProcessor,
SynthIDTextWatermarkLogitsProcessor,
TemperatureLogitsWarper,
TopHLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
TypicalLogitsWarper,
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,12 @@ class GenerationConfig(PushToHubMixin):
Minimum token probability, which will be scaled by the probability of the most likely token. It must be a
value between 0 and 1. Typical values are in the 0.01-0.2 range, comparably selective as setting `top_p` in
the 0.99-0.8 range (use the opposite of normal `top_p` values).
top_h (`float`, *optional*):
Entropy budget scaling factor, which controls how much of the distribution’s entropy is preserved when sampling.
Must be a value between 0 and 1. At each step, tokens are sorted by probability, and the smallest prefix of tokens
is kept whose *renormalized* entropy is less than or equal to `top_h` times the entropy of the full distribution.
Smaller values (e.g., 0.2–0.5) lead to more focused, deterministic outputs, while values closer to 1.0 allow more
randomness and diversity. Typical values are in the 0.3–0.6 range.
typical_p (`float`, *optional*, defaults to 1.0):
Local typicality measures how similar the conditional probability of predicting a target token next is to
the expected conditional probability of predicting a random token next, given the partial text already
Expand Down Expand Up @@ -357,6 +363,7 @@ def __init__(self, **kwargs):
self.top_k = kwargs.pop("top_k", 50)
self.top_p = kwargs.pop("top_p", 1.0)
self.min_p = kwargs.pop("min_p", None)
self.top_h = kwargs.pop("top_h", None)
self.typical_p = kwargs.pop("typical_p", 1.0)
self.epsilon_cutoff = kwargs.pop("epsilon_cutoff", 0.0)
self.eta_cutoff = kwargs.pop("eta_cutoff", 0.0)
Expand Down Expand Up @@ -581,6 +588,8 @@ def validate(self, strict=False):
minor_issues["top_p"] = greedy_wrong_parameter_msg.format(flag_name="top_p", flag_value=self.top_p)
if self.min_p is not None:
minor_issues["min_p"] = greedy_wrong_parameter_msg.format(flag_name="min_p", flag_value=self.min_p)
if self.top_h is not None:
minor_issues["top_h"] = greedy_wrong_parameter_msg.format(flag_name="top_h", flag_value=self.top_h)
if self.typical_p is not None and self.typical_p != 1.0:
minor_issues["typical_p"] = greedy_wrong_parameter_msg.format(
flag_name="typical_p", flag_value=self.typical_p
Expand Down
106 changes: 106 additions & 0 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,112 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
return scores_processed


class TopHLogitsWarper(LogitsProcessor):
"""
[`LogitsProcessor`] that implements Top-H sampling, a decoding method which adaptively selects a subset of
high-probability tokens based on entropy and cumulative probability constraints.

This method dynamically determines how many tokens to keep by analyzing the entropy difference of the selected
distribution, thereby balancing exploration and exploitation. It ensures that generated text maintains both
diversity and coherence.

Reference:
For details, see *Top-H Decoding: Adapting the Creativity and Coherence with Bounded Entropy in Text Generation*
(NeurIPS 2025): https://arxiv.org/abs/2509.02510

Args:
top_h (`float`):
Scaling coefficient for the entropy-based threshold (`tau`). Must be in the range `(0, 1]`.

filter_value (`float`, *optional*, defaults to -inf):
All filtered values will be set to this float value.

Example:

```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM

>>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B")
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B")

>>> inputs = tokenizer("A sequence: 1, 2", return_tensors="pt")

>>> outputs = model.generate(**inputs, do_sample=True, top_h=0.4)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
A sequence: 1, 2, 3, 4, 5, 6, 7, 8, 9
```
"""

def __init__(self, top_h: float, filter_value: float = -float("Inf")):
super().__init__()

# input checks
if not (0 < top_h <= 1):
raise ValueError("`top_h` must be in the range (0, 1].")

# Maximum number of top tokens to consider before applying the entropy-based filter.
# Acts as a cap for efficiency and numerical stability — increasing this allows more
# tokens to be evaluated but may slow down generation. Default is 100.
self.top_n = 100

self.top_h = top_h
self.filter_value = filter_value

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
"""
Filters logits using Top-H sampling.

Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Input token IDs.
scores (`torch.FloatTensor` of shape `(batch_size, vocab_size)`):
Raw logits from the model.

Return:
`torch.FloatTensor` of shape `(batch_size, vocab_size)`:
Processed logits where invalid tokens are masked with `-inf`.
"""
batch_size, vocab_size = scores.shape
device = scores.device
keep_mask = torch.zeros((batch_size, vocab_size), dtype=torch.bool, device=device)
top_n = min(self.top_n, vocab_size)

# 1. Get top-k logits and indices for the whole batch
top_logits, top_idx = torch.topk(scores, top_n, dim=-1, largest=True, sorted=True)

# 2. Create a batch of categorical distributions
dist = torch.distributions.Categorical(logits=top_logits)
probs = dist.probs
log_probs = torch.log(probs) # dist.log_prob(idx)

# 3. Calculate the entropy-based threshold tau for the whole batch
# We unsqueeze tau to enable broadcasting against the cumulative entropy tensor.
tau = (dist.entropy() * self.top_h).unsqueeze(-1)

# 4. Calculate cumulative entropy using torch.cumsum
# The individual entropy terms (-p * log(p)) are calculated for all top_n tokens at once.
entropy_terms = -probs * log_probs
cumulative_entropy = torch.cumsum(entropy_terms, dim=-1)

# 5. Determine which tokens to keep based on the stopping condition
# Create a boolean mask for the top_n tokens.
# Stopping rule: keep adding tokens in order of probability until the cumulative entropy
# exceeds the threshold τ = H(p) * top_h. This ensures diversity (via entropy) while
# guaranteeing at least the most probable token is always included.
selection_mask = cumulative_entropy <= tau
selection_mask[:, 0] = True

# 6. Update the final keep_mask for the entire batch in one operation
# The scatter_ operation efficiently updates the keep_mask at the indices
# specified by top_idx with the boolean values from selection_mask.
keep_mask.scatter_(dim=1, index=top_idx, src=selection_mask)

# apply filtering
scores_processed = scores.clone()
scores_processed[~keep_mask] = self.filter_value
return scores_processed


class MinPLogitsWarper(LogitsProcessor):
"""
[`LogitsProcessor`] that performs min-p, i.e. keeps all tokens that are above a minimum probability, scaled by the
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
SuppressTokensAtBeginLogitsProcessor,
SuppressTokensLogitsProcessor,
TemperatureLogitsWarper,
TopHLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
TypicalLogitsWarper,
Expand Down Expand Up @@ -1243,6 +1244,8 @@ def _get_logits_processor(
# all samplers can be found in `generation_utils_samplers.py`
if generation_config.temperature is not None and generation_config.temperature != 1.0:
processors.append(TemperatureLogitsWarper(generation_config.temperature))
if generation_config.top_h is not None:
processors.append(TopHLogitsWarper(top_h=generation_config.top_h))
if generation_config.top_k is not None and generation_config.top_k != 0:
processors.append(
TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep)
Expand Down
90 changes: 90 additions & 0 deletions tests/generation/test_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
SequenceBiasLogitsProcessor,
SynthIDTextWatermarkLogitsProcessor,
TemperatureLogitsWarper,
TopHLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
TypicalLogitsWarper,
Expand Down Expand Up @@ -394,6 +395,95 @@ def test_top_p_dist_warper(self):
# first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [3, 2])

def test_top_h_dist_warper(self):
"""
We construct small distributions where the expected kept set is obvious for a given alpha.
We pass *log-probabilities* as "scores" so that softmax(scores) == original probabilities,
matching the style in other warper tests (e.g., MinP).
"""

input_ids = None

# --- Case 1: Highly peaked distribution -> small alpha keeps only the top-1
dist1 = torch.log(
torch.tensor(
[[0.97, 0.01, 0.01, 0.01]],
device=torch_device,
dtype=torch.float,
)
)
top_h_warp = TopHLogitsWarper(top_h=0.3)
filtered_logits = top_h_warp(input_ids, dist1.clone())
filtered_dist = torch.exp(filtered_logits) # exp(-inf) -> 0

EXPECTED1 = torch.tensor(
[[0.97, 0.0, 0.0, 0.0]],
device=torch_device,
dtype=torch.float,
)
torch.testing.assert_close(filtered_dist, EXPECTED1, rtol=1e-3, atol=1e-3)

# --- Case 2: Moderately skewed distribution -> alpha large enough to keep exactly top-2
dist2 = torch.log(
torch.tensor(
[[0.4, 0.3, 0.2, 0.1]], # entropy budget with alpha=0.7 yields 2-token prefix
device=torch_device,
dtype=torch.float,
)
)
top_h_warp = TopHLogitsWarper(top_h=0.7)
filtered_logits = top_h_warp(input_ids, dist2.clone())
filtered_dist = torch.exp(filtered_logits)

EXPECTED2 = torch.tensor(
[[0.4, 0.3, 0.0, 0.0]],
device=torch_device,
dtype=torch.float,
)
torch.testing.assert_close(filtered_dist, EXPECTED2, rtol=1e-3, atol=1e-3)

# --- Case 3: Uniform distribution -> alpha=1.0 keeps all tokens
dist3 = torch.log(
torch.tensor(
[[0.25, 0.25, 0.25, 0.25]],
device=torch_device,
dtype=torch.float,
)
)
top_h_warp = TopHLogitsWarper(top_h=1.0)
filtered_logits = top_h_warp(input_ids, dist3.clone())
filtered_dist = torch.exp(filtered_logits)

EXPECTED3 = torch.tensor(
[[0.25, 0.25, 0.25, 0.25]],
device=torch_device,
dtype=torch.float,
)
torch.testing.assert_close(filtered_dist, EXPECTED3, rtol=1e-3, atol=1e-3)

# --- Case 4: Probabilities including 0 value
dist4 = torch.log(
torch.tensor(
[[0.75, 0.25, 0.0, 0.0]],
device=torch_device,
dtype=torch.float,
)
)
top_h_warp = TopHLogitsWarper(top_h=0.4)
filtered_logits = top_h_warp(input_ids, dist4.clone())
filtered_dist = torch.exp(filtered_logits)

EXPECTED4 = torch.tensor(
[[0.75, 0.0, 0.0, 0.0]],
device=torch_device,
dtype=torch.float,
)
torch.testing.assert_close(filtered_dist, EXPECTED4, rtol=1e-3, atol=1e-3)
# Processor should not change logits in-place
top_h_warp = TopHLogitsWarper(top_h=0.5)
out_again = top_h_warp(input_ids, dist3)
assert not torch.all(out_again == dist3)

def test_min_p_dist_warper(self):
input_ids = None
vocab_size = 10
Expand Down
28 changes: 28 additions & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3059,6 +3059,34 @@ def test_synthid_text_watermark_generation_mean_expected_bias(self):
)
self.assertTrue(torch.all(is_close))

@slow
def test_TopH_example_integration(self):
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B")
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-3B")
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.pad_token_id
encoder_input_str = "Tell me a joke about a monkey."
input_ids = tokenizer(encoder_input_str, return_tensors="pt")

torch.manual_seed(0)

outputs = model.generate(
**input_ids,
eos_token_id=model.config.eos_token_id,
do_sample=True,
temperature=1.0,
top_h=0.4,
max_new_tokens=32,
pad_token_id=tokenizer.pad_token_id,
)
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
outputs,
[
'Tell me a joke about a monkey. Why did the monkey go to the doctor? Because he was feeling a little "tropic"!'
],
)

@slow
def test_beam_search_example_integration(self):
# exactly the example provided in the docstrings of beam search, which previously
Expand Down