Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
421 commits
Select commit Hold shift + click to select a range
546efee
fix
qgallouedec Jul 10, 2025
7b077ae
nice
qgallouedec Jul 10, 2025
ebcae9a
where i am at
ArthurZucker Jul 10, 2025
528b3c8
Bro this works
ArthurZucker Jul 10, 2025
9c61a8c
Merge pull request #16 from huggingface/fix-attention
qgallouedec Jul 10, 2025
297e47e
Update src/transformers/integrations/tensor_parallel.py
ArthurZucker Jul 11, 2025
2f852e2
Merge pull request #11 from huggingface/tp_embed_parallel
ArthurZucker Jul 11, 2025
3d25cf7
cleanups
ArthurZucker Jul 11, 2025
ff0544b
Merge branch 'add-oai' into add-fast-flash-kernel
ArthurZucker Jul 11, 2025
29454d2
yups that was breaking
ArthurZucker Jul 11, 2025
b3582fc
Merge branch 'add-fast-flash-kernel' of github.com:huggingface/new-mo…
ArthurZucker Jul 11, 2025
f33a74d
Merge pull request #15 from huggingface/add-fast-flash-kernel
ArthurZucker Jul 11, 2025
1f3ae2b
Merge branch 'main' of github.com:huggingface/new-model-addition-open…
ArthurZucker Jul 11, 2025
15c85e0
Update src/transformers/models/openai_moe/modeling_openai_moe.py
ArthurZucker Jul 11, 2025
0c7379a
merge
ArthurZucker Jul 11, 2025
ad0fc38
gather on experts and not mlp
SunMarc Jul 17, 2025
4fb7345
add changes for latest convert branch
edbeeching Jul 18, 2025
968238c
adds options to get output_router_logits from config
edbeeching Jul 18, 2025
4bc5557
bring chat temlate + special tokens back into the script.
Vaibhavs10 Jul 22, 2025
68fd833
Merge pull request #22 from huggingface/vb/special-tok
ArthurZucker Jul 22, 2025
410435a
Merge pull request #21 from huggingface/ed-fix-modeling
ArthurZucker Jul 22, 2025
07bd34d
initial commmit
MekkCyber Jul 22, 2025
b7987d2
update
MekkCyber Jul 22, 2025
2c0fd4d
working with shards
MekkCyber Jul 22, 2025
1d03f3a
add model.safetensors.index.json
MekkCyber Jul 22, 2025
40e379d
fix
MekkCyber Jul 22, 2025
b68aa6b
fix
MekkCyber Jul 22, 2025
a87db4f
mxfp4 flag
MekkCyber Jul 22, 2025
c3c01f0
rm print
MekkCyber Jul 22, 2025
863630d
Fix PAD/EOS/BOS (#18)
qgallouedec Jul 10, 2025
eab251f
add some doc
MekkCyber Jul 23, 2025
928b9b6
Merge pull request #23 from huggingface/update_conversion_script
SunMarc Jul 23, 2025
9280e59
special tokens based on harmony.
Vaibhavs10 Jul 23, 2025
b382c5e
add in tokenizer config as well.
Vaibhavs10 Jul 23, 2025
7cdd0be
Merge pull request #25 from huggingface/vb/upd-conversion-script
Vaibhavs10 Jul 23, 2025
f8f3e40
prepare for rebase with main
ArthurZucker Jul 24, 2025
c9dc8f2
Merge branches 'add-oai' and 'add-oai' of github.com:huggingface/new-…
ArthurZucker Jul 24, 2025
0ce752c
merge with main
ArthurZucker Jul 24, 2025
60af841
Fix for initialize_tensor_parallelism now returning 4-tuple
edbeeching Jul 24, 2025
1ce172b
mxfp4
SunMarc Jun 10, 2025
c0bee22
mxfp4 draft
SunMarc Jun 12, 2025
fe896d3
fix
SunMarc Jun 12, 2025
174147d
fix import
SunMarc Jun 23, 2025
b8215dd
draft
SunMarc Jun 23, 2025
62f77e1
draft impl
SunMarc Jul 4, 2025
6e9d0c7
finally working !
SunMarc Jul 4, 2025
6b8b279
simplify
SunMarc Jul 8, 2025
ea5c364
add import
SunMarc Jul 8, 2025
1175ab4
working version
SunMarc Jul 8, 2025
d53cb49
consider blocks and scales
SunMarc Jul 10, 2025
8c43631
device mesh fix
SunMarc Jul 10, 2025
4f515eb
initial commit
MekkCyber Jul 16, 2025
0ff6727
add working dequant + quant logic
MekkCyber Jul 16, 2025
13cb07b
update
MekkCyber Jul 17, 2025
3988856
non nan, gibberish output
MekkCyber Jul 21, 2025
b9c8138
working EP + quantization finally !
MekkCyber Jul 22, 2025
5117d71
start cleaning
MekkCyber Jul 23, 2025
3733a34
remove reversing process
MekkCyber Jul 23, 2025
6587359
style
MekkCyber Jul 23, 2025
7961073
some cleaning
MekkCyber Jul 23, 2025
0de006a
initial commmit
MekkCyber Jul 22, 2025
12a9e80
more cleaning
MekkCyber Jul 23, 2025
3904783
more cleaning
MekkCyber Jul 24, 2025
75e0f21
simplify
MekkCyber Jul 24, 2025
c8ce047
more cleaning
MekkCyber Jul 24, 2025
8b162f7
rm duplicated function
MekkCyber Jul 24, 2025
8a00f60
changing tp_plan
MekkCyber Jul 24, 2025
d760f30
update tp plan check
MekkCyber Jul 24, 2025
b34570e
add loading attribute
MekkCyber Jul 24, 2025
a4950aa
dequantizing logic
MekkCyber Jul 24, 2025
89b0671
use subfunctions
MekkCyber Jul 24, 2025
7bfdca6
import cleaning
MekkCyber Jul 24, 2025
21872bd
update_param_name
MekkCyber Jul 24, 2025
b68ece8
adds clamped swiglu
edbeeching Jul 24, 2025
3e106d6
add clamping to training path
edbeeching Jul 28, 2025
1716e6d
simplify dequant logic
MekkCyber Jul 28, 2025
f49bcbb
Merge branch 'main' of github.com:huggingface/new-model-addition-open…
ArthurZucker Jul 28, 2025
b8b0023
update
ArthurZucker Jul 28, 2025
6400fb2
Merge branch 'add-oai' of github.com:huggingface/new-model-addition-o…
ArthurZucker Jul 28, 2025
6976169
Bad merge
ArthurZucker Jul 28, 2025
195cca6
more simplifications & tests
MekkCyber Jul 28, 2025
345afb1
fix !
ArthurZucker Jul 28, 2025
7b18304
Merge pull request #26 from huggingface/add-clamp-swiglu
ArthurZucker Jul 28, 2025
009355a
fix registering custom attention
ArthurZucker Jul 28, 2025
d237a90
fix order
MekkCyber Jul 29, 2025
ccffc0b
fixes
MekkCyber Jul 29, 2025
f92878a
some test nits
MekkCyber Jul 29, 2025
90522c4
nits
MekkCyber Jul 29, 2025
dbb8b20
nit
MekkCyber Jul 29, 2025
d5634bd
Merge branch 'add-oai' into adding_packing_format_option
MekkCyber Jul 29, 2025
587d8da
Merge pull request #20 from huggingface/adding_packing_format_option
MekkCyber Jul 29, 2025
edd9232
fix
MekkCyber Jul 29, 2025
c0ef156
Merge pull request #27 from huggingface/guard_kernels_imports
SunMarc Jul 29, 2025
dc2b16f
Clamp sink logits
lewtun Jul 29, 2025
b050830
Clean
lewtun Jul 30, 2025
e0e406e
Soft-max trick
lewtun Jul 30, 2025
54e8825
Clean up
lewtun Jul 30, 2025
0378ae8
p
lewtun Jul 30, 2025
a208980
Merge pull request #28 from huggingface/fix-train-bsz
ArthurZucker Jul 30, 2025
077cfee
fix deepspeed
MekkCyber Jul 30, 2025
bec11b7
update both modeling and modular for cleanup
ArthurZucker Jul 30, 2025
7d8ac2e
contiguous
MekkCyber Jul 30, 2025
42ab108
update tests
ArthurZucker Jul 30, 2025
e9f130a
fix top_k router call
ArthurZucker Jul 30, 2025
da77d5e
revert renaming
ArthurZucker Jul 30, 2025
5b0bd40
test nits
ArthurZucker Jul 30, 2025
9af87b2
Merge branch 'add-oai' of github.com:huggingface/new-model-addition-o…
ArthurZucker Jul 30, 2025
b43d2cd
small fixes for EP
ArthurZucker Jul 30, 2025
13ec4ef
fix path for our local tests
ArthurZucker Jul 30, 2025
0b5a0e9
Merge branch 'add-oai' of github.com:huggingface/new-model-addition-o…
ArthurZucker Jul 30, 2025
0276225
update as I should not have broken that!
ArthurZucker Jul 30, 2025
f1cf951
Merge branch 'add-oai' of github.com:huggingface/new-model-addition-o…
ArthurZucker Jul 30, 2025
a34b39c
fix the loss of mixtral
ArthurZucker Jul 30, 2025
e7cc591
revert part of the changes related to router_scores, kernel probably …
ArthurZucker Jul 30, 2025
b7a9e4a
Merge branch 'add-oai' of github.com:huggingface/new-model-addition-o…
ArthurZucker Jul 30, 2025
f1245b4
deleting a small nit
ArthurZucker Jul 30, 2025
8a6fbf9
Merge branches 'add-oai' and 'add-oai' of github.com:huggingface/new-…
ArthurZucker Jul 30, 2025
9b387ca
update arch
SunMarc Jul 30, 2025
6c0effa
fix post processing
MekkCyber Jul 30, 2025
ab0f929
update
SunMarc Jul 30, 2025
e030193
Merge pull request #30 from huggingface/fix-conversion-architecture
Vaibhavs10 Jul 30, 2025
c80bd44
running version but not expected output
SunMarc Jul 30, 2025
6c55b12
Merge pull request #29 from huggingface/fix_ds
SunMarc Jul 30, 2025
740f3aa
Merge remote-tracking branch 'origin/add-oai' into update-triton-kernels
SunMarc Jul 30, 2025
dc12518
moving to cuda
MekkCyber Jul 31, 2025
20dfa56
initial commit
MekkCyber Jul 31, 2025
228a982
revert
MekkCyber Jul 31, 2025
5a59733
erroring when loading on cpu
MekkCyber Jul 31, 2025
910ccfe
updates
MekkCyber Jul 31, 2025
212acd0
del blocks, scales
MekkCyber Jul 31, 2025
5c6d3b2
fix
SunMarc Jul 31, 2025
5ec240f
style
SunMarc Jul 31, 2025
2faa7ca
rm comm
SunMarc Jul 31, 2025
c5b8cec
comment
MekkCyber Jul 31, 2025
79dd4fc
add comment
SunMarc Jul 31, 2025
93f0816
Merge pull request #36 from huggingface/default_to_dequantize_training
SunMarc Jul 31, 2025
c5e7bfc
Merge branch 'add-oai' into update-triton-kernels
SunMarc Jul 31, 2025
d238ea4
style
SunMarc Jul 31, 2025
76f9088
Merge pull request #31 from huggingface/update-triton-kernels
SunMarc Jul 31, 2025
a7dd97f
remove duplicated lines
SunMarc Jul 31, 2025
cf4843b
Fix minor issue with weight_map conversion script
SunMarc Aug 1, 2025
8b7a73f
fix sampling params
zhuohan123 Aug 1, 2025
08b031b
rename to final name
ArthurZucker Aug 1, 2025
a39ebae
Merge branch 'add-oai' into zhuohan/fix-sampling-parmsl
pcuenca Aug 1, 2025
8430860
Merge pull request #37 from huggingface/zhuohan/fix-sampling-parmsl
Vaibhavs10 Aug 1, 2025
0d1a2da
upate pre-final version of template
Vaibhavs10 Aug 1, 2025
5f3de46
Update src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py
pcuenca Aug 1, 2025
ce4e912
Merge pull request #38 from huggingface/vb/upd-template
Vaibhavs10 Aug 1, 2025
bddc8c2
fix batched inference
MekkCyber Aug 1, 2025
b2b1ca5
Merge pull request #39 from huggingface/fix_batched_inference
MekkCyber Aug 1, 2025
06b35eb
serve fixes
Aug 1, 2025
0de8f62
swizzle !
SunMarc Aug 1, 2025
a29c5a2
Merge branch 'add-oai' into swizzle
SunMarc Aug 1, 2025
aca1e72
update final chat template by Matt.
Vaibhavs10 Aug 1, 2025
a8c3c49
fix responses; pin oai
Aug 1, 2025
33636c9
sinplify
SunMarc Aug 1, 2025
af6fb99
Thanks Matt for his tireless efforts!
Vaibhavs10 Aug 1, 2025
22e8236
`transformer serve` fixes for oai (mostly hide CoT)
gante Aug 1, 2025
6f91a55
Update src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py
Vaibhavs10 Aug 1, 2025
afe8912
fix
SunMarc Aug 1, 2025
b7dc08c
Merge pull request #42 from huggingface/swizzle
SunMarc Aug 1, 2025
e991ef4
Merge pull request #41 from huggingface/vb/up-template-2
Vaibhavs10 Aug 1, 2025
7e540fc
Use ROCm kernels from HUB
ahadnagy Aug 1, 2025
3e4ad36
Make kernel modes explicit
ahadnagy Aug 1, 2025
fa6eee9
Merge pull request #43 from huggingface/rocm-kernels-support
ahadnagy Aug 1, 2025
e946804
update final chat template by Matt. x2
Vaibhavs10 Aug 1, 2025
1a8728d
Thanks Matt for his tireless efforts!
Vaibhavs10 Aug 1, 2025
f322506
Merge pull request #44 from huggingface/vb/up-template-3
Vaibhavs10 Aug 1, 2025
50b8250
Fix installation
lewtun Aug 1, 2025
dec98d8
Update setup.py
lewtun Aug 1, 2025
0c6f911
allow no content
qgallouedec Aug 1, 2025
181c625
fix: update message handling in write_tokenizer function
qgallouedec Aug 1, 2025
fa7a66d
Merge pull request #45 from huggingface/fix-install
lewtun Aug 1, 2025
7c74123
Fix template logic for user message role
qgallouedec Aug 2, 2025
672bc17
Merge pull request #47 from huggingface/fix-chat-template
lewtun Aug 2, 2025
402976d
Merge branch 'main' of github.com:huggingface/new-model-addition-open…
ArthurZucker Aug 2, 2025
5509620
Merge branch 'add-oai' of github.com:huggingface/new-model-addition-o…
ArthurZucker Aug 2, 2025
9d27880
last nits for CB and flash_paged!
ArthurZucker Aug 2, 2025
4cf6186
there was one bad merge
ArthurZucker Aug 2, 2025
cac4c09
fix CB (hardcode for now, its just using kv groups instead)
ArthurZucker Aug 2, 2025
eeef8c8
fix
MekkCyber Aug 2, 2025
45fbc18
better fix for device_map
SunMarc Aug 2, 2025
92a2a49
Merge pull request #48 from huggingface/fix_target_device
SunMarc Aug 2, 2025
6dd3a72
minor device fix
SunMarc Aug 2, 2025
5ef7f3f
Fix flash paged
ArthurZucker Aug 3, 2025
47ae152
Merge branch 'add-oai' of github.com:huggingface/new-model-addition-o…
ArthurZucker Aug 3, 2025
d2303c7
updates
ArthurZucker Aug 3, 2025
ed511f2
Revert "remove dtensors, not explicit (#39840)"
ArthurZucker Aug 3, 2025
d8092b9
Merge pull request #46 from huggingface/fix-tool-chat-template
lewtun Aug 3, 2025
e9b3708
update
ArthurZucker Aug 3, 2025
70750d9
Revert "remove dtensors, not explicit (#39840)"
ArthurZucker Aug 3, 2025
3557689
fix merge
ArthurZucker Aug 3, 2025
fbc6815
Merge branch 'add-oai' of github.com:huggingface/new-model-addition-o…
ArthurZucker Aug 3, 2025
b939303
fix
MekkCyber Aug 3, 2025
d238182
Fix line break when custom model indentity
qgallouedec Aug 3, 2025
7c364da
Merge pull request #49 from huggingface/fix_import_triton_kernels
MekkCyber Aug 4, 2025
088a607
nits testing
ArthurZucker Aug 4, 2025
d91814b
to locals first and pass sliding window to flash paged
ArthurZucker Aug 4, 2025
b392bc5
Merge branch 'add-oai' of github.com:huggingface/new-model-addition-o…
ArthurZucker Aug 4, 2025
27bd828
register modes for MegaBlocksMoeMlp
ArthurZucker Aug 4, 2025
b667b7c
add integration test in fixtures -> now update the tests to use it!
ArthurZucker Aug 4, 2025
afffd58
update integration tests
ArthurZucker Aug 4, 2025
00d6703
initial fix
MekkCyber Aug 4, 2025
6a8710e
style and update tests
ArthurZucker Aug 4, 2025
4cb0a93
fix
MekkCyber Aug 4, 2025
b696531
Merge pull request #53 from huggingface/fix_warning
MekkCyber Aug 4, 2025
a9b7b39
Merge pull request #52 from huggingface/fix_kernels
MekkCyber Aug 4, 2025
b9f34dd
chore(gpt oss): remove mlp_bias from configuration
tengomucho Aug 4, 2025
eb942a6
stats
SunMarc Aug 4, 2025
94a85f0
Integration tests
LysandreJik Aug 4, 2025
210067a
whoops
LysandreJik Aug 4, 2025
e60807a
Shouldn't move model
LysandreJik Aug 4, 2025
2718a7c
Merge pull request #57 from huggingface/add-oai-integration-test-fixes
LysandreJik Aug 4, 2025
093ffd5
Merge pull request #50 from huggingface/fix-line-break
Vaibhavs10 Aug 4, 2025
c954ef7
Ensure assistant messages without thinking always go to "final" channel
Rocketknight1 Aug 4, 2025
13f6756
More checks to ensure expected format
Rocketknight1 Aug 4, 2025
6ef5c34
Merge pull request #54 from huggingface/remove-mlp_bias
ArthurZucker Aug 4, 2025
bee0515
Add pad_token_id to model configuration in write_model function (#51)
qgallouedec Aug 4, 2025
e1f46b4
Add oai fix fast tests (#59)
LysandreJik Aug 4, 2025
e29f659
Update src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py
Rocketknight1 Aug 4, 2025
5c6255e
Update src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py
Rocketknight1 Aug 4, 2025
889fe01
Update src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py
Rocketknight1 Aug 4, 2025
25e8bd8
Merge pull request #58 from huggingface/update-template
Rocketknight1 Aug 4, 2025
9844308
reasoning -> Reasoning
Vaibhavs10 Aug 4, 2025
563b5cf
Merge pull request #61 from huggingface/vb/upd-chat-temp-reasoning
Vaibhavs10 Aug 4, 2025
b222c6f
Add additional integration tests
LysandreJik Aug 4, 2025
8421054
fixup
LysandreJik Aug 4, 2025
6001771
Slight fixes
LysandreJik Aug 4, 2025
e360f17
align chat template with harmony
qgallouedec Aug 5, 2025
5fe06b9
simplify
qgallouedec Aug 5, 2025
ba792c9
Add comment
LysandreJik Aug 5, 2025
afc0fc4
torch testing assert close
LysandreJik Aug 5, 2025
7bddb91
torch testing assert close
LysandreJik Aug 5, 2025
4068437
torch testing assert close
LysandreJik Aug 5, 2025
94f11c5
torch testing assert close
LysandreJik Aug 5, 2025
3660b2b
torch testing assert close
LysandreJik Aug 5, 2025
974987f
torch testing assert close
LysandreJik Aug 5, 2025
768b582
Merge pull request #56 from huggingface/better-stats
SunMarc Aug 5, 2025
d881a20
Revert fixup
LysandreJik Aug 5, 2025
0c7db23
Merge pull request #62 from huggingface/add-new-integration-tests
LysandreJik Aug 5, 2025
6698004
skip 2 test remove todo
ArthurZucker Aug 5, 2025
208b83c
Merge branch 'add-oai' of github.com:huggingface/new-model-addition-o…
ArthurZucker Aug 5, 2025
54cf55f
merge
ArthurZucker Aug 5, 2025
f19e04b
padding side should be left for integration tests
ArthurZucker Aug 5, 2025
1f7cad0
fix modular wrt to changes made to modeling
ArthurZucker Aug 5, 2025
6973ba4
style
ArthurZucker Aug 5, 2025
9ab5897
Merge branch 'main' of github.com:huggingface/transformers into add-oai
ArthurZucker Aug 5, 2025
1f47841
isort
ArthurZucker Aug 5, 2025
865b368
fix opies for the loss
ArthurZucker Aug 5, 2025
75f13d0
mmmm
ArthurZucker Aug 5, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,8 @@
title: OLMoE
- local: model_doc/open-llama
title: Open-Llama
- local: model_doc/openai_moe
title: OpenAIMoe
- local: model_doc/opt
title: OPT
- local: model_doc/pegasus
Expand Down
58 changes: 58 additions & 0 deletions docs/source/en/model_doc/openai_moe.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->

<div style="float: right;">
<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=
">
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
</div>
</div>

# OpenAIMoE

## Overview

The OpenAIMoE model was proposed in [<INSERT PAPER NAME HERE>](<INSERT PAPER LINK HERE>) by <INSERT AUTHORS HERE>.
<INSERT SHORT SUMMARY HERE>

The abstract from the paper is the following:

*<INSERT PAPER ABSTRACT HERE>*

Tips:

<INSERT TIPS ABOUT MODEL HERE>

This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/<INSERT YOUR HF USERNAME HERE>).
The original code can be found [here](<INSERT LINK TO GITHUB REPO HERE>).


## OpenAIMoeConfig

[[autodoc]] OpenAIMoeConfig

## OpenAIMoeModel

[[autodoc]] OpenAIMoeModel
- forward

## OpenAIMoeForCausalLM

[[autodoc]] OpenAIMoeForCausalLM
- forward
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@
# Keras pin - this is to make sure Keras 3 doesn't destroy us. Remove or change when we have proper support.
"keras>2.9,<2.16",
"keras-nlp>=0.3.1,<0.14.0", # keras-nlp 0.14 doesn't support keras 2, see pin on keras.
"kernels>=0.6.1,<0.7",
"kernels>=0.6.1,<=0.9",
"librosa",
"natten>=0.14.6,<0.15.0",
"nltk<=3.8.1",
Expand All @@ -137,7 +137,7 @@
"onnxconverter-common",
"onnxruntime-tools>=1.4.2",
"onnxruntime>=1.4.0",
"openai",
"openai>=1.98.0",
"opencv-python",
"optimum-benchmark>=0.3.0",
"optuna",
Expand Down
1 change: 1 addition & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@
"GPTQConfig",
"HiggsConfig",
"HqqConfig",
"Mxfp4Config",
"QuantoConfig",
"QuarkConfig",
"FPQuantConfig",
Expand Down
72 changes: 69 additions & 3 deletions src/transformers/commands/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,7 +909,16 @@ def generate_chat_completion(self, req: dict) -> Generator[str, None, None]:
inputs = inputs.to(model.device)
request_id = req.get("request_id", "req_0")

generation_streamer = TextIteratorStreamer(processor, skip_special_tokens=True, skip_prompt=True)
# Temporary hack for GPTOSS 1: don't filter special tokens
skip_special_tokens = True
if "gptoss" in model.config.architectures[0].lower():
skip_special_tokens = False

generation_streamer = TextIteratorStreamer(
processor,
skip_special_tokens=skip_special_tokens,
skip_prompt=True,
)
generation_config = create_generation_config_from_req(req, model_generation_config=model.generation_config)

last_kv_cache = None
Expand All @@ -925,12 +934,21 @@ def generate_chat_completion(self, req: dict) -> Generator[str, None, None]:
}

def stream_chat_completion(streamer, _request_id):
# Temporary hack for GPTOS 2: filter out the CoT tokens. Full solution here implies defining new output
# classes and piping the reasoning trace into a new field
filter_cot = False
cot_trace_end = None
if "gptoss" in model.config.architectures[0].lower():
filter_cot = True
cot_trace_end = "<|channel|>final<|message|>"

# Thin wrapper to save the KV cache after generation
def generate_with_cache(**kwargs):
generate_output = model.generate(**kwargs)
self.last_kv_cache = generate_output.past_key_values

thread = Thread(target=generate_with_cache, kwargs=generation_kwargs)
results = ""

try:
thread.start()
Expand All @@ -941,6 +959,20 @@ def generate_with_cache(**kwargs):
yield self.build_chat_completion_chunk(request_id, role="assistant", model=model_id_and_revision)

for result in streamer:
# Temporary hack for GPTOS 3: don't emit the final "<|return|>"
if "gptoss" in model.config.architectures[0].lower():
if result.endswith("<|return|>"):
result = result[: -len("<|return|>")]
results += result

# (related to temporary hack 2)
if filter_cot:
if cot_trace_end in results: # end of reasoning trace observed -> stop filtering
filter_cot = False
continue
else:
continue

# ====== TOOL CALL LOGIC ======
if tool_model_family is not None:
# Start of a tool call: reset state variables, set `inside_tool_call`
Expand Down Expand Up @@ -1064,7 +1096,16 @@ def generate_response(self, req: dict) -> Generator[str, None, None]:
inputs = inputs.to(model.device)
request_id = req.get("previous_response_id", "req_0")

generation_streamer = TextIteratorStreamer(processor, skip_special_tokens=True, skip_prompt=True)
# Temporary hack for GPTOSS 1: don't filter special tokens
skip_special_tokens = True
if "gptoss" in model.config.architectures[0].lower():
skip_special_tokens = False

generation_streamer = TextIteratorStreamer(
processor,
skip_special_tokens=skip_special_tokens,
skip_prompt=True,
)
generation_config = create_generation_config_from_req(req, model_generation_config=model.generation_config)

last_kv_cache = None
Expand All @@ -1081,6 +1122,14 @@ def generate_response(self, req: dict) -> Generator[str, None, None]:
}

def stream_response(streamer, _request_id):
# Temporary hack for GPTOS 2: filter out the CoT tokens. Full solution here implies defining new output
# classes and piping the reasoning trace into a new field
filter_cot = False
cot_trace_end = None
if "gptoss" in model.config.architectures[0].lower():
filter_cot = True
cot_trace_end = "<|channel|>final<|message|>"

# Thin wrapper to save the KV cache after generation
def generate_with_cache(**kwargs):
generate_output = model.generate(**kwargs)
Expand Down Expand Up @@ -1167,14 +1216,29 @@ def generate_with_cache(**kwargs):
# Stream the actual generated text
results = ""
for result in streamer:
# Temporary hack for GPTOS 3: don't emit the final "<|return|>"
if "gptoss" in model.config.architectures[0].lower():
if result.endswith("<|return|>"):
result = result[: -len("<|return|>")]
results += result

# (related to temporary hack 2)
if filter_cot:
if cot_trace_end in results: # end of reasoning trace observed -> stop filtering
filter_cot = False
results = "" # reset the results -> results will now track the final response
continue
else:
continue

response_output_text_delta = ResponseTextDeltaEvent(
type="response.output_text.delta",
item_id=f"msg_{request_id}",
sequence_number=sequence_number,
output_index=output_index,
content_index=content_index,
delta=result,
logprobs=[{"token": "", "logprob": 99.9}], # TODO: add actual logprobs
)
sequence_number += 1
yield self.build_response_event(response_output_text_delta)
Expand All @@ -1187,6 +1251,7 @@ def generate_with_cache(**kwargs):
output_index=output_index,
content_index=0,
text=results,
logprobs=[{"token": "", "logprob": 99.9}], # TODO: add actual logprobs
)
sequence_number += 1
yield self.build_response_event(response_output_text_done)
Expand Down Expand Up @@ -1446,9 +1511,10 @@ def _load_model_and_data_processor(self, model_id_and_revision: str):
"attn_implementation": args.attn_implementation,
"torch_dtype": torch_dtype,
"device_map": "auto",
"quantization_config": quantization_config,
"trust_remote_code": args.trust_remote_code,
}
if quantization_config is not None:
model_kwargs["quantization_config"] = quantization_config

config = AutoConfig.from_pretrained(model_id, **model_kwargs)
architecture = getattr(transformers, config.architectures[0])
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
"kenlm": "kenlm",
"keras": "keras>2.9,<2.16",
"keras-nlp": "keras-nlp>=0.3.1,<0.14.0",
"kernels": "kernels>=0.6.1,<0.7",
"kernels": "kernels>=0.6.1,<=0.9",
"librosa": "librosa",
"natten": "natten>=0.14.6,<0.15.0",
"nltk": "nltk<=3.8.1",
Expand All @@ -43,7 +43,7 @@
"onnxconverter-common": "onnxconverter-common",
"onnxruntime-tools": "onnxruntime-tools>=1.4.2",
"onnxruntime": "onnxruntime>=1.4.0",
"openai": "openai",
"openai": "openai>=1.98.0",
"opencv-python": "opencv-python",
"optimum-benchmark": "optimum-benchmark>=0.3.0",
"optuna": "optuna",
Expand Down
37 changes: 23 additions & 14 deletions src/transformers/generation/continuous_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,27 +182,29 @@ def __init__(
f"Number of key value heads {num_key_value_heads} must be divisible by tensor parallel size {tp_size}."
)
# If the model is using tensor parallelism, we need to adjust the number of heads accordingly.
self.num_key_value_heads //= tp_size
# self.num_key_value_heads //= tp_size

self.head_dim = (
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
)
self.num_hidden_layers = config.num_hidden_layers

# Calculate optimal block size and number if not provided
num_blocks = getattr(generation_config, "num_blocks", None)
num_blocks = getattr(generation_config, "num_blocks", 1024)
block_size = getattr(generation_config, "block_size", 32)
max_memory_percent = getattr(generation_config, "max_memory", 0.9)
num_blocks, max_batch_tokens = compute_optimal_blocks(
generation_config.max_new_tokens,
block_size=block_size,
head_dim=self.head_dim,
num_layers=self.num_hidden_layers,
num_heads=self.num_key_value_heads,
max_memory_percent=max_memory_percent,
dtype=dtype,
num_blocks=num_blocks,
)
max_batch_tokens = getattr(generation_config, "max_batch_tokens", 256)
if num_blocks is None or max_batch_tokens is None:
num_blocks, max_batch_tokens = compute_optimal_blocks(
generation_config.max_new_tokens,
block_size=block_size,
head_dim=self.head_dim,
num_layers=self.num_hidden_layers,
num_heads=self.num_key_value_heads,
max_memory_percent=max_memory_percent,
dtype=dtype,
num_blocks=num_blocks,
)
logger.warning(
f"Using calculated num_blocks={num_blocks}, block_size={block_size}, max concurrent requests {max_batch_tokens}"
)
Expand Down Expand Up @@ -960,7 +962,14 @@ def _build_tensors(

@traced
def _sync(self):
return self.output_ids.tolist()[0] # should be the only synch we do
if self.output_ids is not None:
try:
out = self.output_ids.tolist()[0] # should be the only synch we do
except Exception:
out = [0, 1]
else:
out = [0, 0]
return out

@traced
def _maybe_send_output(self, state: RequestState, token: int):
Expand Down Expand Up @@ -1250,7 +1259,7 @@ def _run_generation_loop(self):
self.model.device,
self.model.dtype,
num_requests=len(self.input_queue.queue),
tp_size=getattr(self.model, "tp_size"),
tp_size=getattr(self.model, "_tp_size", 8), # TODO quantized converted don't set this
)

scheduler = None
Expand Down
15 changes: 15 additions & 0 deletions src/transformers/integrations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,14 @@
"run_hp_search_sigopt",
"run_hp_search_wandb",
],
"mxfp4": [
"Mxfp4GptOssExperts",
"convert_moe_packed_tensors",
"dequantize",
"load_and_swizzle_mxfp4",
"quantize_to_mxfp4",
"replace_with_mxfp4_linear",
],
"peft": ["PeftAdapterMixin"],
"quanto": ["replace_with_quanto_layers"],
"spqr": ["replace_with_spqr_linear"],
Expand Down Expand Up @@ -255,6 +263,13 @@
run_hp_search_sigopt,
run_hp_search_wandb,
)
from .mxfp4 import (
Mxfp4GptOssExperts,
dequantize,
load_and_swizzle_mxfp4,
quantize_to_mxfp4,
replace_with_mxfp4_linear,
)
from .peft import PeftAdapterMixin
from .quanto import replace_with_quanto_layers
from .spqr import replace_with_spqr_linear
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/integrations/flash_paged.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ def paged_attention_forward(
"""
k, v = cache.update(k, v, module.layer_idx, cumulative_seqlens_k=cumulative_seqlens_k, **kwargs)

sliding_window = (-1, -1) if not getattr(module, "sliding_window", False) else (module.sliding_window, 0)
if implementation is not None:
flash_attn_varlen_func = implementation.flash_attn_varlen_func
custom_kwargs = {"s_aux": kwargs.get("s_aux")}
attn_output = flash_attn_varlen_func(
q.transpose(1, 2).squeeze(0).contiguous(),
k.transpose(1, 2).squeeze(0).contiguous(),
Expand All @@ -62,9 +64,9 @@ def paged_attention_forward(
max_seqlen_k,
softmax_scale=module.scaling,
causal=True, # kind of a must, it automatically aligns the mask for q < k
window_size=(-1, -1), # -1 means infinite context window
window_size=sliding_window, # -1 means infinite context window
# block_table=block_tables, -> torch.Tensor
# **kwargs,
**custom_kwargs,
)
if isinstance(attn_output, tuple):
attn_output = attn_output[0]
Expand Down
Loading
Loading