diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index f82d6c4a6d03..996029b00b89 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -617,6 +617,8 @@
title: OLMoE
- local: model_doc/open-llama
title: Open-Llama
+ - local: model_doc/openai_moe
+ title: OpenAIMoe
- local: model_doc/opt
title: OPT
- local: model_doc/pegasus
diff --git a/docs/source/en/model_doc/openai_moe.md b/docs/source/en/model_doc/openai_moe.md
new file mode 100644
index 000000000000..2c0b39013dc4
--- /dev/null
+++ b/docs/source/en/model_doc/openai_moe.md
@@ -0,0 +1,58 @@
+
+
+
+
+# OpenAIMoE
+
+## Overview
+
+The OpenAIMoE model was proposed in []() by .
+
+
+The abstract from the paper is the following:
+
+**
+
+Tips:
+
+
+
+This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/).
+The original code can be found [here]().
+
+
+## OpenAIMoeConfig
+
+[[autodoc]] OpenAIMoeConfig
+
+## OpenAIMoeModel
+
+[[autodoc]] OpenAIMoeModel
+ - forward
+
+## OpenAIMoeForCausalLM
+
+[[autodoc]] OpenAIMoeForCausalLM
+ - forward
diff --git a/setup.py b/setup.py
index 2a81f11cbab1..920e2adbbef8 100644
--- a/setup.py
+++ b/setup.py
@@ -128,7 +128,7 @@
# Keras pin - this is to make sure Keras 3 doesn't destroy us. Remove or change when we have proper support.
"keras>2.9,<2.16",
"keras-nlp>=0.3.1,<0.14.0", # keras-nlp 0.14 doesn't support keras 2, see pin on keras.
- "kernels>=0.6.1,<0.7",
+ "kernels>=0.6.1,<=0.9",
"librosa",
"natten>=0.14.6,<0.15.0",
"nltk<=3.8.1",
@@ -137,7 +137,7 @@
"onnxconverter-common",
"onnxruntime-tools>=1.4.2",
"onnxruntime>=1.4.0",
- "openai",
+ "openai>=1.98.0",
"opencv-python",
"optimum-benchmark>=0.3.0",
"optuna",
diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py
index 40593bc1fdca..f929e4af9eb3 100644
--- a/src/transformers/__init__.py
+++ b/src/transformers/__init__.py
@@ -277,6 +277,7 @@
"GPTQConfig",
"HiggsConfig",
"HqqConfig",
+ "Mxfp4Config",
"QuantoConfig",
"QuarkConfig",
"FPQuantConfig",
diff --git a/src/transformers/commands/serving.py b/src/transformers/commands/serving.py
index dbfd87cdf69a..c3f9bc2838b5 100644
--- a/src/transformers/commands/serving.py
+++ b/src/transformers/commands/serving.py
@@ -909,7 +909,16 @@ def generate_chat_completion(self, req: dict) -> Generator[str, None, None]:
inputs = inputs.to(model.device)
request_id = req.get("request_id", "req_0")
- generation_streamer = TextIteratorStreamer(processor, skip_special_tokens=True, skip_prompt=True)
+ # Temporary hack for GPTOSS 1: don't filter special tokens
+ skip_special_tokens = True
+ if "gptoss" in model.config.architectures[0].lower():
+ skip_special_tokens = False
+
+ generation_streamer = TextIteratorStreamer(
+ processor,
+ skip_special_tokens=skip_special_tokens,
+ skip_prompt=True,
+ )
generation_config = create_generation_config_from_req(req, model_generation_config=model.generation_config)
last_kv_cache = None
@@ -925,12 +934,21 @@ def generate_chat_completion(self, req: dict) -> Generator[str, None, None]:
}
def stream_chat_completion(streamer, _request_id):
+ # Temporary hack for GPTOS 2: filter out the CoT tokens. Full solution here implies defining new output
+ # classes and piping the reasoning trace into a new field
+ filter_cot = False
+ cot_trace_end = None
+ if "gptoss" in model.config.architectures[0].lower():
+ filter_cot = True
+ cot_trace_end = "<|channel|>final<|message|>"
+
# Thin wrapper to save the KV cache after generation
def generate_with_cache(**kwargs):
generate_output = model.generate(**kwargs)
self.last_kv_cache = generate_output.past_key_values
thread = Thread(target=generate_with_cache, kwargs=generation_kwargs)
+ results = ""
try:
thread.start()
@@ -941,6 +959,20 @@ def generate_with_cache(**kwargs):
yield self.build_chat_completion_chunk(request_id, role="assistant", model=model_id_and_revision)
for result in streamer:
+ # Temporary hack for GPTOS 3: don't emit the final "<|return|>"
+ if "gptoss" in model.config.architectures[0].lower():
+ if result.endswith("<|return|>"):
+ result = result[: -len("<|return|>")]
+ results += result
+
+ # (related to temporary hack 2)
+ if filter_cot:
+ if cot_trace_end in results: # end of reasoning trace observed -> stop filtering
+ filter_cot = False
+ continue
+ else:
+ continue
+
# ====== TOOL CALL LOGIC ======
if tool_model_family is not None:
# Start of a tool call: reset state variables, set `inside_tool_call`
@@ -1064,7 +1096,16 @@ def generate_response(self, req: dict) -> Generator[str, None, None]:
inputs = inputs.to(model.device)
request_id = req.get("previous_response_id", "req_0")
- generation_streamer = TextIteratorStreamer(processor, skip_special_tokens=True, skip_prompt=True)
+ # Temporary hack for GPTOSS 1: don't filter special tokens
+ skip_special_tokens = True
+ if "gptoss" in model.config.architectures[0].lower():
+ skip_special_tokens = False
+
+ generation_streamer = TextIteratorStreamer(
+ processor,
+ skip_special_tokens=skip_special_tokens,
+ skip_prompt=True,
+ )
generation_config = create_generation_config_from_req(req, model_generation_config=model.generation_config)
last_kv_cache = None
@@ -1081,6 +1122,14 @@ def generate_response(self, req: dict) -> Generator[str, None, None]:
}
def stream_response(streamer, _request_id):
+ # Temporary hack for GPTOS 2: filter out the CoT tokens. Full solution here implies defining new output
+ # classes and piping the reasoning trace into a new field
+ filter_cot = False
+ cot_trace_end = None
+ if "gptoss" in model.config.architectures[0].lower():
+ filter_cot = True
+ cot_trace_end = "<|channel|>final<|message|>"
+
# Thin wrapper to save the KV cache after generation
def generate_with_cache(**kwargs):
generate_output = model.generate(**kwargs)
@@ -1167,7 +1216,21 @@ def generate_with_cache(**kwargs):
# Stream the actual generated text
results = ""
for result in streamer:
+ # Temporary hack for GPTOS 3: don't emit the final "<|return|>"
+ if "gptoss" in model.config.architectures[0].lower():
+ if result.endswith("<|return|>"):
+ result = result[: -len("<|return|>")]
results += result
+
+ # (related to temporary hack 2)
+ if filter_cot:
+ if cot_trace_end in results: # end of reasoning trace observed -> stop filtering
+ filter_cot = False
+ results = "" # reset the results -> results will now track the final response
+ continue
+ else:
+ continue
+
response_output_text_delta = ResponseTextDeltaEvent(
type="response.output_text.delta",
item_id=f"msg_{request_id}",
@@ -1175,6 +1238,7 @@ def generate_with_cache(**kwargs):
output_index=output_index,
content_index=content_index,
delta=result,
+ logprobs=[{"token": "", "logprob": 99.9}], # TODO: add actual logprobs
)
sequence_number += 1
yield self.build_response_event(response_output_text_delta)
@@ -1187,6 +1251,7 @@ def generate_with_cache(**kwargs):
output_index=output_index,
content_index=0,
text=results,
+ logprobs=[{"token": "", "logprob": 99.9}], # TODO: add actual logprobs
)
sequence_number += 1
yield self.build_response_event(response_output_text_done)
@@ -1446,9 +1511,10 @@ def _load_model_and_data_processor(self, model_id_and_revision: str):
"attn_implementation": args.attn_implementation,
"torch_dtype": torch_dtype,
"device_map": "auto",
- "quantization_config": quantization_config,
"trust_remote_code": args.trust_remote_code,
}
+ if quantization_config is not None:
+ model_kwargs["quantization_config"] = quantization_config
config = AutoConfig.from_pretrained(model_id, **model_kwargs)
architecture = getattr(transformers, config.architectures[0])
diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py
index 7f6a723b2f11..758b4c590239 100644
--- a/src/transformers/dependency_versions_table.py
+++ b/src/transformers/dependency_versions_table.py
@@ -34,7 +34,7 @@
"kenlm": "kenlm",
"keras": "keras>2.9,<2.16",
"keras-nlp": "keras-nlp>=0.3.1,<0.14.0",
- "kernels": "kernels>=0.6.1,<0.7",
+ "kernels": "kernels>=0.6.1,<=0.9",
"librosa": "librosa",
"natten": "natten>=0.14.6,<0.15.0",
"nltk": "nltk<=3.8.1",
@@ -43,7 +43,7 @@
"onnxconverter-common": "onnxconverter-common",
"onnxruntime-tools": "onnxruntime-tools>=1.4.2",
"onnxruntime": "onnxruntime>=1.4.0",
- "openai": "openai",
+ "openai": "openai>=1.98.0",
"opencv-python": "opencv-python",
"optimum-benchmark": "optimum-benchmark>=0.3.0",
"optuna": "optuna",
diff --git a/src/transformers/generation/continuous_batching.py b/src/transformers/generation/continuous_batching.py
index 9470b3f337c6..57e57c959c26 100644
--- a/src/transformers/generation/continuous_batching.py
+++ b/src/transformers/generation/continuous_batching.py
@@ -182,7 +182,7 @@ def __init__(
f"Number of key value heads {num_key_value_heads} must be divisible by tensor parallel size {tp_size}."
)
# If the model is using tensor parallelism, we need to adjust the number of heads accordingly.
- self.num_key_value_heads //= tp_size
+ # self.num_key_value_heads //= tp_size
self.head_dim = (
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
@@ -190,19 +190,21 @@ def __init__(
self.num_hidden_layers = config.num_hidden_layers
# Calculate optimal block size and number if not provided
- num_blocks = getattr(generation_config, "num_blocks", None)
+ num_blocks = getattr(generation_config, "num_blocks", 1024)
block_size = getattr(generation_config, "block_size", 32)
max_memory_percent = getattr(generation_config, "max_memory", 0.9)
- num_blocks, max_batch_tokens = compute_optimal_blocks(
- generation_config.max_new_tokens,
- block_size=block_size,
- head_dim=self.head_dim,
- num_layers=self.num_hidden_layers,
- num_heads=self.num_key_value_heads,
- max_memory_percent=max_memory_percent,
- dtype=dtype,
- num_blocks=num_blocks,
- )
+ max_batch_tokens = getattr(generation_config, "max_batch_tokens", 256)
+ if num_blocks is None or max_batch_tokens is None:
+ num_blocks, max_batch_tokens = compute_optimal_blocks(
+ generation_config.max_new_tokens,
+ block_size=block_size,
+ head_dim=self.head_dim,
+ num_layers=self.num_hidden_layers,
+ num_heads=self.num_key_value_heads,
+ max_memory_percent=max_memory_percent,
+ dtype=dtype,
+ num_blocks=num_blocks,
+ )
logger.warning(
f"Using calculated num_blocks={num_blocks}, block_size={block_size}, max concurrent requests {max_batch_tokens}"
)
@@ -960,7 +962,14 @@ def _build_tensors(
@traced
def _sync(self):
- return self.output_ids.tolist()[0] # should be the only synch we do
+ if self.output_ids is not None:
+ try:
+ out = self.output_ids.tolist()[0] # should be the only synch we do
+ except Exception:
+ out = [0, 1]
+ else:
+ out = [0, 0]
+ return out
@traced
def _maybe_send_output(self, state: RequestState, token: int):
@@ -1250,7 +1259,7 @@ def _run_generation_loop(self):
self.model.device,
self.model.dtype,
num_requests=len(self.input_queue.queue),
- tp_size=getattr(self.model, "tp_size"),
+ tp_size=getattr(self.model, "_tp_size", 8), # TODO quantized converted don't set this
)
scheduler = None
diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py
index 0c4d169380b5..390db81867fd 100755
--- a/src/transformers/integrations/__init__.py
+++ b/src/transformers/integrations/__init__.py
@@ -119,6 +119,14 @@
"run_hp_search_sigopt",
"run_hp_search_wandb",
],
+ "mxfp4": [
+ "Mxfp4GptOssExperts",
+ "convert_moe_packed_tensors",
+ "dequantize",
+ "load_and_swizzle_mxfp4",
+ "quantize_to_mxfp4",
+ "replace_with_mxfp4_linear",
+ ],
"peft": ["PeftAdapterMixin"],
"quanto": ["replace_with_quanto_layers"],
"spqr": ["replace_with_spqr_linear"],
@@ -255,6 +263,13 @@
run_hp_search_sigopt,
run_hp_search_wandb,
)
+ from .mxfp4 import (
+ Mxfp4GptOssExperts,
+ dequantize,
+ load_and_swizzle_mxfp4,
+ quantize_to_mxfp4,
+ replace_with_mxfp4_linear,
+ )
from .peft import PeftAdapterMixin
from .quanto import replace_with_quanto_layers
from .spqr import replace_with_spqr_linear
diff --git a/src/transformers/integrations/flash_paged.py b/src/transformers/integrations/flash_paged.py
index a7bf5ae57717..096e3fdf9522 100644
--- a/src/transformers/integrations/flash_paged.py
+++ b/src/transformers/integrations/flash_paged.py
@@ -50,8 +50,10 @@ def paged_attention_forward(
"""
k, v = cache.update(k, v, module.layer_idx, cumulative_seqlens_k=cumulative_seqlens_k, **kwargs)
+ sliding_window = (-1, -1) if not getattr(module, "sliding_window", False) else (module.sliding_window, 0)
if implementation is not None:
flash_attn_varlen_func = implementation.flash_attn_varlen_func
+ custom_kwargs = {"s_aux": kwargs.get("s_aux")}
attn_output = flash_attn_varlen_func(
q.transpose(1, 2).squeeze(0).contiguous(),
k.transpose(1, 2).squeeze(0).contiguous(),
@@ -62,9 +64,9 @@ def paged_attention_forward(
max_seqlen_k,
softmax_scale=module.scaling,
causal=True, # kind of a must, it automatically aligns the mask for q < k
- window_size=(-1, -1), # -1 means infinite context window
+ window_size=sliding_window, # -1 means infinite context window
# block_table=block_tables, -> torch.Tensor
- # **kwargs,
+ **custom_kwargs,
)
if isinstance(attn_output, tuple):
attn_output = attn_output[0]
diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py
index e310ff8ac5db..31d3b3b14d6d 100644
--- a/src/transformers/integrations/flex_attention.py
+++ b/src/transformers/integrations/flex_attention.py
@@ -198,8 +198,8 @@ def default_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
mask_mod_maybe_combined = causal_mask_mod if attention_chunk_size is None else chunk_causal_mask_mod
if offsets is not None:
- q_offset = offsets[0]
- kv_offset = offsets[1]
+ q_offset = offsets[0].to(device)
+ kv_offset = offsets[1].to(device)
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
offset_q = q_idx + q_offset
@@ -241,6 +241,7 @@ def flex_attention_forward(
scaling: Optional[float] = None,
softcap: Optional[float] = None,
head_mask: Optional[torch.Tensor] = None,
+ s_aux: Optional[torch.Tensor] = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
if head_mask is not None:
@@ -271,6 +272,12 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
score = score + score_mask[batch_idx][0][q_idx][kv_idx]
if head_mask is not None:
score = score + head_mask[batch_idx][head_idx][0][0]
+ if s_aux is not None:
+ logits_max = torch.max(score, dim=-1, keepdim=True).values
+ sinks = torch.exp(s_aux - logits_max)
+ unnormalized_scores = torch.exp(score - logits_max)
+ normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks
+ score = unnormalized_scores / normalizer
return score
enable_gqa = True
diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py
index e824a5ab1f0e..ad5e08d8da4d 100644
--- a/src/transformers/integrations/hub_kernels.py
+++ b/src/transformers/integrations/hub_kernels.py
@@ -18,6 +18,7 @@
from kernels import (
Device,
LayerRepository,
+ Mode,
register_kernel_mapping,
replace_kernel_forward_from_hub,
use_kernel_forward_from_hub,
@@ -44,7 +45,14 @@
repo_id="kernels-community/liger_kernels",
layer_name="LigerRMSNorm",
# revision="pure-layer-test",
- )
+ ),
+ "rocm": {
+ Mode.INFERENCE: LayerRepository(
+ repo_id="kernels-community/liger_kernels",
+ layer_name="LigerRMSNorm",
+ # revision="pure-layer-test",
+ )
+ },
},
"MLP": {
"cuda": LayerRepository(
@@ -53,10 +61,22 @@
)
},
"MegaBlocksMoeMLP": {
- "cuda": LayerRepository(
- repo_id="kernels-community/megablocks",
- layer_name="MegaBlocksMoeMLP",
- )
+ "cuda": {
+ Mode.TRAINING: LayerRepository(
+ repo_id="kernels-community/megablocks",
+ layer_name="MegaBlocksMoeMLP",
+ ),
+ Mode.INFERENCE: LayerRepository(
+ repo_id="kernels-community/megablocks",
+ layer_name="MegaBlocksMoeMLP",
+ ),
+ },
+ "rocm": {
+ Mode.INFERENCE: LayerRepository(
+ repo_id="ahadnagy/megablocks",
+ layer_name="MegaBlocksMoeMLP",
+ )
+ },
},
}
diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py
new file mode 100644
index 000000000000..86517671b5f3
--- /dev/null
+++ b/src/transformers/integrations/mxfp4.py
@@ -0,0 +1,470 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ..utils import is_accelerate_available, is_torch_available, logging
+
+
+if is_torch_available():
+ import torch
+ from torch import nn
+
+if is_accelerate_available():
+ from accelerate import init_empty_weights
+
+import re
+
+
+logger = logging.get_logger(__name__)
+
+FP4_VALUES = [
+ +0.0,
+ +0.5,
+ +1.0,
+ +1.5,
+ +2.0,
+ +3.0,
+ +4.0,
+ +6.0,
+ -0.0,
+ -0.5,
+ -1.0,
+ -1.5,
+ -2.0,
+ -3.0,
+ -4.0,
+ -6.0,
+]
+
+
+# Copied from GPT_OSS repo and vllm
+def quantize_to_mxfp4(w):
+ from triton_kernels.numerics_details.mxfp import downcast_to_mxfp
+
+ w, w_scale = downcast_to_mxfp(w.to(torch.bfloat16), torch.uint8, axis=1)
+ w, w_scale = swizzle_mxfp4(w, w_scale)
+ return w, w_scale
+
+
+def swizzle_mxfp4(w, w_scale):
+ from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
+ from triton_kernels.tensor_details import layout
+ from triton_kernels.tensor_details.layout import StridedLayout
+
+ value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
+ w = convert_layout(wrap_torch_tensor(w, dtype=FP4), value_layout, **value_layout_opts)
+ # TODO : add that when we are actually sure that it works on B200
+ # if torch.cuda.get_device_capability()[0] == 10:
+ # constraints = {
+ # "is_persistent": True,
+ # "epilogue_subtile": 1,
+ # }
+ # opt_flags.update_opt_flags_constraints(constraints)
+ # # transpose the tensor so that the quantization axis is on dim1
+
+ # TODO: there is still an issue with the scales on hopper
+ # scale_layout, scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout(mx_axis=1, num_warps=8)
+ # w_scale = convert_layout(wrap_torch_tensor(w_scale), scale_layout, **scale_layout_opts)
+ w_scale = convert_layout(wrap_torch_tensor(w_scale), StridedLayout)
+ return w, w_scale
+
+
+# Copied from GPT_OSS repo
+# TODO: Add absolute link when the repo is public
+def convert_moe_packed_tensors(
+ blocks,
+ scales,
+ *,
+ dtype: torch.dtype = torch.bfloat16,
+ rows_per_chunk: int = 32768 * 1024,
+) -> torch.Tensor:
+ import math
+
+ # Check if blocks and scales are on CPU, and move to GPU if so
+ if not blocks.is_cuda and torch.cuda.is_available():
+ blocks = blocks.cuda()
+ scales = scales.cuda()
+
+ scales = scales.to(torch.int32) - 127
+
+ assert blocks.shape[:-1] == scales.shape, f"{blocks.shape=} does not match {scales.shape=}"
+
+ lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device)
+
+ *prefix_shape, G, B = blocks.shape
+ rows_total = math.prod(prefix_shape) * G
+
+ blocks = blocks.reshape(rows_total, B)
+ scales = scales.reshape(rows_total, 1)
+
+ out = torch.empty(rows_total, B * 2, dtype=dtype, device=blocks.device)
+
+ for r0 in range(0, rows_total, rows_per_chunk):
+ r1 = min(r0 + rows_per_chunk, rows_total)
+
+ blk = blocks[r0:r1]
+ exp = scales[r0:r1]
+
+ # nibble indices -> int64
+ idx_lo = (blk & 0x0F).to(torch.long)
+ idx_hi = (blk >> 4).to(torch.long)
+
+ sub = out[r0:r1]
+ sub[:, 0::2] = lut[idx_lo]
+ sub[:, 1::2] = lut[idx_hi]
+
+ torch.ldexp(sub, exp, out=sub)
+ del idx_lo, idx_hi, blk, exp, sub
+
+ out = out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2)
+
+ # TODO: Delete after making sure this is not necessary! since we go back to cpu in the end in create_quantized_param using .to(target_device)
+ # Move back to CPU if needed
+ # if need_to_move_back:
+ # out = out.cpu()
+ del blocks, scales, lut
+ return out
+
+
+class Mxfp4GptOssExperts(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ self.num_experts = config.num_local_experts
+ self.intermediate_size = config.intermediate_size
+ self.hidden_size = config.hidden_size
+
+ self.gate_up_proj_blocks = nn.Parameter(
+ torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, 16, dtype=torch.uint8),
+ requires_grad=False,
+ )
+ self.gate_up_proj_scales = nn.Parameter(
+ torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, dtype=torch.uint8),
+ requires_grad=False,
+ )
+ self.gate_up_proj_bias = nn.Parameter(
+ torch.zeros(self.num_experts, 2 * self.intermediate_size, dtype=torch.float32), requires_grad=False
+ )
+
+ self.down_proj_blocks = nn.Parameter(
+ torch.zeros((self.num_experts, self.hidden_size, self.intermediate_size // 32, 16), dtype=torch.uint8),
+ requires_grad=False,
+ )
+ self.down_proj_scales = nn.Parameter(
+ torch.zeros(self.num_experts, self.hidden_size, self.intermediate_size // 32, dtype=torch.uint8),
+ requires_grad=False,
+ )
+ self.down_proj_bias = nn.Parameter(
+ torch.zeros(self.num_experts, self.hidden_size, dtype=torch.float32), requires_grad=False
+ )
+ self.alpha = 1.702
+
+ self.gate_up_proj_precision_config = None
+ self.down_proj_precision_config = None
+
+ def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter_idx) -> torch.Tensor:
+ from triton_kernels.matmul_ogs import FnSpecs, FusedActivation, matmul_ogs
+ from triton_kernels.swiglu import swiglu_fn
+
+ with torch.cuda.device(hidden_states.device):
+ act = FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), (self.alpha, None), 2)
+
+ intermediate_cache1 = matmul_ogs(
+ hidden_states,
+ self.gate_up_proj,
+ self.gate_up_proj_bias.to(torch.float32),
+ routing_data,
+ gather_indx=gather_idx,
+ precision_config=self.gate_up_proj_precision_config,
+ gammas=None,
+ fused_activation=act,
+ )
+
+ intermediate_cache3 = matmul_ogs(
+ intermediate_cache1,
+ self.down_proj,
+ self.down_proj_bias.to(torch.float32),
+ routing_data,
+ scatter_indx=scatter_idx,
+ precision_config=self.down_proj_precision_config,
+ gammas=routing_data.gate_scal,
+ )
+
+ return intermediate_cache3
+
+
+# Adapted from GPT_OSS repo
+# TODO: Add absolute link when the repo is public
+def routing_torch_dist(
+ logits,
+ n_expts_act,
+):
+ import os
+
+ from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, compute_expt_data_torch
+
+ with torch.cuda.device(logits.device):
+ world_size = torch.distributed.get_world_size()
+ rank = int(os.environ.get("LOCAL_RANK", 0))
+ replace_value = -1
+
+ n_tokens = logits.shape[0]
+ n_expts_tot = logits.shape[1]
+
+ n_local_experts = n_expts_tot // world_size
+ local_expert_start = rank * n_local_experts
+ local_expert_end = (rank + 1) * n_local_experts
+
+ n_gates_pad = n_tokens * n_expts_act
+
+ def topk(vals, k):
+ tk_indx = torch.argsort(-vals, dim=1, stable=True)[:, :k]
+ tk_indx = tk_indx.long()
+ tk_val = torch.take_along_dim(vals, tk_indx, dim=1)
+ return tk_val, tk_indx.int()
+
+ expt_scal, expt_indx = topk(logits, n_expts_act)
+ expt_scal = torch.softmax(expt_scal, dim=-1)
+ expt_indx, sort_indices = torch.sort(expt_indx, dim=1)
+ expt_scal = torch.gather(expt_scal, 1, sort_indices)
+
+ # Flatten and mask for local experts
+ expt_scal = expt_scal.reshape(-1)
+
+ hist = torch.histc(expt_indx, bins=n_expts_tot, max=n_expts_tot - 1)[local_expert_start:local_expert_end]
+
+ expt_indx = expt_indx.view(-1).to(torch.int32)
+
+ # we use a large value to replace the indices that are not in the local expert range
+ var = 1000
+ expt_indx = torch.where(expt_indx < local_expert_start, var, expt_indx)
+ topk_indx = torch.argsort(expt_indx, stable=True).to(torch.int32)
+ gate_indx = torch.argsort(topk_indx).to(torch.int32)
+ expt_indx = torch.where(expt_indx < local_expert_end, expt_indx, replace_value)
+ expt_indx = torch.where(local_expert_start <= expt_indx, expt_indx, replace_value)
+
+ gate_indx = torch.where(expt_indx == replace_value, replace_value, gate_indx)
+ gate_scal = expt_scal[topk_indx]
+
+ topk_indx = torch.where(gate_indx[topk_indx] == replace_value, replace_value, topk_indx)
+
+ # # Routing metadata for local expert computation
+ gather_indx = GatherIndx(src_indx=topk_indx.int(), dst_indx=gate_indx.int())
+ scatter_indx = ScatterIndx(src_indx=gate_indx.int(), dst_indx=topk_indx.int())
+
+ expt_data = compute_expt_data_torch(hist, n_local_experts, n_gates_pad)
+
+ hitted_experts = n_expts_act
+ return RoutingData(gate_scal, hist, n_local_experts, hitted_experts, expt_data), gather_indx, scatter_indx
+
+
+def mlp_forward(self, hidden_states):
+ import torch.distributed as dist
+
+ if dist.is_available() and dist.is_initialized():
+ routing = routing_torch_dist
+ else:
+ from triton_kernels.routing import routing
+
+ routing = routing
+ batch_size = hidden_states.shape[0]
+ hidden_states = hidden_states.reshape(-1, self.router.hidden_dim)
+ router_logits = nn.functional.linear(hidden_states, self.router.weight, self.router.bias)
+ routing_data, gather_idx, scatter_idx = routing(router_logits, self.router.top_k)
+ routed_out = self.experts(hidden_states, routing_data, gather_idx, scatter_idx)
+ routed_out = routed_out.reshape(batch_size, -1, self.router.hidden_dim)
+ return routed_out, router_logits
+
+
+def should_convert_module(current_key_name, patterns):
+ current_key_name_str = ".".join(current_key_name)
+ if not any(
+ re.match(f"{key}\\.", current_key_name_str) or re.match(f"{key}", current_key_name_str) for key in patterns
+ ):
+ return True
+ return False
+
+
+def dequantize(module, param_name, param_value, target_device, dq_param_name, **kwargs):
+ from ..integrations.tensor_parallel import shard_and_distribute_module
+
+ model = kwargs.get("model", None)
+ empty_param = kwargs.get("empty_param", None)
+ casting_dtype = kwargs.get("casting_dtype", None)
+ to_contiguous = kwargs.get("to_contiguous", None)
+ rank = kwargs.get("rank", None)
+ device_mesh = kwargs.get("device_mesh", None)
+
+ for proj in ["gate_up_proj", "down_proj"]:
+ if proj in param_name:
+ if device_mesh is not None:
+ param_value = shard_and_distribute_module(
+ model,
+ param_value,
+ empty_param,
+ dq_param_name,
+ casting_dtype,
+ to_contiguous,
+ rank,
+ device_mesh,
+ set_param=False,
+ )
+ blocks_attr = f"{proj}_blocks"
+ scales_attr = f"{proj}_scales"
+ setattr(module, param_name.rsplit(".", 1)[1], param_value)
+ if hasattr(module, blocks_attr) and hasattr(module, scales_attr):
+ dequantized = convert_moe_packed_tensors(getattr(module, blocks_attr), getattr(module, scales_attr))
+ dequantized = dequantized.transpose(1, 2).contiguous().to(target_device)
+ # TODO: this is perhaps necessary since if target_device is cpu, and the param was on gpu
+ if target_device == "cpu" and torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ setattr(module, proj, torch.nn.Parameter(dequantized))
+ delattr(module, blocks_attr)
+ delattr(module, scales_attr)
+
+
+def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, **kwargs):
+ from triton_kernels.matmul_ogs import FlexCtx, InFlexData, PrecisionConfig
+
+ from ..integrations.tensor_parallel import shard_and_distribute_module
+
+ model = kwargs.get("model", None)
+ empty_param = kwargs.get("empty_param", None)
+ casting_dtype = kwargs.get("casting_dtype", None)
+ to_contiguous = kwargs.get("to_contiguous", None)
+ rank = kwargs.get("rank", None)
+ device_mesh = kwargs.get("device_mesh", None)
+
+ for proj in ["gate_up_proj", "down_proj"]:
+ if proj in param_name:
+ if device_mesh is not None:
+ shard_and_distribute_module(
+ model, param_value, empty_param, param_name, casting_dtype, to_contiguous, rank, device_mesh
+ )
+ else:
+ setattr(module, param_name.rsplit(".", 1)[1], torch.nn.Parameter(param_value, requires_grad=False))
+ blocks_attr = f"{proj}_blocks"
+ scales_attr = f"{proj}_scales"
+ blocks = getattr(module, blocks_attr)
+ scales = getattr(module, scales_attr)
+ # Check if both blocks and scales both not on on meta device
+ if blocks.device.type != "meta" and scales.device.type != "meta":
+ # need it for ep
+ local_experts = blocks.size(0)
+ if proj == "gate_up_proj":
+ blocks = blocks.view(local_experts, module.intermediate_size * 2, -1)
+ else:
+ blocks = blocks.view(local_experts, -1, module.intermediate_size // 2)
+ # TODO: we need to have the weights on cuda, refactor later
+ if getattr(target_device, "type", target_device) == "cpu":
+ target_device = "cuda"
+ # TODO: check why we still do move the tensors despite the context manager
+ blocks = blocks.to(target_device)
+ scales = scales.to(target_device)
+ with torch.cuda.device(target_device):
+ triton_weight_tensor, weight_scale = swizzle_mxfp4(
+ blocks.transpose(-2, -1), scales.transpose(-2, -1)
+ )
+
+ # need to overwrite the shapes for the kernels
+ if proj == "gate_up_proj":
+ triton_weight_tensor.shape = torch.Size(
+ [local_experts, module.hidden_size, module.intermediate_size * 2]
+ )
+ else:
+ triton_weight_tensor.shape = torch.Size(
+ [local_experts, module.intermediate_size, module.hidden_size]
+ )
+
+ # triton_weight_tensor is what needs to be passed in oai kernels. It stores the data, the shapes and any more objects. It is like a subtensor
+ setattr(module, proj, triton_weight_tensor)
+ setattr(
+ module,
+ f"{proj}_precision_config",
+ PrecisionConfig(weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData())),
+ )
+
+ # delete blocks and scales
+ delattr(module, scales_attr)
+ delattr(module, blocks_attr)
+ # setattr(module, blocks_attr, torch.nn.Parameter(triton_weight_tensor.storage.data, requires_grad=False))
+ del blocks
+
+
+def _replace_with_mxfp4_linear(
+ model,
+ modules_to_not_convert=None,
+ current_key_name=None,
+ quantization_config=None,
+ has_been_replaced=False,
+ config=None,
+):
+ if current_key_name is None:
+ current_key_name = []
+
+ for name, module in model.named_children():
+ current_key_name.append(name)
+ if not should_convert_module(current_key_name, modules_to_not_convert):
+ current_key_name.pop(-1)
+ continue
+ if module.__class__.__name__ == "GptOssExperts" and not quantization_config.dequantize:
+ with init_empty_weights():
+ model._modules[name] = Mxfp4GptOssExperts(config)
+ has_been_replaced = True
+ if module.__class__.__name__ == "GptOssMLP" and not quantization_config.dequantize:
+ from types import MethodType
+
+ module.forward = MethodType(mlp_forward, module)
+ if len(list(module.children())) > 0:
+ _, has_been_replaced = _replace_with_mxfp4_linear(
+ module,
+ modules_to_not_convert,
+ current_key_name,
+ quantization_config,
+ has_been_replaced=has_been_replaced,
+ config=config,
+ )
+ current_key_name.pop(-1)
+ return model, has_been_replaced
+
+
+def replace_with_mxfp4_linear(
+ model,
+ modules_to_not_convert=None,
+ current_key_name=None,
+ quantization_config=None,
+ config=None,
+):
+ if quantization_config.dequantize:
+ return model
+
+ modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
+
+ if quantization_config.modules_to_not_convert is not None:
+ modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
+ modules_to_not_convert = list(set(modules_to_not_convert))
+ model, has_been_replaced = _replace_with_mxfp4_linear(
+ model,
+ modules_to_not_convert,
+ current_key_name,
+ quantization_config,
+ config=config,
+ )
+ if not has_been_replaced:
+ logger.warning(
+ "You are loading your model using mixed-precision FP4 quantization but no linear modules were found in your model."
+ " Please double check your model architecture, or submit an issue on github if you think this is"
+ " a bug."
+ )
+
+ return model
diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py
index 353cc1d08174..ff429b9dc744 100644
--- a/src/transformers/integrations/tensor_parallel.py
+++ b/src/transformers/integrations/tensor_parallel.py
@@ -657,7 +657,7 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype,
@staticmethod
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
if hasattr(mod, "bias") and mod.bias is not None:
- mod._bias = mod.bias
+ mod._bias = mod.bias.to_local()
mod.bias = None
input_tensor = inputs[0]
@@ -997,7 +997,7 @@ def add_tensor_parallel_hooks_to_module(
def shard_and_distribute_module(
- model, param, empty_param, parameter_name, param_casting_dtype, is_contiguous, rank, device_mesh
+ model, param, empty_param, parameter_name, param_casting_dtype, is_contiguous, rank, device_mesh, set_param=True
): # TODO: rename to shard_and_distribute_param
r"""
This function is called in `from_pretrained` when loading a model's checkpoints.
diff --git a/src/transformers/masking_utils.py b/src/transformers/masking_utils.py
index 6a73d0dd5def..901572917561 100644
--- a/src/transformers/masking_utils.py
+++ b/src/transformers/masking_utils.py
@@ -48,7 +48,7 @@ def and_masks(*mask_functions: list[Callable]) -> Callable:
def and_mask(batch_idx, head_idx, q_idx, kv_idx):
result = q_idx.new_ones((), dtype=torch.bool)
for mask in mask_functions:
- result = result & mask(batch_idx, head_idx, q_idx, kv_idx)
+ result = result & mask(batch_idx, head_idx, q_idx, kv_idx).to(result.device)
return result
return and_mask
@@ -62,7 +62,7 @@ def or_masks(*mask_functions: list[Callable]) -> Callable:
def or_mask(batch_idx, head_idx, q_idx, kv_idx):
result = q_idx.new_zeros((), dtype=torch.bool)
for mask in mask_functions:
- result = result | mask(batch_idx, head_idx, q_idx, kv_idx)
+ result = result | mask(batch_idx, head_idx, q_idx, kv_idx).to(result.device)
return result
return or_mask
diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py
index 47744eaca3f2..bfab34703971 100644
--- a/src/transformers/modeling_flash_attention_utils.py
+++ b/src/transformers/modeling_flash_attention_utils.py
@@ -389,7 +389,8 @@ def _flash_attention_forward(
flash_kwargs["deterministic"] = det
if softcap is not None:
flash_kwargs["softcap"] = softcap
-
+ if "s_aux" in kwargs:
+ flash_kwargs["s_aux"] = kwargs.get("s_aux")
query_states, key_states, value_states = fa_peft_integration_check(
query_states, key_states, value_states, target_dtype
)
diff --git a/src/transformers/modeling_rope_utils.py b/src/transformers/modeling_rope_utils.py
index bdb1dd64ce9a..04b2f5ccccfe 100644
--- a/src/transformers/modeling_rope_utils.py
+++ b/src/transformers/modeling_rope_utils.py
@@ -252,10 +252,13 @@ def find_correction_dim(num_rotations, dim, base, max_position_embeddings):
"""Inverse dimension formula to find the dimension based on the number of rotations"""
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
- def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings):
+ def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings, truncate):
"""Find dimension range bounds based on rotations"""
- low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings))
- high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
+ low = find_correction_dim(low_rot, dim, base, max_position_embeddings)
+ high = find_correction_dim(high_rot, dim, base, max_position_embeddings)
+ if truncate:
+ low = low = math.floor(low)
+ high = math.ceil(high)
return max(low, 0), min(high, dim - 1)
def linear_ramp_factor(min, max, dim):
@@ -272,7 +275,8 @@ def linear_ramp_factor(min, max, dim):
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (factor * pos_freqs)
- low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings)
+ truncate = config.rope_scaling.get("truncate", True)
+ low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings, truncate)
# Get n-dimensional rotational scaling corrected for extrapolation
inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(device=device, dtype=torch.float)
@@ -465,6 +469,7 @@ def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[se
"original_max_position_embeddings",
"mscale",
"mscale_all_dim",
+ "truncate",
}
received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py
index 03e9cf531470..0eab1cbab9d8 100644
--- a/src/transformers/modeling_utils.py
+++ b/src/transformers/modeling_utils.py
@@ -51,6 +51,7 @@
from torchao.quantization import Int4WeightOnlyConfig
from .configuration_utils import PretrainedConfig
+from .distributed import DistributedConfig
from .dynamic_module_utils import custom_object_save
from .generation import CompileConfig, GenerationConfig
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
@@ -709,6 +710,7 @@ def _infer_parameter_dtype(
if hf_quantizer is not None and hf_quantizer.quantization_config.quant_method in {
QuantizationMethod.HQQ,
QuantizationMethod.QUARK,
+ QuantizationMethod.MXFP4,
}:
return True, None
else:
@@ -778,9 +780,8 @@ def _load_state_dict_into_meta_model(
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device)
for param_name, empty_param in state_dict.items():
- if param_name not in expected_keys:
+ if param_name not in expected_keys: # when loading from ckpt, we skip param if doesnt exist in modeling
continue
-
# we need to use serialized_param_name as file pointer is untouched
if is_meta_state_dict:
# This is the name of the parameter as it appears on disk file
@@ -788,7 +789,6 @@ def _load_state_dict_into_meta_model(
param = file_pointer.get_slice(serialized_param_name)
else:
param = empty_param.to(tensor_device) # It is actually not empty!
-
to_contiguous, casting_dtype = _infer_parameter_dtype(
model,
param_name,
@@ -797,17 +797,47 @@ def _load_state_dict_into_meta_model(
hf_quantizer,
)
- if device_mesh is not None: # In this case, the param is already on the correct device!
- shard_and_distribute_module(
- model,
- param,
- empty_param,
- param_name,
- casting_dtype,
- to_contiguous,
- device_mesh.get_local_rank(),
- device_mesh,
- )
+ if device_mesh is not None:
+ if (
+ not is_quantized
+ or (not hf_quantizer.requires_parameters_quantization)
+ or (
+ not hf_quantizer.check_quantized_param(
+ model,
+ param,
+ param_name,
+ state_dict,
+ device_map=device_map,
+ )
+ )
+ ): # In this case, the param is already on the correct device!
+ shard_and_distribute_module(
+ model,
+ param,
+ empty_param,
+ param_name,
+ casting_dtype,
+ to_contiguous,
+ device_mesh.get_local_rank(),
+ device_mesh,
+ )
+ else: # we have a device mesh but the param needs to be quantized, so we shard inside create_quantized_param:
+ sharding_kwargs = {
+ "empty_param": empty_param,
+ "casting_dtype": casting_dtype,
+ "to_contiguous": to_contiguous,
+ "rank": device_mesh.get_local_rank(),
+ "device_mesh": device_mesh,
+ }
+ hf_quantizer.create_quantized_param(
+ model,
+ param,
+ param_name,
+ device_mesh.get_local_rank(),
+ state_dict,
+ unexpected_keys,
+ **sharding_kwargs,
+ )
else:
param = param[...]
if casting_dtype is not None:
@@ -852,17 +882,24 @@ def _load_state_dict_into_meta_model(
hf_quantizer.create_quantized_param(
model, param, param_name, param_device, state_dict, unexpected_keys
)
+
# For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU
# and then cast it to CPU to avoid excessive memory usage on each GPU
# in comparison to the sharded model across GPUs.
if is_fsdp_enabled() or is_deepspeed_zero3_enabled():
+ param_name = hf_quantizer.update_param_name(param_name)
module, param_type = get_module_from_name(model, param_name)
value = getattr(module, param_type)
+ # special case for GptOssForCausalLM, we wait for the param to be leave the meta device before casting it to cpu
+ if model.__class__.__name__ == "GptOssForCausalLM" and value.device.type == "meta":
+ continue
param_to = "cpu"
if is_fsdp_enabled() and not is_local_dist_rank_0():
param_to = "meta"
val_kwargs = {}
- if hasattr(module, "weight") and module.weight.__class__.__name__ == "Int8Params":
+ if (hasattr(module, "weight") and module.weight.__class__.__name__ == "Int8Params") or (
+ value.dtype == torch.uint8 or value.dtype == torch.int8
+ ):
val_kwargs["requires_grad"] = False
value = type(value)(value.data.to(param_to), **val_kwargs, **value.__dict__)
setattr(module, param_type, value)
@@ -4578,7 +4615,6 @@ def from_pretrained(
load_in_8bit = kwargs.pop("load_in_8bit", False)
load_in_4bit = kwargs.pop("load_in_4bit", False)
quantization_config = kwargs.pop("quantization_config", None)
- distributed_config = kwargs.pop("distributed_config", None)
subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)
variant = kwargs.pop("variant", None)
@@ -4588,6 +4624,7 @@ def from_pretrained(
gguf_file = kwargs.pop("gguf_file", None)
tp_plan = kwargs.pop("tp_plan", None)
tp_size = kwargs.pop("tp_size", None)
+ distributed_config: DistributedConfig = kwargs.pop("distributed_config", None)
device_mesh = kwargs.pop("device_mesh", None)
trust_remote_code = kwargs.pop("trust_remote_code", None)
use_kernels = kwargs.pop("use_kernels", False)
@@ -4851,10 +4888,9 @@ def from_pretrained(
config = hf_quantizer.update_tp_plan(config)
# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
- if hasattr(hf_quantizer.quantization_config.quant_method, "value"):
- user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value
- else:
- user_agent["quant"] = hf_quantizer.quantization_config.quant_method
+ if not getattr(hf_quantizer.quantization_config, "dequantize", False):
+ quant_method = hf_quantizer.quantization_config.quant_method
+ user_agent["quant"] = getattr(quant_method, "value", quant_method)
if gguf_file is not None and hf_quantizer is not None:
raise ValueError(
@@ -4949,9 +4985,6 @@ def from_pretrained(
# Let's make sure we don't run the init function of buffer modules
model = cls(config, *model_args, **model_kwargs)
- if _torch_distributed_available and device_mesh is not None:
- model = distribute_model(model, distributed_config, device_mesh, tp_size)
-
# Make sure to tie the weights correctly
model.tie_weights()
@@ -4980,7 +5013,11 @@ def from_pretrained(
if hf_quantizer is not None:
hf_quantizer.preprocess_model(
- model=model, device_map=device_map, keep_in_fp32_modules=model._keep_in_fp32_modules, config=config
+ model=model,
+ device_map=device_map,
+ keep_in_fp32_modules=model._keep_in_fp32_modules,
+ config=config,
+ use_kernels=use_kernels,
)
# We store the original dtype for quantized models as we cannot easily retrieve it
# once the weights have been quantized
@@ -4997,6 +5034,9 @@ def _assign_original_dtype(module):
config._pre_quantization_dtype = original_dtype
_assign_original_dtype(model)
+ if _torch_distributed_available and device_mesh is not None:
+ model = distribute_model(model, distributed_config, device_mesh, tp_size)
+
# Prepare the full device map
if device_map is not None:
device_map = _get_device_map(model, device_map, max_memory, hf_quantizer, torch_dtype, keep_in_fp32_regex)
@@ -5043,6 +5083,11 @@ def _assign_original_dtype(module):
# check if using kernels
if use_kernels:
+ if not is_kernels_available():
+ raise ValueError(
+ "Kernels are not available. To use kernels, please install kernels using `pip install kernels`"
+ )
+
from kernels import Device, kernelize
kernelize(model, device=Device(type=model.device.type))
@@ -5116,8 +5161,8 @@ def _assign_original_dtype(module):
dispatch_model(model, **device_map_kwargs)
if hf_quantizer is not None:
- hf_quantizer.postprocess_model(model, config=config)
model.hf_quantizer = hf_quantizer
+ hf_quantizer.postprocess_model(model, config=config)
if _adapter_model_path is not None:
adapter_kwargs["key_mapping"] = key_mapping
@@ -6029,7 +6074,16 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict,
if param_name in tied_param_names:
continue
- param = model.get_parameter_or_buffer(param_name)
+ # For example in the case of MXFP4 quantization, we need to update the param name to the original param name
+ # because the checkpoint contains blocks, and scales, but since we are dequantizing, we need to use the original param name
+ if hf_quantizer is not None:
+ param_name = hf_quantizer.update_param_name(param_name)
+
+ try:
+ param = model.get_parameter_or_buffer(param_name)
+ except AttributeError:
+ raise AttributeError(f"Parameter {param_name} not found in model")
+
# The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
param_byte_count = param.numel() * param.element_size()
diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py
index 56c2f3fcdcf7..704c4950895a 100644
--- a/src/transformers/models/__init__.py
+++ b/src/transformers/models/__init__.py
@@ -140,6 +140,7 @@
from .gpt_neo import *
from .gpt_neox import *
from .gpt_neox_japanese import *
+ from .gpt_oss import *
from .gpt_sw3 import *
from .gptj import *
from .granite import *
diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py
index 129c5ea300b0..15d10c756618 100644
--- a/src/transformers/models/auto/configuration_auto.py
+++ b/src/transformers/models/auto/configuration_auto.py
@@ -172,6 +172,7 @@
("gpt_neo", "GPTNeoConfig"),
("gpt_neox", "GPTNeoXConfig"),
("gpt_neox_japanese", "GPTNeoXJapaneseConfig"),
+ ("gpt_oss", "GptOssConfig"),
("gptj", "GPTJConfig"),
("gptsan-japanese", "GPTSanJapaneseConfig"),
("granite", "GraniteConfig"),
@@ -577,6 +578,7 @@
("gpt_neo", "GPT Neo"),
("gpt_neox", "GPT NeoX"),
("gpt_neox_japanese", "GPT NeoX Japanese"),
+ ("gpt_oss", "GptOss"),
("gptj", "GPT-J"),
("gptsan-japanese", "GPTSAN-japanese"),
("granite", "Granite"),
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index 84672e80a75d..5554de103cbb 100644
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -174,6 +174,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("gpt_neo", "GPTNeoModel"),
("gpt_neox", "GPTNeoXModel"),
("gpt_neox_japanese", "GPTNeoXJapaneseModel"),
+ ("gpt_oss", "GptOssModel"),
("gptj", "GPTJModel"),
("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"),
("granite", "GraniteModel"),
@@ -642,6 +643,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("gpt_neo", "GPTNeoForCausalLM"),
("gpt_neox", "GPTNeoXForCausalLM"),
("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"),
+ ("gpt_oss", "GptOssForCausalLM"),
("gptj", "GPTJForCausalLM"),
("granite", "GraniteForCausalLM"),
("granitemoe", "GraniteMoeForCausalLM"),
diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py
index b623010e7b1c..232221782f78 100644
--- a/src/transformers/models/auto/tokenization_auto.py
+++ b/src/transformers/models/auto/tokenization_auto.py
@@ -300,6 +300,7 @@
("gpt_neo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("gpt_neox", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
("gpt_neox_japanese", ("GPTNeoXJapaneseTokenizer", None)),
+ ("gpt_oss", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
("gptj", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("gptsan-japanese", ("GPTSanJapaneseTokenizer", None)),
("granite", ("GPT2Tokenizer", None)),
diff --git a/src/transformers/models/gpt_oss/__init__.py b/src/transformers/models/gpt_oss/__init__.py
new file mode 100644
index 000000000000..19e12e75ef8f
--- /dev/null
+++ b/src/transformers/models/gpt_oss/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_gpt_oss import *
+ from .modeling_gpt_oss import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/src/transformers/models/gpt_oss/configuration_gpt_oss.py b/src/transformers/models/gpt_oss/configuration_gpt_oss.py
new file mode 100644
index 000000000000..0a120e7ec970
--- /dev/null
+++ b/src/transformers/models/gpt_oss/configuration_gpt_oss.py
@@ -0,0 +1,118 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""openai model configuration"""
+
+from ...configuration_utils import PretrainedConfig, layer_type_validation
+from ...modeling_rope_utils import rope_config_validation
+
+
+class GptOssConfig(PretrainedConfig):
+ r"""
+ This will yield a configuration to that of the BERT
+ [google-bert/bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased) architecture.
+
+ """
+
+ model_type = "gpt_oss"
+ base_model_pp_plan = {
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+ "norm": (["hidden_states"], ["hidden_states"]),
+ }
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise",
+ "layers.*.self_attn.k_proj": "colwise",
+ "layers.*.self_attn.v_proj": "colwise",
+ "layers.*.self_attn.o_proj": "rowwise",
+ "layers.*.self_attn.sinks": "local_rowwise",
+ "layers.*.mlp.experts": "gather",
+ "layers.*.mlp.router": "ep_router",
+ "layers.*.mlp.experts.gate_up_proj": "grouped_gemm",
+ "layers.*.mlp.experts.gate_up_proj_bias": "grouped_gemm",
+ "layers.*.mlp.experts.down_proj": "grouped_gemm",
+ "layers.*.mlp.experts.down_proj_bias": "grouped_gemm",
+ }
+
+ def __init__(
+ self,
+ num_hidden_layers: int = 36,
+ num_local_experts: int = 128,
+ vocab_size: int = 201088,
+ hidden_size: int = 2880,
+ intermediate_size: int = 2880,
+ head_dim: int = 64,
+ num_attention_heads: int = 64,
+ num_key_value_heads: int = 8,
+ sliding_window: int = 128,
+ rope_theta: float = 150000.0,
+ tie_word_embeddings=False,
+ hidden_act: str = "silu",
+ initializer_range: float = 0.02,
+ max_position_embeddings=131072,
+ rms_norm_eps: float = 1e-5,
+ rope_scaling={"rope_type": "yarn", "factor": 32.0, "beta_fast": 32.0, "beta_slow": 1.0, "truncate": False},
+ attention_dropout: float = 0.0,
+ num_experts_per_tok=4,
+ router_aux_loss_coef: float = 0.9,
+ output_router_logits=False,
+ use_cache=True,
+ layer_types=None,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_local_experts = num_local_experts
+ self.sliding_window = sliding_window
+ self.num_experts_per_tok = num_experts_per_tok
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.attention_dropout = attention_dropout
+ self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
+ self.layer_types = layer_types
+ if self.layer_types is None:
+ self.layer_types = [
+ "sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(self.num_hidden_layers)
+ ]
+ layer_type_validation(self.layer_types)
+
+ # Validate the correctness of rotary position embeddings parameters
+ # BC: if there is a 'type' field, copy it it to 'rope_type'.
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
+ rope_config_validation(self)
+
+ self.attention_bias = True
+ self.max_position_embeddings = max_position_embeddings
+ self.router_aux_loss_coef = router_aux_loss_coef
+ self.output_router_logits = output_router_logits
+ self.use_cache = use_cache
+ super().__init__(
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+
+__all__ = ["GptOssConfig"]
diff --git a/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py b/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py
new file mode 100644
index 000000000000..34bcba3b2515
--- /dev/null
+++ b/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py
@@ -0,0 +1,820 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import gc
+import json
+import os
+from pathlib import Path
+from typing import Optional
+
+import regex as re
+import tiktoken
+import torch
+from safetensors.torch import load_file as safe_load
+
+from transformers import (
+ GenerationConfig,
+ GptOssConfig,
+ GptOssForCausalLM,
+ PreTrainedTokenizerFast,
+)
+from transformers.convert_slow_tokenizer import TikTokenConverter
+
+
+# fmt: off
+# If a weight needs to be split in two or more keys, use `|` to indicate it. ex:
+# r"layers.(\d+).attention.wqkv.weight": r"layers.\1.self_attn.q|k|v|_proj.weight"
+ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
+ r"norm.weight": r"norm.weight",
+ r"\nnorm.scale": r"\nnorm.weight",
+ r"unembedding.weight": r"lm_head.weight",
+ r"embedding": r"embed_tokens",
+ # special key, wqkv needs to be split afterwards
+ r"block.(\d+).attn.qkv": r"layers.\1.self_attn.qkv_proj",
+ r"block.(\d+).attn.out": r"layers.\1.self_attn.o_proj",
+ r"block.(\d+).attn.sinks": r"layers.\1.self_attn.sinks",
+ r"block.(\d+).attn.norm.scale": r"layers.\1.input_layernorm.weight",
+
+ r"block.(\d+).mlp.mlp1_weight": r"layers.\1.mlp.experts.gate_up_proj",
+ r"block.(\d+).mlp.mlp1_bias": r"layers.\1.mlp.experts.gate_up_proj_bias",
+ r"block.(\d+).mlp.mlp2_weight": r"layers.\1.mlp.experts.down_proj",
+ r"block.(\d+).mlp.mlp2_bias": r"layers.\1.mlp.experts.down_proj_bias",
+ r"block.(\d+).mlp.norm.scale": r"layers.\1.post_attention_layernorm.weight",
+ r"block.(\d+).mlp.gate": r"layers.\1.mlp.router",
+}
+# fmt: on
+
+
+def convert_old_keys_to_new_keys(state_dict_keys: Optional[dict] = None):
+ """
+ This function should be applied only once, on the concatenated keys to efficiently rename using
+ the key mappings.
+ """
+ output_dict = {}
+ if state_dict_keys is not None:
+ old_text = "\n".join(state_dict_keys)
+ new_text = old_text
+ for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items():
+ if replacement is None:
+ new_text = re.sub(pattern, "", new_text) # an empty line
+ continue
+ new_text = re.sub(pattern, replacement, new_text)
+ output_dict = dict(zip(old_text.split("\n"), new_text.split("\n")))
+ return output_dict
+
+
+FP4_VALUES = [
+ +0.0,
+ +0.5,
+ +1.0,
+ +1.5,
+ +2.0,
+ +3.0,
+ +4.0,
+ +6.0,
+ -0.0,
+ -0.5,
+ -1.0,
+ -1.5,
+ -2.0,
+ -3.0,
+ -4.0,
+ -6.0,
+]
+
+
+def convert_moe_packed_tensors(
+ blocks,
+ scales,
+ *,
+ dtype: torch.dtype = torch.bfloat16,
+ rows_per_chunk: int = 32768 * 1024,
+) -> torch.Tensor:
+ import math
+
+ scales = scales.to(torch.int32) - 127
+
+ assert blocks.shape[:-1] == scales.shape, f"{blocks.shape=} does not match {scales.shape=}"
+
+ lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device)
+
+ *prefix_shape, G, B = blocks.shape
+ rows_total = math.prod(prefix_shape) * G
+
+ blocks = blocks.reshape(rows_total, B)
+ scales = scales.reshape(rows_total, 1)
+
+ out = torch.empty(rows_total, B * 2, dtype=dtype, device=blocks.device)
+
+ for r0 in range(0, rows_total, rows_per_chunk):
+ r1 = min(r0 + rows_per_chunk, rows_total)
+
+ blk = blocks[r0:r1]
+ exp = scales[r0:r1]
+
+ # nibble indices -> int64
+ idx_lo = (blk & 0x0F).to(torch.long)
+ idx_hi = (blk >> 4).to(torch.long)
+
+ sub = out[r0:r1]
+ sub[:, 0::2] = lut[idx_lo]
+ sub[:, 1::2] = lut[idx_hi]
+
+ torch.ldexp(sub, exp, out=sub)
+ del idx_lo, idx_hi, blk, exp
+
+ out = out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2)
+ # to match for now existing implementation
+ return out.to(torch.float8_e5m2)
+
+
+def write_model(
+ model_path,
+ input_base_path,
+ safe_serialization=True,
+ instruct=False,
+ mxfp4=False,
+):
+ os.makedirs(model_path, exist_ok=True)
+ eos_token_id = 199999 if not instruct else 200002
+ pad_token_id = 199999
+
+ original_config = json.loads((Path(input_base_path) / "config.json").read_text())
+
+ num_local_experts = original_config.pop("num_experts")
+ rope_scaling = {
+ "beta_fast": float(original_config.pop("rope_ntk_beta")),
+ "beta_slow": float(original_config.pop("rope_ntk_alpha")),
+ "factor": float(original_config.pop("rope_scaling_factor")),
+ "rope_type": "yarn",
+ "truncate": False,
+ "original_max_position_embeddings": 4096,
+ }
+
+ config = GptOssConfig(
+ num_local_experts=num_local_experts,
+ rope_scaling=rope_scaling,
+ eos_token_id=eos_token_id,
+ pad_token_id=pad_token_id,
+ **original_config,
+ )
+
+ print(f"Fetching all parameters from the checkpoint at {input_base_path}...")
+ final_ = {}
+ for file in list(os.listdir(input_base_path)):
+ if file.endswith(".safetensors"):
+ final_.update(safe_load(os.path.join(input_base_path, file)))
+
+ print("Converting ..")
+ all_keys = final_.keys()
+ new_keys = convert_old_keys_to_new_keys(all_keys)
+
+ state_dict = {}
+ for key in all_keys:
+ # Post-process the current_parameter.
+ new_key = new_keys.get(key, key)
+ if "lm_head" not in new_key:
+ new_key = "model." + new_key
+ print(f"Processing key: {key} -> {new_key}")
+ if re.search("qkv_proj", new_key):
+ q_len = config.head_dim * config.num_attention_heads
+ k_len = config.head_dim * config.num_key_value_heads
+ q, k, v = (
+ final_[key][:q_len, ...],
+ final_[key][q_len : k_len + q_len, ...],
+ final_[key][k_len + q_len :, ...],
+ )
+ q_key = re.sub(r"qkv_proj", "q_proj", new_key)
+ k_key = re.sub(r"qkv_proj", "k_proj", new_key)
+ v_key = re.sub(r"qkv_proj", "v_proj", new_key)
+ state_dict[q_key] = q.contiguous().to(torch.bfloat16)
+ state_dict[k_key] = k.contiguous().to(torch.bfloat16)
+ state_dict[v_key] = v.contiguous().to(torch.bfloat16)
+ elif re.search("gate_up_proj|down_proj", new_key) and "bias" not in new_key:
+ if not mxfp4:
+ if "scales" in new_key:
+ continue
+ elif "blocks" in new_key:
+ # deal with packed weights
+ blocks = final_[key]
+ scales = final_[key.replace("blocks", "scales")]
+ new_key = new_key.replace(".blocks", "")
+ unpacked_tensors = convert_moe_packed_tensors(blocks, scales, dtype=torch.bfloat16)
+ unpacked_tensors = unpacked_tensors.permute(0, 2, 1).contiguous() # einsum in orignal, I use bmm
+ state_dict[new_key] = unpacked_tensors
+ else:
+ raise (f"Unidentified {key}, please double check the state dict")
+ else:
+ if "scales" in new_key:
+ new_key = new_key.replace(".scales", "_scales")
+ state_dict[new_key] = final_[key].contiguous()
+ elif "blocks" in new_key:
+ new_key = new_key.replace(".blocks", "_blocks")
+ state_dict[new_key] = final_[key].contiguous()
+ else:
+ raise (f"Unidentified {key}, please double check the state dict")
+ else:
+ weight = final_[key]
+ if not re.search("norm", new_key):
+ weight = weight.to(torch.bfloat16) # norms are the only ones in float32
+ state_dict[new_key] = weight
+
+ del final_
+ gc.collect()
+
+ if not mxfp4:
+ print("Loading the checkpoint in a GptOss model for unpacked format")
+ with torch.device("meta"):
+ model = GptOssForCausalLM(config)
+ model.load_state_dict(state_dict, strict=True, assign=True)
+ print("Checkpoint loaded successfully.")
+ del config._name_or_path
+
+ print("Saving the model")
+ model.save_pretrained(model_path, safe_serialization=safe_serialization)
+ del state_dict, model
+
+ else:
+ print("Saving the checkpoint in mxfp4 format")
+ config.quantization_config = {
+ "quant_method": "mxfp4",
+ "modules_to_not_convert": [
+ "model.layers.*.self_attn",
+ "model.layers.*.mlp.router",
+ "model.embed_tokens",
+ "lm_head",
+ ],
+ }
+ # required as we don't save the model with save_pretrained
+ config.architectures = ["GptOssForCausalLM"]
+ config.save_pretrained(model_path)
+ save_sharded_model(state_dict, model_path)
+ del state_dict
+
+ gc.collect()
+ print("Reloading the model to check if it's saved correctly.")
+ GptOssForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map="auto")
+ print("Model reloaded successfully.")
+
+ # generation config
+ if instruct:
+ print("Saving generation config...")
+ generation_config = GenerationConfig(
+ bos_token_id=199998, # <|startoftext|>
+ do_sample=True,
+ eos_token_id=[200002, 199999], # <|return|>, <|endoftext|>
+ pad_token_id=199999, # <|endoftext|>
+ temperature=1.0,
+ top_p=1.0,
+ )
+ generation_config.save_pretrained(model_path)
+
+
+def save_sharded_model(state_dict, model_path):
+ from safetensors.torch import save_file
+
+ max_shard_size = 4800000000 # 4.8 GB
+ os.makedirs(model_path, exist_ok=True)
+ shard_size_counter = 0
+ shard_id = 0
+ shard_state_dict = {}
+ total_sharded_dict = {}
+ safetensors_index = {}
+ safetensors_index["metadata"] = {"total_size": 0}
+ safetensors_index["weight_map"] = {}
+ for key in state_dict.keys():
+ size = state_dict[key].numel() * state_dict[key].element_size()
+ if shard_size_counter + size > max_shard_size:
+ total_sharded_dict[shard_id] = shard_state_dict
+ shard_id += 1
+ shard_size_counter = 0
+ shard_state_dict = {}
+ shard_state_dict[key] = state_dict[key]
+ shard_size_counter += size
+ safetensors_index["metadata"]["total_size"] += size
+ safetensors_index["weight_map"][key] = shard_id
+ total_sharded_dict[shard_id] = shard_state_dict
+ num_shards = len(total_sharded_dict) - 1
+ for shard_id, shard_state_dict in total_sharded_dict.items():
+ save_file(shard_state_dict, os.path.join(model_path, f"model-{shard_id:05d}-of-{num_shards:05d}.safetensors"))
+ create_safetensors_index(safetensors_index, num_shards, model_path)
+
+
+def create_safetensors_index(safetensors_index, num_shards, model_path):
+ for key in safetensors_index["weight_map"].keys():
+ shard_id = safetensors_index["weight_map"][key]
+ safetensors_index["weight_map"][key] = f"model-{shard_id:05d}-of-{num_shards:05d}.safetensors"
+ with open(os.path.join(model_path, "model.safetensors.index.json"), "w") as f:
+ json.dump(safetensors_index, f)
+
+
+# Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
+def bytes_to_unicode():
+ """
+ Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
+ characters the bpe code barfs on.
+
+ The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
+ if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
+ decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
+ tables between utf-8 bytes and unicode strings.
+ """
+ bs = (
+ list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
+ )
+ cs = bs[:]
+ n = 0
+ for b in range(2**8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2**8 + n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+class GptOssConverter(TikTokenConverter):
+ def extract_vocab_merges_from_model(self, tiktoken_url: str):
+ tokenizer = tiktoken.get_encoding(tiktoken_url)
+ self.pattern = tokenizer._pat_str
+ bpe_ranks = tokenizer._mergeable_ranks
+ byte_encoder = bytes_to_unicode()
+
+ def token_bytes_to_string(b):
+ return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])
+
+ merges = []
+ vocab = {}
+ for token, rank in bpe_ranks.items():
+ vocab[token_bytes_to_string(token)] = rank
+ if len(token) == 1:
+ continue
+ local = []
+ for index in range(1, len(token)):
+ piece_l, piece_r = token[:index], token[index:]
+ if piece_l in bpe_ranks and piece_r in bpe_ranks and (piece_l + piece_r) in bpe_ranks:
+ local.append((piece_l, piece_r, rank))
+ local = sorted(local, key=lambda x: (bpe_ranks[x[0]], bpe_ranks[x[1]]), reverse=False)
+ merges.extend(local)
+ merges = sorted(merges, key=lambda val: val[2], reverse=False)
+ merges = [(token_bytes_to_string(val[0]), token_bytes_to_string(val[1])) for val in merges]
+ return vocab, merges
+
+ def __init__(
+ self,
+ vocab_file,
+ model_max_length: int,
+ chat_template: Optional[str] = None,
+ **kwargs,
+ ):
+ super().__init__(vocab_file, pattern=None)
+
+ # TODO 1st donwload the vocabfile!!!
+ tokenizer = tiktoken.get_encoding(vocab_file)
+ self.additional_special_tokens = {}
+ # Complete list of Harmony special tokens as per o200k_harmony spec
+ special_tokens_map = {
+ "<|startoftext|>": 199998,
+ "<|endoftext|>": 199999,
+ "<|return|>": 200002,
+ "<|constrain|>": 200003,
+ "<|channel|>": 200005,
+ "<|start|>": 200006,
+ "<|end|>": 200007,
+ "<|message|>": 200008,
+ "<|call|>": 200012,
+ "<|endofprompt|>": 200018,
+ }
+
+ # Add the remaining reserved slots while skipping IDs already present above.
+ used_ids = set(special_tokens_map.values())
+ for k in range(199999, 200018):
+ if k in used_ids:
+ continue
+ special_tokens_map.setdefault(f"<|reserved_{k}|>", k)
+
+ # Keep only token strings (sorted by ID) for TikTokenConverter.
+ self.additional_special_tokens = [tok for tok, _ in sorted(special_tokens_map.items(), key=lambda x: x[1])]
+ tokenizer = self.converted()
+ if chat_template is not None:
+ kwargs["chat_template"] = chat_template
+ self.tokenizer = PreTrainedTokenizerFast(
+ tokenizer_object=tokenizer,
+ bos_token="<|startoftext|>",
+ eos_token="<|return|>" if chat_template else "<|endoftext|>",
+ pad_token="<|endoftext|>",
+ model_input_names=["input_ids", "attention_mask"],
+ model_max_length=model_max_length,
+ **kwargs,
+ )
+
+
+def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False):
+ # Updated Harmony chat template
+ chat_template = """{#-
+ In addition to the normal inputs of `messages` and `tools`, this template also accepts the
+ following kwargs:
+ - "builtin_tools": A list, can contain "browser" and/or "python".
+ - "model_identity": A string that optionally describes the model identity.
+ - "reasoning_effort": A string that describes the reasoning effort, defaults to "medium".
+ #}
+
+{#- Tool Definition Rendering ============================================== #}
+{%- macro render_typescript_type(param_spec, required_params, is_nullable=false) -%}
+ {%- if param_spec.type == "array" -%}
+ {%- if param_spec['items'] -%}
+ {%- if param_spec['items']['type'] == "string" -%}
+ {{- "string[]" }}
+ {%- elif param_spec['items']['type'] == "number" -%}
+ {{- "number[]" }}
+ {%- elif param_spec['items']['type'] == "integer" -%}
+ {{- "number[]" }}
+ {%- elif param_spec['items']['type'] == "boolean" -%}
+ {{- "boolean[]" }}
+ {%- else -%}
+ {%- set inner_type = render_typescript_type(param_spec['items'], required_params) -%}
+ {%- if inner_type == "object | object" or inner_type|length > 50 -%}
+ {{- "any[]" }}
+ {%- else -%}
+ {{- inner_type + "[]" }}
+ {%- endif -%}
+ {%- endif -%}
+ {%- if param_spec.nullable -%}
+ {{- " | null" }}
+ {%- endif -%}
+ {%- else -%}
+ {{- "any[]" }}
+ {%- if param_spec.nullable -%}
+ {{- " | null" }}
+ {%- endif -%}
+ {%- endif -%}
+ {%- elif param_spec.type is defined and param_spec.type is iterable and param_spec.type is not string and param_spec.type is not mapping and param_spec.type[0] is defined -%}
+ {#- Handle array of types like ["object", "object"] from Union[dict, list] #}
+ {%- if param_spec.type | length > 1 -%}
+ {{- param_spec.type | join(" | ") }}
+ {%- else -%}
+ {{- param_spec.type[0] }}
+ {%- endif -%}
+ {%- elif param_spec.oneOf -%}
+ {#- Handle oneOf schemas - check for complex unions and fallback to any #}
+ {%- set has_object_variants = false -%}
+ {%- for variant in param_spec.oneOf -%}
+ {%- if variant.type == "object" -%}
+ {%- set has_object_variants = true -%}
+ {%- endif -%}
+ {%- endfor -%}
+ {%- if has_object_variants and param_spec.oneOf|length > 1 -%}
+ {{- "any" }}
+ {%- else -%}
+ {%- for variant in param_spec.oneOf -%}
+ {{- render_typescript_type(variant, required_params) -}}
+ {%- if variant.description %}
+ {{- "// " + variant.description }}
+ {%- endif -%}
+ {%- if variant.default is defined %}
+ {{ "// default: " + variant.default|tojson }}
+ {%- endif -%}
+ {%- if not loop.last %}
+ {{- " | " }}
+ {% endif -%}
+ {%- endfor -%}
+ {%- endif -%}
+ {%- elif param_spec.type == "string" -%}
+ {%- if param_spec.enum -%}
+ {{- '"' + param_spec.enum|join('" | "') + '"' -}}
+ {%- else -%}
+ {{- "string" }}
+ {%- if param_spec.nullable %}
+ {{- " | null" }}
+ {%- endif -%}
+ {%- endif -%}
+ {%- elif param_spec.type == "number" -%}
+ {{- "number" }}
+ {%- elif param_spec.type == "integer" -%}
+ {{- "number" }}
+ {%- elif param_spec.type == "boolean" -%}
+ {{- "boolean" }}
+
+ {%- elif param_spec.type == "object" -%}
+ {%- if param_spec.properties -%}
+ {{- "{\n" }}
+ {%- for prop_name, prop_spec in param_spec.properties.items() -%}
+ {{- prop_name -}}
+ {%- if prop_name not in (param_spec.required or []) -%}
+ {{- "?" }}
+ {%- endif -%}
+ {{- ": " }}
+ {{ render_typescript_type(prop_spec, param_spec.required or []) }}
+ {%- if not loop.last -%}
+ {{-", " }}
+ {%- endif -%}
+ {%- endfor -%}
+ {{- "}" }}
+ {%- else -%}
+ {{- "object" }}
+ {%- endif -%}
+ {%- else -%}
+ {{- "any" }}
+ {%- endif -%}
+{%- endmacro -%}
+
+{%- macro render_tool_namespace(namespace_name, tools) -%}
+ {{- "## " + namespace_name + "\n\n" }}
+ {{- "namespace " + namespace_name + " {\n\n" }}
+ {%- for tool in tools %}
+ {%- set tool = tool.function %}
+ {{- "// " + tool.description + "\n" }}
+ {{- "type "+ tool.name + " = " }}
+ {%- if tool.parameters and tool.parameters.properties %}
+ {{- "(_: {\n" }}
+ {%- for param_name, param_spec in tool.parameters.properties.items() %}
+ {%- if param_spec.description %}
+ {{- "// " + param_spec.description + "\n" }}
+ {%- endif %}
+ {{- param_name }}
+ {%- if param_name not in (tool.parameters.required or []) -%}
+ {{- "?" }}
+ {%- endif -%}
+ {{- ": " }}
+ {{- render_typescript_type(param_spec, tool.parameters.required or []) }}
+ {%- if param_spec.default is defined -%}
+ {%- if param_spec.enum %}
+ {{- ", // default: " + param_spec.default }}
+ {%- elif param_spec.oneOf %}
+ {{- "// default: " + param_spec.default }}
+ {%- else %}
+ {{- ", // default: " + param_spec.default|tojson }}
+ {%- endif -%}
+ {%- endif -%}
+ {%- if not loop.last %}
+ {{- ",\n" }}
+ {%- else %}
+ {{- "\n" }}
+ {%- endif -%}
+ {%- endfor %}
+ {{- "}) => any;\n\n" }}
+ {%- else -%}
+ {{- "() => any;\n\n" }}
+ {%- endif -%}
+ {%- endfor %}
+ {{- "} // namespace " + namespace_name }}
+{%- endmacro -%}
+
+{%- macro render_builtin_tools(browser_tool, python_tool) -%}
+ {%- if browser_tool %}
+ {{- "## browser\n\n" }}
+ {{- "// Tool for browsing.\n" }}
+ {{- "// The `cursor` appears in brackets before each browsing display: `[{cursor}]`.\n" }}
+ {{- "// Cite information from the tool using the following format:\n" }}
+ {{- "// `【{cursor}†L{line_start}(-L{line_end})?】`, for example: `【6†L9-L11】` or `【8†L3】`.\n" }}
+ {{- "// Do not quote more than 10 words directly from the tool output.\n" }}
+ {{- "// sources=web (default: web)\n" }}
+ {{- "namespace browser {\n\n" }}
+ {{- "// Searches for information related to `query` and displays `topn` results.\n" }}
+ {{- "type search = (_: {\n" }}
+ {{- "query: string,\n" }}
+ {{- "topn?: number, // default: 10\n" }}
+ {{- "source?: string,\n" }}
+ {{- "}) => any;\n\n" }}
+ {{- "// Opens the link `id` from the page indicated by `cursor` starting at line number `loc`, showing `num_lines` lines.\n" }}
+ {{- "// Valid link ids are displayed with the formatting: `【{id}†.*】`.\n" }}
+ {{- "// If `cursor` is not provided, the most recent page is implied.\n" }}
+ {{- "// If `id` is a string, it is treated as a fully qualified URL associated with `source`.\n" }}
+ {{- "// If `loc` is not provided, the viewport will be positioned at the beginning of the document or centered on the most relevant passage, if available.\n" }}
+ {{- "// Use this function without `id` to scroll to a new location of an opened page.\n" }}
+ {{- "type open = (_: {\n" }}
+ {{- "id?: number | string, // default: -1\n" }}
+ {{- "cursor?: number, // default: -1\n" }}
+ {{- "loc?: number, // default: -1\n" }}
+ {{- "num_lines?: number, // default: -1\n" }}
+ {{- "view_source?: boolean, // default: false\n" }}
+ {{- "source?: string,\n" }}
+ {{- "}) => any;\n\n" }}
+ {{- "// Finds exact matches of `pattern` in the current page, or the page given by `cursor`.\n" }}
+ {{- "type find = (_: {\n" }}
+ {{- "pattern: string,\n" }}
+ {{- "cursor?: number, // default: -1\n" }}
+ {{- "}) => any;\n\n" }}
+ {{- "} // namespace browser\n\n" }}
+ {%- endif -%}
+
+ {%- if python_tool %}
+ {{- "## python\n\n" }}
+ {{- "Use this tool to execute Python code in your chain of thought. The code will not be shown to the user. This tool should be used for internal reasoning, but not for code that is intended to be visible to the user (e.g. when creating plots, tables, or files).\n\n" }}
+ {{- "When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 120.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is UNKNOWN. Depends on the cluster.\n\n" }}
+ {%- endif -%}
+{%- endmacro -%}
+
+{#- System Message Construction ============================================ #}
+{%- macro build_system_message() -%}
+ {%- if model_identity is not defined %}
+ {%- set model_identity = "You are ChatGPT, a large language model trained by OpenAI." %}
+ {%- endif %}
+ {{- model_identity + "\n" }}
+ {{- "Knowledge cutoff: 2024-06\n" }}
+ {{- "Current date: " + strftime_now("%Y-%m-%d") + "\n\n" }}
+ {%- if reasoning_effort is not defined %}
+ {%- set reasoning_effort = "medium" %}
+ {%- endif %}
+ {{- "Reasoning: " + reasoning_effort + "\n\n" }}
+ {%- if builtin_tools %}
+ {{- "# Tools\n\n" }}
+ {%- set available_builtin_tools = namespace(browser=false, python=false) %}
+ {%- for tool in builtin_tools %}
+ {%- if tool == "browser" %}
+ {%- set available_builtin_tools.browser = true %}
+ {%- elif tool == "python" %}
+ {%- set available_builtin_tools.python = true %}
+ {%- endif %}
+ {%- endfor %}
+ {{- render_builtin_tools(available_builtin_tools.browser, available_builtin_tools.python) }}
+ {%- endif -%}
+ {{- "# Valid channels: analysis, commentary, final. Channel must be included for every message." }}
+ {%- if tools -%}
+ {{- "\nCalls to these tools must go to the commentary channel: 'functions'." }}
+ {%- endif -%}
+{%- endmacro -%}
+
+{#- Main Template Logic ================================================= #}
+{#- Set defaults #}
+
+{#- Render system message #}
+{{- "<|start|>system<|message|>" }}
+{{- build_system_message() }}
+{{- "<|end|>" }}
+
+{#- Extract developer message #}
+{%- if messages[0].role == "developer" or messages[0].role == "system" %}
+ {%- set developer_message = messages[0].content %}
+ {%- set loop_messages = messages[1:] %}
+{%- else %}
+ {%- set developer_message = "" %}
+ {%- set loop_messages = messages %}
+{%- endif %}
+
+{#- Render developer message #}
+{%- if developer_message or tools %}
+ {{- "<|start|>developer<|message|>" }}
+ {%- if developer_message %}
+ {{- "# Instructions\n\n" }}
+ {{- developer_message }}
+ {%- endif %}
+ {%- if tools -%}
+ {{- "\n\n" }}
+ {{- "# Tools\n\n" }}
+ {{- render_tool_namespace("functions", tools) }}
+ {%- endif -%}
+ {{- "<|end|>" }}
+{%- endif %}
+
+{#- Render messages #}
+{%- set last_tool_call = namespace(name=none) %}
+{%- for message in loop_messages -%}
+ {#- At this point only assistant/user/tool messages should remain #}
+ {%- if message.role == 'assistant' -%}
+ {#- Checks to ensure the messages are being passed in the format we expect #}
+ {%- if "content" in message %}
+ {%- if "<|channel|>analysis<|message|>" in message.content or "<|channel|>final<|message|>" in message.content %}
+ {{- raise_exception("You have passed a message containing <|channel|> tags in the content field. Instead of doing this, you should pass analysis messages (the string between '<|message|>' and '<|end|>') in the 'thinking' field, and final messages (the string between '<|message|>' and '<|end|>') in the 'content' field.") }}
+ {%- endif %}
+ {%- endif %}
+ {%- if "thinking" in message %}
+ {%- if "<|channel|>analysis<|message|>" in message.thinking or "<|channel|>final<|message|>" in message.thinking %}
+ {{- raise_exception("You have passed a message containing <|channel|> tags in the thinking field. Instead of doing this, you should pass analysis messages (the string between '<|message|>' and '<|end|>') in the 'thinking' field, and final messages (the string between '<|message|>' and '<|end|>') in the 'content' field.") }}
+ {%- endif %}
+ {%- endif %}
+ {%- if "tool_calls" in message %}
+ {#- We assume max 1 tool call per message, and so we infer the tool call name #}
+ {#- in "tool" messages from the most recent assistant tool call name #}
+ {%- set tool_call = message.tool_calls[0] %}
+ {%- if tool_call.function %}
+ {%- set tool_call = tool_call.function %}
+ {%- endif %}
+ {%- if message.content and message.thinking %}
+ {{- raise_exception("Cannot pass both content and thinking in an assistant message with tool calls! Put the analysis message in one or the other, but not both.") }}
+ {%- elif message.content %}
+ {{- "<|start|>assistant<|channel|>analysis<|message|>" + message.content + "<|end|>" }}
+ {%- elif message.thinking %}
+ {{- "<|start|>assistant<|channel|>analysis<|message|>" + message.thinking + "<|end|>" }}
+ {%- endif %}
+ {{- "<|start|>assistant to=" }}
+ {{- "functions." + tool_call.name + "<|channel|>commentary " }}
+ {{- (tool_call.content_type if tool_call.content_type is defined else "json") + "<|message|>" }}
+ {{- tool_call.arguments|tojson }}
+ {{- "<|call|>" }}
+ {%- set last_tool_call.name = tool_call.name %}
+ {%- elif loop.last and not add_generation_prompt %}
+ {#- Only render the CoT if the final turn is an assistant turn and add_generation_prompt is false #}
+ {#- This is a situation that should only occur in training, never in inference. #}
+ {%- if "thinking" in message %}
+ {{- "<|start|>assistant<|channel|>analysis<|message|>" + message.thinking + "<|end|>" }}
+ {%- endif %}
+ {#- <|return|> indicates the end of generation, but <|end|> does not #}
+ {#- <|return|> should never be an input to the model, but we include it as the final token #}
+ {#- when training, so the model learns to emit it. #}
+ {{- "<|start|>assistant<|channel|>final<|message|>" + message.content + "<|return|>" }}
+ {%- else %}
+ {#- CoT is dropped during all previous turns, so we never render it for inference #}
+ {{- "<|start|>assistant<|channel|>final<|message|>" + message.content + "<|end|>" }}
+ {%- set last_tool_call.name = none %}
+ {%- endif %}
+ {%- elif message.role == 'tool' -%}
+ {%- if last_tool_call.name is none %}
+ {{- raise_exception("Message has tool role, but there was no previous assistant message with a tool call!") }}
+ {%- endif %}
+ {{- "<|start|>functions." + last_tool_call.name }}
+ {{- " to=assistant<|channel|>commentary<|message|>" + message.content|tojson + "<|end|>" }}
+ {%- elif message.role == 'user' -%}
+ {{- "<|start|>user<|message|>" + message.content + "<|end|>" }}
+ {%- endif -%}
+{%- endfor -%}
+
+{#- Generation prompt #}
+{%- if add_generation_prompt -%}
+<|start|>assistant
+{%- endif -%}"""
+
+ converter = GptOssConverter(
+ vocab_file=tokenizer_path,
+ model_max_length=None,
+ chat_template=chat_template if instruct else None,
+ )
+ tokenizer = converter.tokenizer
+ tokenizer.save_pretrained(save_dir)
+
+ if instruct:
+ print("Saving chat template...")
+ chat_template_path = os.path.join(save_dir, "chat_template.json")
+ with open(chat_template_path, "w") as f:
+ json.dump({"chat_template": chat_template}, f, indent=2)
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--input_dir",
+ default="/fsx/mohamed/oai-hf/tests/120b",
+ help="Location of LLaMA weights, which contains tokenizer.model and model folders",
+ )
+ parser.add_argument(
+ "--output_dir",
+ default="/fsx/mohamed/oai-hf/tests/120b_converted_packed",
+ help="Location to write HF model and tokenizer",
+ )
+ parser.add_argument(
+ "--safe_serialization", default=True, type=bool, help="Whether or not to save using `safetensors`."
+ )
+ parser.add_argument(
+ "--special_tokens",
+ default=None,
+ type=list[str],
+ help="The list of special tokens that should be added to the ",
+ )
+
+ parser.add_argument(
+ "--instruct",
+ action="store_true",
+ help="Whether the model is an instruct model",
+ )
+
+ # Only specify this if you want to use the model with mxfp4 quantization
+ # It means the model will be unpacked, and quantized using mxfp4 during inference if all the triton requirements are satisfied (triton >= 3.4.0)
+ # Else we have a fallback to the full precision model (bfloat16)
+ # If not specified, the model will be unpacked during conversion, and will be in fp8/bfloat16 during inference
+ # Note: mxfp4 should bring an important speedup in inference time with blackwell gpus
+ parser.add_argument(
+ "--mxfp4",
+ action="store_true",
+ help="Whether to use the original model with mxfp4 quantization or default to the full precision model.",
+ )
+
+ args = parser.parse_args()
+ write_model(
+ model_path=args.output_dir,
+ input_base_path=args.input_dir,
+ safe_serialization=args.safe_serialization,
+ instruct=args.instruct,
+ mxfp4=args.mxfp4,
+ )
+
+ write_tokenizer(
+ tokenizer_path="o200k_base",
+ save_dir=args.output_dir,
+ instruct=args.instruct,
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/transformers/models/gpt_oss/modeling_gpt_oss.py b/src/transformers/models/gpt_oss/modeling_gpt_oss.py
new file mode 100644
index 000000000000..2077f7372c9d
--- /dev/null
+++ b/src/transformers/models/gpt_oss/modeling_gpt_oss.py
@@ -0,0 +1,701 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/gpt_oss/modular_gpt_oss.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_gpt_oss.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Callable, Optional, Union
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...integrations.hub_kernels import use_kernel_forward_from_hub
+from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
+from ...utils.generic import OutputRecorder, check_model_inputs
+from .configuration_gpt_oss import GptOssConfig
+
+
+@use_kernel_forward_from_hub("RMSNorm")
+class GptOssRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ GptOssRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return (self.weight * hidden_states).to(input_dtype) # main diff with Llama
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class GptOssExperts(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.intermediate_size = config.intermediate_size
+ self.num_experts = config.num_local_experts
+ self.hidden_size = config.hidden_size
+ self.expert_dim = self.intermediate_size
+ self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
+ self.gate_up_proj_bias = nn.Parameter(torch.empty(self.num_experts, 2 * self.expert_dim))
+ self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size)))
+ self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.hidden_size))
+ self.alpha = 1.702
+ self.limit = 7.0
+
+ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
+ """
+ When training is is more efficient to just loop over the experts and compute the output for each expert
+ as otherwise the memory would explode.
+
+ For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs.
+
+ Args:
+ hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size)
+ selected_experts (torch.Tensor): (batch_size * token_num, top_k)
+ routing_weights (torch.Tensor): (batch_size * token_num, num_experts)
+ Returns:
+ torch.Tensor
+ """
+ batch_size = hidden_states.shape[0]
+ hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size)
+ num_experts = routing_weights.shape[1]
+ if self.training:
+ next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
+ with torch.no_grad():
+ expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts)
+ expert_mask = expert_mask.permute(2, 1, 0)
+ # we sum on the top_k and on the sequence lenght to get which experts
+ # are hit this time around
+ expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
+ for expert_idx in expert_hitted[:]:
+ with torch.no_grad():
+ _, token_idx = torch.where(expert_mask[expert_idx[0]])
+ current_state = hidden_states[token_idx]
+ gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx]
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+ gate = gate.clamp(min=None, max=self.limit)
+ up = up.clamp(min=-self.limit, max=self.limit)
+ glu = gate * torch.sigmoid(gate * self.alpha)
+ gated_output = (up + 1) * glu
+ out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
+ weighted_output = out[0] * routing_weights[token_idx, expert_idx, None]
+ next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
+ next_states = next_states.view(batch_size, -1, self.hidden_size)
+ else:
+ hidden_states = hidden_states.repeat(num_experts, 1)
+ hidden_states = hidden_states.view(num_experts, -1, self.hidden_size)
+ gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :]
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+ gate = gate.clamp(min=None, max=self.limit)
+ up = up.clamp(min=-self.limit, max=self.limit)
+ glu = gate * torch.sigmoid(gate * self.alpha)
+ next_states = torch.bmm(((up + 1) * glu), self.down_proj)
+ next_states = next_states + self.down_proj_bias[..., None, :]
+ next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size)
+ next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None]
+ next_states = next_states.sum(dim=0)
+ return next_states
+
+
+class GptOssTopKRouter(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.top_k = config.num_experts_per_tok
+ self.num_experts = config.num_local_experts
+ self.hidden_dim = config.hidden_size
+ self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim))
+ self.bias = nn.Parameter(torch.empty(self.num_experts))
+
+ def forward(self, hidden_states):
+ hidden_states = hidden_states.reshape(-1, self.hidden_dim)
+ router_logits = F.linear(hidden_states, self.weight, self.bias) # (seq_len, num_experts)
+ router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
+ router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
+ router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
+ return router_scores, router_indices
+
+
+@use_kernel_forward_from_hub("MegaBlocksMoeMLP")
+class GptOssMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.router = GptOssTopKRouter(config)
+ self.experts = GptOssExperts(config)
+
+ def forward(self, hidden_states):
+ router_scores, router_indices = self.router(hidden_states) # (num_experts, seq_len)
+ routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores)
+ return routed_out, router_scores
+
+
+class GptOssRotaryEmbedding(nn.Module):
+ def __init__(self, config: GptOssConfig, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = freqs
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(x.dtype), sin.to(x.dtype)
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def _apply_rotary_emb(
+ x: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+) -> torch.Tensor:
+ first_half, second_half = torch.chunk(x, 2, dim=-1)
+ first_ = first_half * cos - second_half * sin
+ second_ = second_half * cos + first_half * sin
+ return torch.cat((first_, second_), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = _apply_rotary_emb(q, cos, sin)
+ k_embed = _apply_rotary_emb(k, cos, sin)
+ return q_embed, k_embed
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1)
+ combined_logits = torch.cat([attn_weights, sinks], dim=-1)
+
+ # This was not in the original implementation and slightly affect results; it prevents overflow in BF16/FP16
+ # when training with bsz>1 we clamp max values.
+
+ combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values
+ probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype)
+ scores = probs[..., :-1] # we drop the sink here
+ attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ return attn_output, attn_weights
+
+
+class GptOssAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: GptOssConfig, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = True
+ self.q_proj = nn.Linear(
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.k_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.v_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.o_proj = nn.Linear(
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
+ )
+ self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
+ self.sinks = nn.Parameter(torch.empty(config.num_attention_heads))
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_value: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ cache_kwargs = {"cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ sliding_window=self.sliding_window,
+ s_aux=self.sinks, # diff with Llama
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class GptOssDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: GptOssConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.self_attn = GptOssAttention(config=config, layer_idx=layer_idx)
+ self.mlp = GptOssMLP(config)
+ self.input_layernorm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.attention_type = config.layer_types[layer_idx]
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor]:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ # Self Attention
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states, _ = self.mlp(hidden_states) # diff with llama: router scores
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+@auto_docstring
+class GptOssPreTrainedModel(PreTrainedModel):
+ config: GptOssConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["GptOssDecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn = True
+ _supports_sdpa = False
+ _supports_flex_attn = True
+
+ _can_compile_fullgraph = True
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "router_logits": OutputRecorder(GptOssTopKRouter, index=0),
+ "hidden_states": GptOssDecoderLayer,
+ "attentions": GptOssAttention,
+ }
+ _keep_in_fp32_modules = ["post_attention_layernorm", "input_layernorm", "norm"]
+ _supports_flash_attention = False
+ _supports_flex_attention = False
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Parameter):
+ module.data.normal_(mean=0.0, std=std)
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, GptOssRMSNorm):
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, GptOssExperts):
+ module.gate_up_proj.data.normal_(mean=0.0, std=std)
+ module.gate_up_proj_bias.data.zero_()
+ module.down_proj.data.normal_(mean=0.0, std=std)
+ module.down_proj_bias.data.zero_()
+ elif isinstance(module, GptOssAttention):
+ module.sinks.data.normal_(mean=0.0, std=std)
+ elif isinstance(module, GptOssTopKRouter):
+ module.weight.data.normal_(mean=0.0, std=std)
+ module.bias.data.normal_(mean=0.0, std=std)
+
+
+@auto_docstring
+class GptOssModel(GptOssPreTrainedModel):
+ _no_split_modules = ["GptOssDecoderLayer"]
+
+ def __init__(self, config: GptOssConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [GptOssDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = GptOssRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @check_model_inputs
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[list[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> MoeModelOutputWithPast:
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache()
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ # It may already have been prepared by e.g. `generate`
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
+ mask_kwargs = {
+ "config": self.config,
+ "input_embeds": inputs_embeds,
+ "attention_mask": attention_mask,
+ "cache_position": cache_position,
+ "past_key_values": past_key_values,
+ }
+ causal_mask_mapping = {
+ "full_attention": create_causal_mask(**mask_kwargs),
+ "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
+ }
+
+ hidden_states = inputs_embeds
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ for decoder_layer in self.layers:
+ hidden_states = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask_mapping[decoder_layer.attention_type],
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = self.norm(hidden_states)
+ return MoeModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ )
+
+
+def load_balancing_loss_func(
+ gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None],
+ num_experts: Optional[int] = None,
+ top_k=2,
+ attention_mask: Optional[torch.Tensor] = None,
+) -> Union[torch.Tensor, int]:
+ r"""
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
+
+ See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
+ experts is too unbalanced.
+
+ Args:
+ gate_logits:
+ Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
+ shape [batch_size X sequence_length, num_experts].
+ num_experts:
+ Number of experts
+ top_k:
+ The number of experts to route per-token, can be also interpreted as the `top-k` routing
+ parameter.
+ attention_mask (`torch.Tensor`, *optional*):
+ The attention_mask used in forward function
+ shape [batch_size X sequence_length] if not None.
+
+ Returns:
+ The auxiliary loss.
+ """
+ if gate_logits is None or not isinstance(gate_logits, tuple):
+ return 0
+
+ if isinstance(gate_logits, tuple):
+ compute_device = gate_logits[0].device
+ concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
+
+ routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
+
+ _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
+
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
+
+ if attention_mask is None:
+ # Compute the percentage of tokens routed to each experts
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
+
+ # Compute the average probability of routing to these experts
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
+ else:
+ batch_size, sequence_length = attention_mask.shape
+ num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
+
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
+ expert_attention_mask = (
+ attention_mask[None, :, :, None, None]
+ .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
+ .reshape(-1, top_k, num_experts)
+ .to(compute_device)
+ )
+
+ # Compute the percentage of tokens routed to each experts
+ tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
+ expert_attention_mask, dim=0
+ )
+
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
+ router_per_expert_attention_mask = (
+ attention_mask[None, :, :, None]
+ .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
+ .reshape(-1, num_experts)
+ .to(compute_device)
+ )
+
+ # Compute the average probability of routing to these experts
+ router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
+ router_per_expert_attention_mask, dim=0
+ )
+
+ overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
+ return overall_loss * num_experts
+
+
+@auto_docstring
+class GptOssForCausalLM(GptOssPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+ _tp_plan = {"lm_head": "colwise_rep"}
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = GptOssModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+ self.router_aux_loss_coef = config.router_aux_loss_coef
+ self.num_experts = config.num_local_experts
+ self.num_experts_per_tok = config.num_experts_per_tok
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_router_logits: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> MoeCausalLMOutputWithPast:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, GptOssForCausalLM
+
+ >>> model = GptOssForCausalLM.from_pretrained("mistralai/GptOss-8x7B-v0.1")
+ >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/GptOss-8x7B-v0.1")
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+
+ output_router_logits = (
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
+ )
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs: MoeModelOutputWithPast = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_router_logits=output_router_logits,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
+
+ aux_loss = None
+ if output_router_logits:
+ aux_loss = load_balancing_loss_func(
+ outputs.router_logits,
+ self.num_experts,
+ self.num_experts_per_tok,
+ attention_mask,
+ )
+ if labels is not None:
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
+
+ return MoeCausalLMOutputWithPast(
+ loss=loss,
+ aux_loss=aux_loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ router_logits=outputs.router_logits,
+ )
+
+
+__all__ = ["GptOssForCausalLM", "GptOssModel", "GptOssPreTrainedModel"]
diff --git a/src/transformers/models/gpt_oss/modular_gpt_oss.py b/src/transformers/models/gpt_oss/modular_gpt_oss.py
new file mode 100644
index 000000000000..9b4eb578b73b
--- /dev/null
+++ b/src/transformers/models/gpt_oss/modular_gpt_oss.py
@@ -0,0 +1,447 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Callable, Optional
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from ...cache_utils import Cache, DynamicCache
+from ...integrations.hub_kernels import use_kernel_forward_from_hub
+from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
+from ...modeling_outputs import (
+ MoeModelOutputWithPast,
+)
+from ...modeling_rope_utils import dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
+from ...processing_utils import Unpack
+from ...utils import (
+ TransformersKwargs,
+ auto_docstring,
+ logging,
+)
+from ...utils.generic import OutputRecorder, check_model_inputs
+from ..llama.modeling_llama import (
+ LlamaDecoderLayer,
+ LlamaPreTrainedModel,
+ LlamaRMSNorm,
+ LlamaRotaryEmbedding,
+ repeat_kv,
+)
+from ..mixtral.modeling_mixtral import MixtralForCausalLM, MixtralModel
+from ..qwen2.modeling_qwen2 import Qwen2Attention
+from .configuration_gpt_oss import GptOssConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class GptOssRMSNorm(LlamaRMSNorm):
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return (self.weight * hidden_states).to(input_dtype) # main diff with Llama
+
+
+class GptOssExperts(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.intermediate_size = config.intermediate_size
+ self.num_experts = config.num_local_experts
+ self.hidden_size = config.hidden_size
+ self.expert_dim = self.intermediate_size
+ self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
+ self.gate_up_proj_bias = nn.Parameter(torch.empty(self.num_experts, 2 * self.expert_dim))
+ self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size)))
+ self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.hidden_size))
+ self.alpha = 1.702
+ self.limit = 7.0
+
+ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
+ """
+ When training is is more efficient to just loop over the experts and compute the output for each expert
+ as otherwise the memory would explode.
+
+ For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs.
+
+ Args:
+ hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size)
+ selected_experts (torch.Tensor): (batch_size * token_num, top_k)
+ routing_weights (torch.Tensor): (batch_size * token_num, num_experts)
+ Returns:
+ torch.Tensor
+ """
+ batch_size = hidden_states.shape[0]
+ hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size)
+ num_experts = routing_weights.shape[1]
+ if self.training:
+ next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
+ with torch.no_grad():
+ expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts)
+ expert_mask = expert_mask.permute(2, 1, 0)
+ # we sum on the top_k and on the sequence lenght to get which experts
+ # are hit this time around
+ expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
+ for expert_idx in expert_hitted[:]:
+ with torch.no_grad():
+ _, token_idx = torch.where(expert_mask[expert_idx[0]])
+ current_state = hidden_states[token_idx]
+ gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx]
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+ gate = gate.clamp(min=None, max=self.limit)
+ up = up.clamp(min=-self.limit, max=self.limit)
+ glu = gate * torch.sigmoid(gate * self.alpha)
+ gated_output = (up + 1) * glu
+ out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
+ weighted_output = out[0] * routing_weights[token_idx, expert_idx, None]
+ next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
+ next_states = next_states.view(batch_size, -1, self.hidden_size)
+ else:
+ hidden_states = hidden_states.repeat(num_experts, 1)
+ hidden_states = hidden_states.view(num_experts, -1, self.hidden_size)
+ gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :]
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+ gate = gate.clamp(min=None, max=self.limit)
+ up = up.clamp(min=-self.limit, max=self.limit)
+ glu = gate * torch.sigmoid(gate * self.alpha)
+ next_states = torch.bmm(((up + 1) * glu), self.down_proj)
+ next_states = next_states + self.down_proj_bias[..., None, :]
+ next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size)
+ next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None]
+ next_states = next_states.sum(dim=0)
+ return next_states
+
+
+class GptOssTopKRouter(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.top_k = config.num_experts_per_tok
+ self.num_experts = config.num_local_experts
+ self.hidden_dim = config.hidden_size
+ self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim))
+ self.bias = nn.Parameter(torch.empty(self.num_experts))
+
+ def forward(self, hidden_states):
+ hidden_states = hidden_states.reshape(-1, self.hidden_dim)
+ router_logits = F.linear(hidden_states, self.weight, self.bias) # (seq_len, num_experts)
+ router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
+ router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
+ router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
+ return router_scores, router_indices
+
+
+@use_kernel_forward_from_hub("MegaBlocksMoeMLP")
+class GptOssMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.router = GptOssTopKRouter(config)
+ self.experts = GptOssExperts(config)
+
+ def forward(self, hidden_states):
+ router_scores, router_indices = self.router(hidden_states) # (num_experts, seq_len)
+ routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores)
+ return routed_out, router_scores
+
+
+class GptOssRotaryEmbedding(LlamaRotaryEmbedding):
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = freqs
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(x.dtype), sin.to(x.dtype)
+
+
+def _apply_rotary_emb(
+ x: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+) -> torch.Tensor:
+ first_half, second_half = torch.chunk(x, 2, dim=-1)
+ first_ = first_half * cos - second_half * sin
+ second_ = second_half * cos + first_half * sin
+ return torch.cat((first_, second_), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = _apply_rotary_emb(q, cos, sin)
+ k_embed = _apply_rotary_emb(k, cos, sin)
+ return q_embed, k_embed
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1)
+ combined_logits = torch.cat([attn_weights, sinks], dim=-1)
+
+ # This was not in the original implementation and slightly affect results; it prevents overflow in BF16/FP16
+ # when training with bsz>1 we clamp max values.
+
+ combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values
+ probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype)
+ scores = probs[..., :-1] # we drop the sink here
+ attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ return attn_output, attn_weights
+
+
+class GptOssAttention(Qwen2Attention):
+ def __init__(self, config: GptOssConfig, layer_idx: int):
+ super().__init__(config, layer_idx)
+ self.q_proj = nn.Linear(
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.k_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.v_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.o_proj = nn.Linear(
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
+ )
+ self.sinks = nn.Parameter(torch.empty(config.num_attention_heads))
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_value: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ cache_kwargs = {"cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ sliding_window=self.sliding_window,
+ s_aux=self.sinks, # diff with Llama
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class GptOssDecoderLayer(LlamaDecoderLayer):
+ def __init__(self, config: GptOssConfig, layer_idx: int):
+ super().__init__(config, layer_idx)
+ self.hidden_size = config.hidden_size
+ self.self_attn = GptOssAttention(config=config, layer_idx=layer_idx)
+ self.mlp = GptOssMLP(config)
+ self.input_layernorm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.attention_type = config.layer_types[layer_idx]
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor]:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ # Self Attention
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states, _ = self.mlp(hidden_states) # diff with llama: router scores
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+class GptOssPreTrainedModel(LlamaPreTrainedModel):
+ _keep_in_fp32_modules = ["post_attention_layernorm", "input_layernorm", "norm"]
+ _supports_sdpa = False
+ _supports_flash_attention = False
+ _supports_flex_attention = False
+ _can_record_outputs = {
+ "router_logits": OutputRecorder(GptOssTopKRouter, index=0),
+ "hidden_states": GptOssDecoderLayer,
+ "attentions": GptOssAttention,
+ }
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Parameter):
+ module.data.normal_(mean=0.0, std=std)
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, GptOssRMSNorm):
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, GptOssExperts):
+ module.gate_up_proj.data.normal_(mean=0.0, std=std)
+ module.gate_up_proj_bias.data.zero_()
+ module.down_proj.data.normal_(mean=0.0, std=std)
+ module.down_proj_bias.data.zero_()
+ elif isinstance(module, GptOssAttention):
+ module.sinks.data.normal_(mean=0.0, std=std)
+ elif isinstance(module, GptOssTopKRouter):
+ module.weight.data.normal_(mean=0.0, std=std)
+ module.bias.data.normal_(mean=0.0, std=std)
+
+
+class GptOssModel(MixtralModel):
+ _no_split_modules = ["GptOssDecoderLayer"]
+
+ @check_model_inputs
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[list[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> MoeModelOutputWithPast:
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache()
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ # It may already have been prepared by e.g. `generate`
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
+ mask_kwargs = {
+ "config": self.config,
+ "input_embeds": inputs_embeds,
+ "attention_mask": attention_mask,
+ "cache_position": cache_position,
+ "past_key_values": past_key_values,
+ }
+ causal_mask_mapping = {
+ "full_attention": create_causal_mask(**mask_kwargs),
+ "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
+ }
+
+ hidden_states = inputs_embeds
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ for decoder_layer in self.layers:
+ hidden_states = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask_mapping[decoder_layer.attention_type],
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = self.norm(hidden_states)
+ return MoeModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ )
+
+
+class GptOssForCausalLM(MixtralForCausalLM):
+ pass
+
+
+__all__ = [
+ "GptOssForCausalLM",
+ "GptOssModel",
+ "GptOssPreTrainedModel",
+]
diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py
index fffe51d794bd..96fc1ca3373c 100644
--- a/src/transformers/models/granitemoe/modeling_granitemoe.py
+++ b/src/transformers/models/granitemoe/modeling_granitemoe.py
@@ -109,8 +109,8 @@ def load_balancing_loss_func(
# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
router_per_expert_attention_mask = (
attention_mask[None, :, :, None]
- .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
- .reshape(-1, num_experts)
+ .expand((num_hidden_layers, batch_size, sequence_length, routing_weights.shape[1]))
+ .reshape(-1, routing_weights.shape[1])
.to(compute_device)
)
@@ -119,7 +119,10 @@ def load_balancing_loss_func(
router_per_expert_attention_mask, dim=0
)
- overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
+ rank = routing_weights.shape[1] * int(routing_weights.device.index)
+ overall_loss = torch.sum(
+ tokens_per_expert[:, rank : rank + routing_weights.shape[1]] * router_prob_per_expert.unsqueeze(0)
+ )
return overall_loss * num_experts
diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py
index 408cb861a142..c727d40f448b 100644
--- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py
+++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py
@@ -1637,8 +1637,8 @@ def load_balancing_loss_func(
# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
router_per_expert_attention_mask = (
attention_mask[None, :, :, None]
- .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
- .reshape(-1, num_experts)
+ .expand((num_hidden_layers, batch_size, sequence_length, routing_weights.shape[1]))
+ .reshape(-1, routing_weights.shape[1])
.to(compute_device)
)
@@ -1647,7 +1647,10 @@ def load_balancing_loss_func(
router_per_expert_attention_mask, dim=0
)
- overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
+ rank = routing_weights.shape[1] * int(routing_weights.device.index)
+ overall_loss = torch.sum(
+ tokens_per_expert[:, rank : rank + routing_weights.shape[1]] * router_prob_per_expert.unsqueeze(0)
+ )
return overall_loss * num_experts
diff --git a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py
index 21e9d13f7195..f2f5d7d6f0f1 100644
--- a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py
+++ b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py
@@ -908,8 +908,8 @@ def load_balancing_loss_func(
# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
router_per_expert_attention_mask = (
attention_mask[None, :, :, None]
- .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
- .reshape(-1, num_experts)
+ .expand((num_hidden_layers, batch_size, sequence_length, routing_weights.shape[1]))
+ .reshape(-1, routing_weights.shape[1])
.to(compute_device)
)
@@ -918,7 +918,10 @@ def load_balancing_loss_func(
router_per_expert_attention_mask, dim=0
)
- overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
+ rank = routing_weights.shape[1] * int(routing_weights.device.index)
+ overall_loss = torch.sum(
+ tokens_per_expert[:, rank : rank + routing_weights.shape[1]] * router_prob_per_expert.unsqueeze(0)
+ )
return overall_loss * num_experts
diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py
index e4a376e90af1..191e82e8e852 100755
--- a/src/transformers/models/jamba/modeling_jamba.py
+++ b/src/transformers/models/jamba/modeling_jamba.py
@@ -138,8 +138,8 @@ def load_balancing_loss_func(
# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
router_per_expert_attention_mask = (
attention_mask[None, :, :, None]
- .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
- .reshape(-1, num_experts)
+ .expand((num_hidden_layers, batch_size, sequence_length, routing_weights.shape[1]))
+ .reshape(-1, routing_weights.shape[1])
.to(compute_device)
)
@@ -148,7 +148,10 @@ def load_balancing_loss_func(
router_per_expert_attention_mask, dim=0
)
- overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
+ rank = routing_weights.shape[1] * int(routing_weights.device.index)
+ overall_loss = torch.sum(
+ tokens_per_expert[:, rank : rank + routing_weights.shape[1]] * router_prob_per_expert.unsqueeze(0)
+ )
return overall_loss * num_experts
diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py
index 5156ac59742e..997885944142 100644
--- a/src/transformers/models/jetmoe/modeling_jetmoe.py
+++ b/src/transformers/models/jetmoe/modeling_jetmoe.py
@@ -119,8 +119,8 @@ def load_balancing_loss_func(
# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
router_per_expert_attention_mask = (
attention_mask[None, :, :, None]
- .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
- .reshape(-1, num_experts)
+ .expand((num_hidden_layers, batch_size, sequence_length, routing_weights.shape[1]))
+ .reshape(-1, routing_weights.shape[1])
.to(compute_device)
)
@@ -129,7 +129,10 @@ def load_balancing_loss_func(
router_per_expert_attention_mask, dim=0
)
- overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
+ rank = routing_weights.shape[1] * int(routing_weights.device.index)
+ overall_loss = torch.sum(
+ tokens_per_expert[:, rank : rank + routing_weights.shape[1]] * router_prob_per_expert.unsqueeze(0)
+ )
return overall_loss * num_experts
diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py
index eacb56c064b6..c9540e33af4c 100644
--- a/src/transformers/models/olmoe/modeling_olmoe.py
+++ b/src/transformers/models/olmoe/modeling_olmoe.py
@@ -108,8 +108,8 @@ def load_balancing_loss_func(
# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
router_per_expert_attention_mask = (
attention_mask[None, :, :, None]
- .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
- .reshape(-1, num_experts)
+ .expand((num_hidden_layers, batch_size, sequence_length, routing_weights.shape[1]))
+ .reshape(-1, routing_weights.shape[1])
.to(compute_device)
)
@@ -118,7 +118,10 @@ def load_balancing_loss_func(
router_per_expert_attention_mask, dim=0
)
- overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
+ rank = routing_weights.shape[1] * int(routing_weights.device.index)
+ overall_loss = torch.sum(
+ tokens_per_expert[:, rank : rank + routing_weights.shape[1]] * router_prob_per_expert.unsqueeze(0)
+ )
return overall_loss * num_experts
diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py
index 2207793dcaea..a4735a04ac24 100644
--- a/src/transformers/models/phimoe/modeling_phimoe.py
+++ b/src/transformers/models/phimoe/modeling_phimoe.py
@@ -124,8 +124,8 @@ def load_balancing_loss_func(
# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
router_per_expert_attention_mask = (
attention_mask[None, :, :, None]
- .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
- .reshape(-1, num_experts)
+ .expand((num_hidden_layers, batch_size, sequence_length, routing_weights.shape[1]))
+ .reshape(-1, routing_weights.shape[1])
.to(compute_device)
)
@@ -134,7 +134,10 @@ def load_balancing_loss_func(
router_per_expert_attention_mask, dim=0
)
- overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
+ rank = routing_weights.shape[1] * int(routing_weights.device.index)
+ overall_loss = torch.sum(
+ tokens_per_expert[:, rank : rank + routing_weights.shape[1]] * router_prob_per_expert.unsqueeze(0)
+ )
return overall_loss * num_experts
diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
index 108aa9617614..a9cc23b37e22 100644
--- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
+++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
@@ -128,8 +128,8 @@ def load_balancing_loss_func(
# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
router_per_expert_attention_mask = (
attention_mask[None, :, :, None]
- .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
- .reshape(-1, num_experts)
+ .expand((num_hidden_layers, batch_size, sequence_length, routing_weights.shape[1]))
+ .reshape(-1, routing_weights.shape[1])
.to(compute_device)
)
@@ -138,7 +138,10 @@ def load_balancing_loss_func(
router_per_expert_attention_mask, dim=0
)
- overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
+ rank = routing_weights.shape[1] * int(routing_weights.device.index)
+ overall_loss = torch.sum(
+ tokens_per_expert[:, rank : rank + routing_weights.shape[1]] * router_prob_per_expert.unsqueeze(0)
+ )
return overall_loss * num_experts
diff --git a/src/transformers/quantizers/auto.py b/src/transformers/quantizers/auto.py
index 161951d3409f..49051f442695 100644
--- a/src/transformers/quantizers/auto.py
+++ b/src/transformers/quantizers/auto.py
@@ -31,6 +31,7 @@
GPTQConfig,
HiggsConfig,
HqqConfig,
+ Mxfp4Config,
QuantizationConfigMixin,
QuantizationMethod,
QuantoConfig,
@@ -54,6 +55,7 @@
from .quantizer_gptq import GptqHfQuantizer
from .quantizer_higgs import HiggsHfQuantizer
from .quantizer_hqq import HqqHfQuantizer
+from .quantizer_mxfp4 import Mxfp4HfQuantizer
from .quantizer_quanto import QuantoHfQuantizer
from .quantizer_quark import QuarkHfQuantizer
from .quantizer_spqr import SpQRHfQuantizer
@@ -81,6 +83,7 @@
"spqr": SpQRHfQuantizer,
"fp8": FineGrainedFP8HfQuantizer,
"auto-round": AutoRoundQuantizer,
+ "mxfp4": Mxfp4HfQuantizer,
}
AUTO_QUANTIZATION_CONFIG_MAPPING = {
@@ -103,6 +106,7 @@
"spqr": SpQRConfig,
"fp8": FineGrainedFP8Config,
"auto-round": AutoRoundConfig,
+ "mxfp4": Mxfp4Config,
}
logger = logging.get_logger(__name__)
@@ -211,7 +215,8 @@ def merge_quantization_configs(
if (
isinstance(
- quantization_config, (GPTQConfig, AwqConfig, AutoRoundConfig, FbgemmFp8Config, CompressedTensorsConfig)
+ quantization_config,
+ (GPTQConfig, AwqConfig, AutoRoundConfig, FbgemmFp8Config, CompressedTensorsConfig, Mxfp4Config),
)
and quantization_config_from_args is not None
):
@@ -222,9 +227,11 @@ def merge_quantization_configs(
warning_msg += f"However, loading attributes (e.g. {list(loading_attr_dict.keys())}) will be overwritten with the one you passed to `from_pretrained`. The rest will be ignored."
- if warning_msg != "":
+ if warning_msg != "" and not isinstance(quantization_config, Mxfp4Config):
warnings.warn(warning_msg)
-
+ else:
+ # in the case of mxfp4, we don't want to print the warning message, bit confusing for users
+ logger.info(warning_msg)
return quantization_config
@staticmethod
diff --git a/src/transformers/quantizers/base.py b/src/transformers/quantizers/base.py
index fb53c4c0f6de..4efa60a6e48d 100644
--- a/src/transformers/quantizers/base.py
+++ b/src/transformers/quantizers/base.py
@@ -237,6 +237,20 @@ def postprocess_model(self, model: "PreTrainedModel", **kwargs):
"""
return self._process_model_after_weight_loading(model, **kwargs)
+ def remove_quantization_config(self, model):
+ """
+ Remove the quantization config from the model.
+ """
+ if hasattr(model, "hf_quantizer"):
+ del model.hf_quantizer
+ if hasattr(model.config, "quantization_config"):
+ del model.config.quantization_config
+ if hasattr(model.config, "_pre_quantization_dtype"):
+ del model.config._pre_quantization_dtype
+ if hasattr(model, "quantization_method"):
+ del model.quantization_method
+ model.is_quantized = False
+
def dequantize(self, model):
"""
Potentially dequantize the model to retrieve the original model, with some loss in accuracy / performance.
@@ -269,6 +283,12 @@ def _dequantize(self, model):
f"{self.quantization_config.quant_method} has no implementation of `dequantize`, please raise an issue on GitHub."
)
+ def update_param_name(self, param_name: str) -> str:
+ """
+ Override this method if you want to adjust the `param_name`.
+ """
+ return param_name
+
@staticmethod
def get_modules_to_not_convert(
model: "PreTrainedModel",
diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py
new file mode 100644
index 000000000000..061ca072f029
--- /dev/null
+++ b/src/transformers/quantizers/quantizer_mxfp4.py
@@ -0,0 +1,331 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING, Any, Optional
+
+from .base import HfQuantizer
+
+
+if TYPE_CHECKING:
+ from ..modeling_utils import PreTrainedModel
+
+from ..utils import (
+ is_accelerate_available,
+ is_torch_available,
+ is_triton_available,
+ is_triton_kernels_availalble,
+ logging,
+)
+from .quantizers_utils import get_module_from_name
+
+
+if is_torch_available():
+ import torch
+
+logger = logging.get_logger(__name__)
+
+
+class Mxfp4HfQuantizer(HfQuantizer):
+ """
+ FP4 quantization using fbgemm kernels
+ """
+
+ requires_parameters_quantization = True
+ # to remove if we decide to allow quantizing weights with this method
+ requires_calibration = False
+
+ required_packages = ["accelerate"]
+
+ def __init__(self, quantization_config, **kwargs):
+ super().__init__(quantization_config, **kwargs)
+ self.quantization_config = quantization_config
+
+ def validate_environment(self, *args, **kwargs):
+ if not is_torch_available():
+ raise ImportError(
+ "Using mxfp4 quantization requires torch"
+ "Please install the latest version of torch ( pip install --upgrade torch )"
+ )
+ if not torch.cuda.is_available():
+ raise RuntimeError("Using MXFP4 quantized models requires a GPU")
+
+ if not is_accelerate_available():
+ raise ImportError("Using mxfp4 requires Accelerate: `pip install accelerate`")
+
+ if self.quantization_config.dequantize:
+ return
+
+ compute_capability = torch.cuda.get_device_capability()
+ major, minor = compute_capability
+
+ if not is_triton_available("3.4.0") or not is_triton_kernels_availalble():
+ if self.pre_quantized and not self.quantization_config.dequantize:
+ logger.warning_once(
+ "MXFP4 quantization requires triton >= 3.4.0 and triton_kernels installed, we will default to dequantizing the model to bf16"
+ )
+ self.quantization_config.dequantize = True
+ return
+ else:
+ # we can't quantize the model in this case so we raise an error
+ raise ValueError("MXFP4 quantization requires triton >= 3.4.0 and triton_kernels installed")
+
+ if major < 9:
+ raise ValueError(
+ "MXFP4 quantized models is only supported on GPUs with compute capability >= 9.0 (e.g H100, or B100)"
+ )
+
+ device_map = kwargs.get("device_map", None)
+ if device_map is None:
+ logger.warning_once(
+ "You have loaded an FP4 model on CPU and have a CUDA device available, make sure to set "
+ "your model on a GPU device in order to run your model. To remove this warning, pass device_map = 'cuda'. "
+ )
+ elif device_map is not None:
+ if (
+ not self.pre_quantized
+ and isinstance(device_map, dict)
+ and ("cpu" in device_map.values() or "disk" in device_map.values())
+ ):
+ raise ValueError(
+ "You are attempting to load an FP4 model with a device_map that contains a CPU or disk device."
+ "This is not supported when the model is quantized on the fly. "
+ "Please use a quantized checkpoint or remove the CPU or disk device from the device_map."
+ )
+
+ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
+ if torch_dtype is None:
+ torch_dtype = torch.bfloat16
+ logger.info(
+ "Overriding torch_dtype=%s with `torch_dtype=torch.bfloat16` due to "
+ "requirements of `fbgemm-gpu` to enable model loading in fp4. "
+ "Pass your own torch_dtype to specify the dtype of the remaining non-linear layers or pass"
+ " torch_dtype=torch.bfloat16 to remove this warning.",
+ torch_dtype,
+ )
+ return torch_dtype
+
+ def check_quantized_param(
+ self,
+ model: "PreTrainedModel",
+ param_value: "torch.Tensor",
+ param_name: str,
+ state_dict: dict[str, Any],
+ **kwargs,
+ ):
+ from ..integrations import Mxfp4GptOssExperts
+ from ..models.gpt_oss.modeling_gpt_oss import GptOssExperts
+
+ # if we are dequantizing, the model doesn't have scales, and blocks only params like gate_up_proj and down_proj so we need to handle this case differently
+ if self.quantization_config.dequantize and ("blocks" in param_name or "scales" in param_name):
+ module, tensor_name = get_module_from_name(model, param_name[: -len("_blocks")])
+ else:
+ module, tensor_name = get_module_from_name(model, param_name)
+
+ if isinstance(module, Mxfp4GptOssExperts) or (
+ isinstance(module, GptOssExperts) and self.quantization_config.dequantize
+ ):
+ if tensor_name in ["down_proj_bias", "gate_up_proj_bias"]:
+ return False
+ return True
+ return False
+
+ def create_quantized_param(
+ self,
+ model: "PreTrainedModel",
+ param_value: "torch.Tensor",
+ param_name: str,
+ target_device: "torch.device",
+ state_dict: dict[str, Any],
+ unexpected_keys: Optional[list[str]] = None,
+ **kwargs,
+ ):
+ if is_triton_kernels_availalble() and is_triton_available("3.4.0"):
+ from triton_kernels.matmul_ogs import FlexCtx, InFlexData, PrecisionConfig
+
+ from ..integrations import Mxfp4GptOssExperts, dequantize, load_and_swizzle_mxfp4, quantize_to_mxfp4
+ from ..models.gpt_oss.modeling_gpt_oss import GptOssExperts
+
+ if not self.pre_quantized:
+ module, _ = get_module_from_name(model, param_name)
+ with torch.cuda.device(target_device):
+ if isinstance(module, Mxfp4GptOssExperts):
+ if "gate_up_proj" in param_name:
+ right_pad = module.gate_up_proj_right_pad
+ bottom_pad = module.gate_up_proj_bottom_pad
+ loaded_weight = torch.nn.functional.pad(
+ param_value, (0, right_pad, 0, bottom_pad, 0, 0), mode="constant", value=0
+ )
+ triton_weight_tensor, weight_scale = quantize_to_mxfp4(loaded_weight)
+ module.gate_up_proj_precision_config = PrecisionConfig(
+ weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData())
+ )
+ module.gate_up_proj = triton_weight_tensor
+ module.gate_up_proj_blocks = torch.nn.Parameter(
+ triton_weight_tensor.storage.data, requires_grad=False
+ )
+ elif "down_proj" in param_name:
+ right_pad = module.down_proj_right_pad
+ bottom_pad = module.down_proj_bottom_pad
+ loaded_weight = torch.nn.functional.pad(
+ param_value, (0, right_pad, 0, bottom_pad, 0, 0), mode="constant", value=0
+ ).to(target_device)
+ triton_weight_tensor, weight_scale = quantize_to_mxfp4(loaded_weight)
+ module.down_proj_precision_config = PrecisionConfig(
+ weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData())
+ )
+ module.down_proj = triton_weight_tensor
+ module.down_proj_blocks = torch.nn.Parameter(
+ triton_weight_tensor.storage.data, requires_grad=False
+ )
+
+ # we take this path if already quantized but not in a compatible way
+ # The params going here are either gate_up_proj_blocks, or down_proj_blocks, or gate_up_proj_scales, or down_proj_scales
+ else:
+ empty_param = kwargs.get("empty_param", None)
+ casting_dtype = kwargs.get("casting_dtype", None)
+ to_contiguous = kwargs.get("to_contiguous", None)
+ rank = kwargs.get("rank", None)
+ device_mesh = kwargs.get("device_mesh", None)
+ if ("blocks" in param_name or "scales" in param_name) and self.quantization_config.dequantize:
+ # blocks and scales have the same length that's this works for both
+ module, _ = get_module_from_name(model, param_name[: -len("_blocks")])
+ else:
+ module, _ = get_module_from_name(model, param_name)
+
+ shard_kwargs = {
+ "empty_param": empty_param,
+ "casting_dtype": casting_dtype,
+ "to_contiguous": to_contiguous,
+ "rank": rank,
+ "device_mesh": device_mesh,
+ "model": model,
+ }
+
+ if isinstance(module, Mxfp4GptOssExperts) or (
+ isinstance(module, GptOssExperts) and self.quantization_config.dequantize
+ ):
+ if self.quantization_config.dequantize:
+ # dq_param_name is the name of the parameter without the blocks or scales suffix, it's used in this case since we don't switch linears
+ # so we only have the original param name
+ dq_param_name = param_name[: -len("_blocks")]
+ dequantize(module, param_name, param_value, target_device, dq_param_name, **shard_kwargs)
+ else:
+ load_and_swizzle_mxfp4(
+ module,
+ param_name,
+ param_value,
+ target_device,
+ **shard_kwargs,
+ )
+
+ def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
+ # we are not really dequantizing, we are just removing everthing related to quantization here
+ if self.quantization_config.dequantize:
+ self.remove_quantization_config(model)
+ # clean cache due to triton ops
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ def update_expected_keys(self, model: "PreTrainedModel", expected_keys: list[str], checkpoint_keys: list[str]):
+ # Replace expected_keys for experts' gate_up_proj and down_proj with their _blocks and _scales variants
+ new_expected_keys = []
+ for key in expected_keys:
+ if key.endswith(".mlp.experts.gate_up_proj"):
+ base = key[: -len("gate_up_proj")]
+ new_expected_keys.append(base + "gate_up_proj_blocks")
+ new_expected_keys.append(base + "gate_up_proj_scales")
+ elif key.endswith(".mlp.experts.down_proj"):
+ base = key[: -len("down_proj")]
+ new_expected_keys.append(base + "down_proj_blocks")
+ new_expected_keys.append(base + "down_proj_scales")
+ else:
+ new_expected_keys.append(key)
+ return new_expected_keys
+
+ def _process_model_before_weight_loading(
+ self,
+ model: "PreTrainedModel",
+ keep_in_fp32_modules: Optional[list[str]] = None,
+ **kwargs,
+ ):
+ from ..integrations import replace_with_mxfp4_linear
+
+ self.modules_to_not_convert = self.get_modules_to_not_convert(
+ model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
+ )
+
+ use_kernels = kwargs.get("use_kernels", False)
+ # if we are using kernels, we can't use the quantized model, since the forward pass is different and needs special handling
+ if use_kernels:
+ logger.warning_once(
+ "You are using full precision kernels, we will dequantize the model to bf16. "
+ "To use the quantized model with quantization kernels, please set use_kernels=False"
+ )
+ self.quantization_config.dequantize = True
+
+ config = model.config
+ model = replace_with_mxfp4_linear(
+ model,
+ modules_to_not_convert=self.modules_to_not_convert,
+ quantization_config=self.quantization_config,
+ config=config,
+ )
+
+ model.config.quantization_config = self.quantization_config
+
+ def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
+ from ..integrations import Mxfp4GptOssExperts
+
+ not_missing_keys = []
+ for name, module in model.named_modules():
+ if isinstance(module, Mxfp4GptOssExperts):
+ for missing in missing_keys:
+ if (
+ (name in missing or name in f"{prefix}.{missing}")
+ and not missing.endswith(".weight")
+ and not missing.endswith(".bias")
+ ):
+ not_missing_keys.append(missing)
+ return [k for k in missing_keys if k not in not_missing_keys]
+
+ def update_tp_plan(self, config):
+ if "GptOssConfig" in config.__class__.__name__:
+ if getattr(config, "base_model_tp_plan", None) is not None:
+ config.base_model_tp_plan.update(
+ {
+ "layers.*.mlp.experts.gate_up_proj_blocks": "grouped_gemm",
+ "layers.*.mlp.experts.gate_up_proj_scales": "grouped_gemm",
+ "layers.*.mlp.experts.down_proj_blocks": "grouped_gemm",
+ "layers.*.mlp.experts.down_proj_scales": "grouped_gemm",
+ }
+ )
+ return config
+
+ def update_param_name(self, param_name: str) -> str:
+ if self.quantization_config.dequantize:
+ if "_blocks" in param_name:
+ return param_name.replace("_blocks", "")
+ elif "_scales" in param_name:
+ return param_name.replace("_scales", "")
+ return param_name
+
+ def is_serializable(self, safe_serialization=None):
+ logger.warning_once("MXFP4 quantization is not serializable using safetensors for now")
+ return False
+
+ @property
+ def is_trainable(self) -> bool:
+ logger.warning_once(
+ "MXFP4 quantization don't support training, please consider dequantizing the model first by passing quantization_config=Mxfp4Config(dequantize=True) to .from_pretrained()"
+ )
+ return False
diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py
index a41ea8166e3a..7eb8b3c46eb7 100644
--- a/src/transformers/testing_utils.py
+++ b/src/transformers/testing_utils.py
@@ -66,6 +66,7 @@
from .utils import (
ACCELERATE_MIN_VERSION,
GGUF_MIN_VERSION,
+ TRITON_MIN_VERSION,
is_accelerate_available,
is_apex_available,
is_apollo_torch_available,
@@ -168,6 +169,8 @@
is_torchcodec_available,
is_torchdynamo_available,
is_torchvision_available,
+ is_triton_available,
+ is_triton_kernels_availalble,
is_vision_available,
is_vptq_available,
strtobool,
@@ -455,6 +458,26 @@ def require_accelerate(test_case, min_version: str = ACCELERATE_MIN_VERSION):
)(test_case)
+def require_triton(min_version: str = TRITON_MIN_VERSION):
+ """
+ Decorator marking a test that requires triton. These tests are skipped when triton isn't installed.
+ """
+
+ def decorator(test_case):
+ return unittest.skipUnless(is_triton_available(min_version), f"test requires triton version >= {min_version}")(
+ test_case
+ )
+
+ return decorator
+
+
+def require_triton_kernels(test_case):
+ """
+ Decorator marking a test that requires triton_kernels. These tests are skipped when triton_kernels isn't installed.
+ """
+ return unittest.skipUnless(is_triton_kernels_availalble(), "test requires triton_kernels")(test_case)
+
+
def require_gguf(test_case, min_version: str = GGUF_MIN_VERSION):
"""
Decorator marking a test that requires ggguf. These tests are skipped when gguf isn't installed.
diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py
index bb2920b40335..f53914baa5dc 100644
--- a/src/transformers/utils/__init__.py
+++ b/src/transformers/utils/__init__.py
@@ -117,6 +117,7 @@
ENV_VARS_TRUE_VALUES,
GGUF_MIN_VERSION,
TORCH_FX_REQUIRED_VERSION,
+ TRITON_MIN_VERSION,
USE_JAX,
USE_TF,
USE_TORCH,
@@ -268,6 +269,8 @@
is_torchvision_available,
is_torchvision_v2_available,
is_training_run_on_sagemaker,
+ is_triton_available,
+ is_triton_kernels_availalble,
is_uroman_available,
is_vision_available,
is_vptq_available,
diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py
index d778532e33e9..be2e0f956a57 100644
--- a/src/transformers/utils/generic.py
+++ b/src/transformers/utils/generic.py
@@ -1044,7 +1044,10 @@ def wrapped_forward(*args, **kwargs):
if not isinstance(output, tuple):
collected_outputs[key] += (output,)
elif output[index] is not None:
- collected_outputs[key] += (output[index],)
+ if key not in collected_outputs:
+ collected_outputs[key] = (output[index],)
+ else:
+ collected_outputs[key] += (output[index],)
return output
return wrapped_forward
diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py
index 9fa7c5253d9c..a858c94c6958 100644
--- a/src/transformers/utils/import_utils.py
+++ b/src/transformers/utils/import_utils.py
@@ -76,6 +76,11 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
package_version = importlib.metadata.version("amd-quark")
except Exception:
package_exists = False
+ elif pkg_name == "triton":
+ try:
+ package_version = importlib.metadata.version("pytorch-triton")
+ except Exception:
+ package_exists = False
else:
# For packages other than "torch", don't attempt the fallback and set as not available
package_exists = False
@@ -111,6 +116,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
VPTQ_MIN_VERSION = "0.0.4"
TORCHAO_MIN_VERSION = "0.4.0"
AUTOROUND_MIN_VERSION = "0.5.0"
+TRITON_MIN_VERSION = "1.0.0"
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
_apex_available = _is_package_available("apex")
@@ -226,12 +232,13 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_tiktoken_available = _is_package_available("tiktoken")
_blobfile_available = _is_package_available("blobfile")
_liger_kernel_available = _is_package_available("liger_kernel")
-_triton_available = _is_package_available("triton")
_spqr_available = _is_package_available("spqr_quant")
_rich_available = _is_package_available("rich")
_kernels_available = _is_package_available("kernels")
_matplotlib_available = _is_package_available("matplotlib")
_mistral_common_available = _is_package_available("mistral_common")
+_triton_available, _triton_version = _is_package_available("triton", return_version=True)
+_triton_kernels_available = _is_package_available("triton_kernels")
_torch_version = "N/A"
_torch_available = False
@@ -412,6 +419,14 @@ def is_torch_deterministic():
return False
+def is_triton_available(min_version: str = TRITON_MIN_VERSION):
+ return _triton_available and version.parse(_triton_version) >= version.parse(min_version)
+
+
+def is_triton_kernels_availalble():
+ return _triton_kernels_available
+
+
def is_hadamard_available():
return _hadamard_available
@@ -1590,10 +1605,6 @@ def is_liger_kernel_available():
return version.parse(importlib.metadata.version("liger_kernel")) >= version.parse("0.3.0")
-def is_triton_available():
- return _triton_available
-
-
def is_rich_available():
return _rich_available
diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py
index 0bc616c6ff9c..ea807b9c51a9 100644
--- a/src/transformers/utils/quantization_config.py
+++ b/src/transformers/utils/quantization_config.py
@@ -65,6 +65,7 @@ class QuantizationMethod(str, Enum):
QUARK = "quark"
FPQUANT = "fp_quant"
AUTOROUND = "auto-round"
+ MXFP4 = "mxfp4"
class AWQLinearVersion(str, Enum):
@@ -2048,3 +2049,31 @@ def __init__(
self.json_export_config = JsonExporterConfig()
self.quant_method = QuantizationMethod.QUARK
+
+
+@dataclass
+class Mxfp4Config(QuantizationConfigMixin):
+ """
+ This is a wrapper class about all possible attributes and features that you can play with a model that has been
+ loaded using mxfp4 quantization.
+
+ Args:
+ modules_to_not_convert (`list`, *optional*, default to `None`):
+ The list of modules to not quantize, useful for quantizing models that explicitly require to have
+ some modules left in their original precision.
+ """
+
+ def __init__(
+ self,
+ modules_to_not_convert: Optional[list] = None,
+ dequantize: bool = False,
+ **kwargs,
+ ):
+ self.quant_method = QuantizationMethod.MXFP4
+ self.modules_to_not_convert = modules_to_not_convert
+ self.dequantize = dequantize
+
+ def get_loading_attributes(self):
+ return {
+ "dequantize": self.dequantize,
+ }
diff --git a/tests/fixtures/gpt_oss/integration_tests.json b/tests/fixtures/gpt_oss/integration_tests.json
new file mode 100644
index 000000000000..99b19b0ee7e0
--- /dev/null
+++ b/tests/fixtures/gpt_oss/integration_tests.json
@@ -0,0 +1,346 @@
+[
+ {
+ "quantized": true,
+ "model": "120b",
+ "kernels": false,
+ "attn_impl": "ft-hf-o-c/vllm-flash-attn3",
+ "mode": "eval",
+ "outputs": [
+ ".....Roses are red, violets are blue, I am a language model, and I can help you too!\n\nSure! Here",
+ "How are you? Tell me the name of the president of the United\n\nHello! As of my last update in November 2023, the President of the"
+ ]
+ },
+ {
+ "quantized": true,
+ "model": "120b",
+ "kernels": false,
+ "attn_impl": "ft-hf-o-c/vllm-flash-attn3",
+ "mode": "train",
+ "outputs": [
+ ".....Roses are red, violets are blue, I am a language model, and I can help you too!\n\nSure! Here",
+ "How are you? Tell me the name of the president of the United\n\nHello! As of my last update in November 2023, the President of the"
+ ]
+ },
+ {
+ "quantized": true,
+ "model": "120b",
+ "kernels": true,
+ "attn_impl": "ft-hf-o-c/vllm-flash-attn3",
+ "mode": "eval",
+ "outputs": [
+ "Did not work"
+ ]
+ },
+ {
+ "quantized": true,
+ "model": "120b",
+ "kernels": true,
+ "attn_impl": "ft-hf-o-c/vllm-flash-attn3",
+ "mode": "train",
+ "outputs": [
+ "Did not work"
+ ]
+ },
+ {
+ "quantized": true,
+ "model": "120b",
+ "kernels": false,
+ "attn_impl": "eager",
+ "mode": "eval",
+ "outputs": [
+ ".....Roses are red, violets are blue, I am a language model, and I can help you too!\n\nSure! Here",
+ "How are you? Tell me the name of the president of the United\n\nHello! As of my last update in November 2023, the President of the"
+ ]
+ },
+ {
+ "quantized": true,
+ "model": "120b",
+ "kernels": false,
+ "attn_impl": "eager",
+ "mode": "train",
+ "outputs": [
+ ".....Roses are red, violets are blue, I am a language model, and I can help you too!\n\nSure! Here",
+ "How are you? Tell me the name of the president of the United\n\nHello! As of my last update in November 2023, the President of the"
+ ]
+ },
+ {
+ "quantized": true,
+ "model": "120b",
+ "kernels": true,
+ "attn_impl": "eager",
+ "mode": "eval",
+ "outputs": [
+ "Did not work"
+ ]
+ },
+ {
+ "quantized": true,
+ "model": "120b",
+ "kernels": true,
+ "attn_impl": "eager",
+ "mode": "train",
+ "outputs": [
+ "Did not work"
+ ]
+ },
+ {
+ "quantized": true,
+ "model": "20b",
+ "kernels": false,
+ "attn_impl": "ft-hf-o-c/vllm-flash-attn3",
+ "mode": "eval",
+ "outputs": [
+ ".....Roses are red, violets are blue, I love you, and I love you too.\n\nIt sounds like you're looking for",
+ "How are you? Tell me the name of the president of the United States.\" The assistant should respond with the name of the president. The user is asking for"
+ ]
+ },
+ {
+ "quantized": true,
+ "model": "20b",
+ "kernels": false,
+ "attn_impl": "ft-hf-o-c/vllm-flash-attn3",
+ "mode": "train",
+ "outputs": [
+ ".....Roses are red, violets are blue, I love you, and I love you too.\n\nIt sounds like you're looking for",
+ "How are you? Tell me the name of the president of the United States.\" The assistant should respond with the name of the president. The user is asking for"
+ ]
+ },
+ {
+ "quantized": true,
+ "model": "20b",
+ "kernels": true,
+ "attn_impl": "ft-hf-o-c/vllm-flash-attn3",
+ "mode": "eval",
+ "outputs": [
+ "Did not work"
+ ]
+ },
+ {
+ "quantized": true,
+ "model": "20b",
+ "kernels": true,
+ "attn_impl": "ft-hf-o-c/vllm-flash-attn3",
+ "mode": "train",
+ "outputs": [
+ "Did not work"
+ ]
+ },
+ {
+ "quantized": true,
+ "model": "20b",
+ "kernels": false,
+ "attn_impl": "eager",
+ "mode": "eval",
+ "outputs": [
+ ".....Roses are red, violets are blue, I love you, and I love you too.\n\nIt sounds like you're expressing a",
+ "How are you? Tell me the name of the president of the United States.\" The assistant should respond with the name of the president. The user is asking for"
+ ]
+ },
+ {
+ "quantized": true,
+ "model": "20b",
+ "kernels": false,
+ "attn_impl": "eager",
+ "mode": "train",
+ "outputs": [
+ ".....Roses are red, violets are blue, I love you, and I love you too.\n\nIt sounds like you're expressing a",
+ "How are you? Tell me the name of the president of the United States.\" The assistant should respond with the name of the president. The user is asking for"
+ ]
+ },
+ {
+ "quantized": true,
+ "model": "20b",
+ "kernels": true,
+ "attn_impl": "eager",
+ "mode": "eval",
+ "outputs": [
+ "Did not work"
+ ]
+ },
+ {
+ "quantized": true,
+ "model": "20b",
+ "kernels": true,
+ "attn_impl": "eager",
+ "mode": "train",
+ "outputs": [
+ "Did not work"
+ ]
+ },
+ {
+ "quantized": false,
+ "model": "120b",
+ "kernels": false,
+ "attn_impl": "ft-hf-o-c/vllm-flash-attn3",
+ "mode": "eval",
+ "outputs": [
+ ".....Roses are red, violets are blue,\nI am a language model, not a human being.\n```\n\nThis poem is a",
+ "How are you? Tell me the name of the president of the United Kingdom?\n\nThe United Kingdom does not have a president. The head of state is the"
+ ]
+ },
+ {
+ "quantized": false,
+ "model": "120b",
+ "kernels": false,
+ "attn_impl": "ft-hf-o-c/vllm-flash-attn3",
+ "mode": "train",
+ "outputs": [
+ ".....Roses are red, violets are blue, I am a language model trained by OpenAI.\n\nI am a large language model",
+ "How are you? Tell me the name of the president of the United\n\nHello! I'm an AI language model, so I don't have feelings, but I'm here"
+ ]
+ },
+ {
+ "quantized": false,
+ "model": "120b",
+ "kernels": true,
+ "attn_impl": "ft-hf-o-c/vllm-flash-attn3",
+ "mode": "eval",
+ "outputs": [
+ ".....Roses are red, violets are blue,\nI am a language model, not a human being.\n```\n\nThis poem is a",
+ "How are you? Tell me the name of the president of the United Kingdom?\n\nThe United Kingdom does not have a president. The head of state is the"
+ ]
+ },
+ {
+ "quantized": false,
+ "model": "120b",
+ "kernels": true,
+ "attn_impl": "ft-hf-o-c/vllm-flash-attn3",
+ "mode": "train",
+ "outputs": [
+ ".....Roses are red, violets are blue, I am a language model trained by OpenAI.\n\nI am a large language model",
+ "How are you? Tell me the name of the president of the United\n\nHello! I'm an AI language model, so I don't have feelings, but I'm here"
+ ]
+ },
+ {
+ "quantized": false,
+ "model": "120b",
+ "kernels": false,
+ "attn_impl": "eager",
+ "mode": "eval",
+ "outputs": [
+ ".....Roses are red, violets are blue,\nI am a language model, not a human being.\n```\n\nThis poem is a",
+ "How are you? Tell me the name of the president of the United States?\n\nAs an AI language model, I do not have personal feelings or emotions,"
+ ]
+ },
+ {
+ "quantized": false,
+ "model": "120b",
+ "kernels": false,
+ "attn_impl": "eager",
+ "mode": "train",
+ "outputs": [
+ ".....Roses are red, violets are blue, I am a language model, and I can help you with your request.\n\nSure",
+ "How are you? Tell me the name of the president of the United\n\nHello! I'm an AI language model, so I don't have feelings, but I'm here"
+ ]
+ },
+ {
+ "quantized": false,
+ "model": "120b",
+ "kernels": true,
+ "attn_impl": "eager",
+ "mode": "eval",
+ "outputs": [
+ ".....Roses are red, violets are blue,\nI am a language model, not a human being.\n```\n\nThis poem is a",
+ "How are you? Tell me the name of the president of the United States?\n\nAs an AI language model, I do not have personal feelings or emotions,"
+ ]
+ },
+ {
+ "quantized": false,
+ "model": "120b",
+ "kernels": true,
+ "attn_impl": "eager",
+ "mode": "train",
+ "outputs": [
+ ".....Roses are red, violets are blue, I am a language model, and I can help you with your request.\n\nSure",
+ "How are you? Tell me the name of the president of the United\n\nHello! I'm an AI language model, so I don't have feelings, but I'm here"
+ ]
+ },
+ {
+ "quantized": false,
+ "model": "20b",
+ "kernels": false,
+ "attn_impl": "ft-hf-o-c/vllm-flash-attn3",
+ "mode": "eval",
+ "outputs": [
+ ".....Roses are red, violets are blue, I love you, and I love you too!\n\nRoses are red, vio",
+ "How are you? Tell me the name of the president of the United States.\" The assistant should respond with the name of the president. The user is asking for"
+ ]
+ },
+ {
+ "quantized": false,
+ "model": "20b",
+ "kernels": false,
+ "attn_impl": "ft-hf-o-c/vllm-flash-attn3",
+ "mode": "train",
+ "outputs": [
+ ".....Roses are red, violets are blue\" (makes sense). But the phrase \"the answer is 3\" is not a",
+ "How are you? Tell me the name of the president of the United States.\" The answer to that is \"Joe Biden.\" The user is asking for the name"
+ ]
+ },
+ {
+ "quantized": false,
+ "model": "20b",
+ "kernels": true,
+ "attn_impl": "ft-hf-o-c/vllm-flash-attn3",
+ "mode": "eval",
+ "outputs": [
+ ".....Roses are red, violets are blue, I love you, and I love you too!\n\nRoses are red, vio",
+ "How are you? Tell me the name of the president of the United States.\" The assistant should respond with the name of the president. The user is asking for"
+ ]
+ },
+ {
+ "quantized": false,
+ "model": "20b",
+ "kernels": true,
+ "attn_impl": "ft-hf-o-c/vllm-flash-attn3",
+ "mode": "train",
+ "outputs": [
+ ".....Roses are red, violets are blue\" (makes sense). But the phrase \"the answer is 3\" is not a",
+ "How are you? Tell me the name of the president of the United States.\" The answer to that is \"Joe Biden.\" The user is asking for the name"
+ ]
+ },
+ {
+ "quantized": false,
+ "model": "20b",
+ "kernels": false,
+ "attn_impl": "eager",
+ "mode": "eval",
+ "outputs": [
+ ".....Roses are red, violets are blue, I love you, and I love you too!\n\nRoses are red, vio",
+ "How are you? Tell me the name of the president of the United States.\" The assistant should respond with the name of the president. The user is asking for"
+ ]
+ },
+ {
+ "quantized": false,
+ "model": "20b",
+ "kernels": false,
+ "attn_impl": "eager",
+ "mode": "train",
+ "outputs": [
+ ".....Roses are red, violets are blue.\" -> from which we can derive a rule: if we have a red object that is",
+ "How are you? Tell me the name of the president of the United States.\n\nI am an AI language model and I do not have a personal life or"
+ ]
+ },
+ {
+ "quantized": false,
+ "model": "20b",
+ "kernels": true,
+ "attn_impl": "eager",
+ "mode": "eval",
+ "outputs": [
+ ".....Roses are red, violets are blue, I love you, and I love you too!\n\nRoses are red, vio",
+ "How are you? Tell me the name of the president of the United States.\" The assistant should respond with the name of the president. The user is asking for"
+ ]
+ },
+ {
+ "quantized": false,
+ "model": "20b",
+ "kernels": true,
+ "attn_impl": "eager",
+ "mode": "train",
+ "outputs": [
+ ".....Roses are red, violets are blue.\" -> from which we can derive a rule: if we have a red object that is",
+ "How are you? Tell me the name of the president of the United States.\n\nI am an AI language model and I do not have a personal life or"
+ ]
+ }
+]
diff --git a/tests/models/gpt_oss/__init__.py b/tests/models/gpt_oss/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/models/gpt_oss/test_modeling_gpt_oss.py b/tests/models/gpt_oss/test_modeling_gpt_oss.py
new file mode 100644
index 000000000000..ab978b9c9bc5
--- /dev/null
+++ b/tests/models/gpt_oss/test_modeling_gpt_oss.py
@@ -0,0 +1,523 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Testing suite for the PyTorch GptOss model."""
+
+import inspect
+import json
+import os
+import subprocess
+import tempfile
+import unittest
+from pathlib import Path
+
+import pytest
+from parameterized import parameterized
+
+from transformers import (
+ AutoModelForCausalLM,
+ AutoTokenizer,
+ GptOssConfig,
+ is_torch_available,
+)
+from transformers.testing_utils import (
+ cleanup,
+ require_read_token,
+ require_torch,
+ require_torch_accelerator,
+ slow,
+ torch_device,
+)
+
+from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
+from ...test_configuration_common import ConfigTester
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import (
+ GptOssForCausalLM,
+ GptOssModel,
+ )
+
+ NUM_GPUS = torch.cuda.device_count()
+
+
+class GptOssModelTester(CausalLMModelTester):
+ if is_torch_available():
+ config_class = GptOssConfig
+ base_model_class = GptOssModel
+ causal_lm_class = GptOssForCausalLM
+
+ pipeline_model_mapping = (
+ {
+ "feature-extraction": GptOssModel,
+ "text-generation": GptOssForCausalLM,
+ }
+ if is_torch_available()
+ else {}
+ )
+
+
+@require_torch
+class GptOssModelTest(CausalLMModelTest, unittest.TestCase):
+ all_model_classes = (GptOssModel, GptOssForCausalLM) if is_torch_available() else ()
+ pipeline_model_mapping = (
+ {
+ "feature-extraction": GptOssModel,
+ "text-generation": GptOssForCausalLM,
+ }
+ if is_torch_available()
+ else {}
+ )
+
+ test_headmasking = False
+ test_pruning = False
+ _is_stateful = True
+ model_split_percents = [0.5, 0.6]
+ model_tester_class = GptOssModelTester
+
+ def setUp(self):
+ self.model_tester = GptOssModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=GptOssConfig, hidden_size=37)
+
+ @unittest.skip("Failing because of unique cache (HybridCache)")
+ def test_model_outputs_equivalence(self, **kwargs):
+ pass
+
+ @unittest.skip("GptOss's forcefully disables sdpa due to Sink")
+ def test_sdpa_can_dispatch_non_composite_models(self):
+ pass
+
+ @unittest.skip("GptOss's eager attn/sdpa attn outputs are expected to be different")
+ def test_eager_matches_sdpa_generate(self):
+ pass
+
+ @parameterized.expand([("random",), ("same",)])
+ @pytest.mark.generate
+ @unittest.skip("GptOss has HybridCache which is not compatible with assisted decoding")
+ def test_assisted_decoding_matches_greedy_search(self, assistant_type):
+ pass
+
+ @unittest.skip("GptOss has HybridCache which is not compatible with assisted decoding")
+ def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
+ pass
+
+ @pytest.mark.generate
+ @unittest.skip("GptOss has HybridCache which is not compatible with assisted decoding")
+ def test_assisted_decoding_sample(self):
+ pass
+
+ @unittest.skip("GptOss has HybridCache which is not compatible with dola decoding")
+ def test_dola_decoding_sample(self):
+ pass
+
+ @unittest.skip("GptOss has HybridCache and doesn't support continue from past kv")
+ def test_generate_continue_from_past_key_values(self):
+ pass
+
+ @unittest.skip("GptOss has HybridCache and doesn't support contrastive generation")
+ def test_contrastive_generate(self):
+ pass
+
+ @unittest.skip("GptOss has HybridCache and doesn't support contrastive generation")
+ def test_contrastive_generate_dict_outputs_use_cache(self):
+ pass
+
+ @unittest.skip("GptOss has HybridCache and doesn't support contrastive generation")
+ def test_contrastive_generate_low_memory(self):
+ pass
+
+ @unittest.skip("GptOss has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
+ def test_generate_with_static_cache(self):
+ pass
+
+ @unittest.skip("GptOss has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
+ def test_generate_from_inputs_embeds_with_static_cache(self):
+ pass
+
+ @unittest.skip("GptOss has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
+ def test_generate_continue_from_inputs_embeds(self):
+ pass
+
+ @unittest.skip(
+ reason="HybridCache can't be gathered because it is not iterable. Adding a simple iter and dumping `distributed_iterator`"
+ " as in Dynamic Cache doesn't work. NOTE: @gante all cache objects would need better compatibility with multi gpu setting"
+ )
+ def test_multi_gpu_data_parallel_forward(self):
+ pass
+
+ @unittest.skip("GptOss has HybridCache which auto-compiles. Compile and FA2 don't work together.")
+ def test_eager_matches_fa2_generate(self):
+ pass
+
+ @unittest.skip("GptOss eager/FA2 attention outputs are expected to be different")
+ def test_flash_attn_2_equivalence(self):
+ pass
+
+ @unittest.skip("Most probably because of the MOE, the moe and router does not ignore padding tokens")
+ def test_eager_padding_matches_padding_free_with_position_ids(self):
+ pass
+
+ @unittest.skip("GptOss does not support flex officially")
+ def test_flex_attention_with_grads(self):
+ pass
+
+
+RESULTS_PATH = Path(__file__).parent.parent.parent / "fixtures/gpt_oss/integration_tests.json"
+
+
+# ------------------------
+# Worker function for distributed torchrun
+# ------------------------
+def distributed_worker(quantized, model_size, kernels, attn_impl, mode):
+ """This is the function that will be executed by torchrun workers."""
+ import os
+
+ from transformers import AutoModelForCausalLM, AutoTokenizer
+ from transformers.testing_utils import torch_device
+
+ input_text = [
+ "Roses are red, violets",
+ "How are you? Tell me the name of the president of",
+ ]
+
+ # Convert args
+ quantized = quantized.lower() == "true"
+ kernels = kernels.lower() == "true"
+
+ # Distributed model loading
+ model_id = f"/fsx/vb/new-oai/gpt-oss-{model_size}-trfs"
+ model = AutoModelForCausalLM.from_pretrained(
+ model_id,
+ torch_dtype="auto",
+ tp_plan="auto", # distributed inference
+ use_kernels=kernels,
+ ).to(torch_device)
+ model.set_attn_implementation(attn_impl)
+ tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
+
+ # Inference
+ inputs = tokenizer(input_text, return_tensors="pt", padding=True).to(torch_device)
+ output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
+ output_texts = tokenizer.batch_decode(output, skip_special_tokens=False)
+
+ # Only rank 0 writes results
+ if int(os.environ.get("RANK", "0")) == 0:
+ result_entry = {
+ "quantized": quantized,
+ "model": model_size,
+ "kernels": kernels,
+ "attn_impl": attn_impl,
+ "mode": mode,
+ "outputs": output_texts,
+ }
+
+ if os.path.exists(RESULTS_PATH):
+ with open(RESULTS_PATH, "r") as f:
+ results = json.load(f)
+ else:
+ results = []
+ results.append(result_entry)
+
+ with open(RESULTS_PATH, "w") as f:
+ json.dump(results, f, indent=2)
+
+
+@slow
+@require_torch_accelerator
+class GptOssIntegrationTest(unittest.TestCase):
+ input_text = [
+ "Roses are red, violets",
+ "How are you? Tell me the name of the president of",
+ ]
+
+ def setUp(self):
+ cleanup(torch_device, gc_collect=True)
+
+ def tearDown(self):
+ cleanup(torch_device, gc_collect=True)
+
+ # ------------------------
+ # Non-distributed inference
+ # ------------------------
+ @staticmethod
+ def load_and_forward(model_id, attn_implementation, input_text, **pretrained_kwargs):
+ model = AutoModelForCausalLM.from_pretrained(
+ model_id,
+ torch_dtype=torch.bfloat16,
+ device_map="auto",
+ attn_implementation=attn_implementation,
+ **pretrained_kwargs,
+ )
+ tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
+
+ inputs = tokenizer(input_text, return_tensors="pt", padding=True).to(model.device)
+ output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
+ output_text = tokenizer.batch_decode(output, skip_special_tokens=False)
+ return output_text
+
+ # ------------------------
+ # Distributed inference using inspect
+ # ------------------------
+ @staticmethod
+ def run_distributed_test(quantized, model, kernels, attn_impl, mode):
+ """Launch torchrun using a temporary worker file generated from inspect.getsource()."""
+ import textwrap
+
+ # Extract worker function source dynamically
+ worker_src = inspect.getsource(distributed_worker)
+
+ # Create a temp file that calls the worker
+ script_code = f"""
+import sys
+import json
+
+RESULTS_PATH = "{RESULTS_PATH}"
+
+{worker_src}
+
+if __name__ == "__main__":
+ distributed_worker("{quantized}", "{model}", "{kernels}", "{attn_impl}", "{mode}")
+"""
+ # Dedent for proper formatting
+ script_code = textwrap.dedent(script_code)
+
+ # Write to temp file
+ with tempfile.NamedTemporaryFile("w", suffix="_worker.py", delete=False) as tmp:
+ tmp.write(script_code)
+ tmp_path = tmp.name
+
+ # Launch torchrun
+ cmd = [
+ "torchrun",
+ f"--nproc_per_node={NUM_GPUS}",
+ tmp_path,
+ ]
+ subprocess.run(cmd, check=True)
+
+ # Cleanup
+ os.remove(tmp_path)
+
+ # ------------------------
+ # Shared parameterization
+ # ------------------------
+ PARAMETERS = [
+ (False, "120b", False, "eager", "eval"),
+ (False, "120b", False, "eager", "train"),
+ (False, "120b", False, "ft-hf-o-c/vllm-flash-attn3", "eval"),
+ (False, "120b", False, "ft-hf-o-c/vllm-flash-attn3", "train"),
+ (False, "120b", True, "eager", "eval"),
+ (False, "120b", True, "eager", "train"),
+ (False, "120b", True, "ft-hf-o-c/vllm-flash-attn3", "eval"),
+ (False, "120b", True, "ft-hf-o-c/vllm-flash-attn3", "train"),
+ (True, "120b", False, "eager", "eval"),
+ (True, "120b", False, "eager", "train"),
+ (True, "120b", False, "ft-hf-o-c/vllm-flash-attn3", "eval"),
+ (True, "120b", False, "ft-hf-o-c/vllm-flash-attn3", "train"),
+ (True, "120b", True, "eager", "eval"),
+ (True, "120b", True, "eager", "train"),
+ (True, "120b", True, "ft-hf-o-c/vllm-flash-attn3", "eval"),
+ (True, "120b", True, "ft-hf-o-c/vllm-flash-attn3", "train"),
+ (False, "20b", False, "eager", "eval"),
+ (False, "20b", False, "eager", "train"),
+ (False, "20b", False, "ft-hf-o-c/vllm-flash-attn3", "eval"),
+ (False, "20b", False, "ft-hf-o-c/vllm-flash-attn3", "train"),
+ (False, "20b", True, "eager", "eval"),
+ (False, "20b", True, "eager", "train"),
+ (False, "20b", True, "ft-hf-o-c/vllm-flash-attn3", "eval"),
+ (False, "20b", True, "ft-hf-o-c/vllm-flash-attn3", "train"),
+ (True, "20b", False, "eager", "eval"),
+ (True, "20b", False, "eager", "train"),
+ (True, "20b", False, "ft-hf-o-c/vllm-flash-attn3", "eval"),
+ (True, "20b", False, "ft-hf-o-c/vllm-flash-attn3", "train"),
+ (True, "20b", True, "eager", "eval"),
+ (True, "20b", True, "eager", "train"),
+ (True, "20b", True, "ft-hf-o-c/vllm-flash-attn3", "eval"),
+ (True, "20b", True, "ft-hf-o-c/vllm-flash-attn3", "train"),
+ ]
+
+ # ------------------------
+ # Non-distributed test
+ # ------------------------
+ @parameterized.expand(PARAMETERS)
+ @require_read_token
+ def test_model_outputs(self, quantized, model, kernels, attn_impl, mode):
+ model_id = f"/fsx/vb/new-oai/gpt-oss-{model}-trfs"
+ output_texts = self.load_and_forward(
+ model_id,
+ attn_impl,
+ self.input_text,
+ use_kernels=kernels,
+ )
+
+ result_entry = {
+ "quantized": quantized,
+ "model": model,
+ "kernels": kernels,
+ "attn_impl": attn_impl,
+ "mode": mode,
+ "outputs": output_texts,
+ }
+
+ if os.path.exists(RESULTS_PATH):
+ with open(RESULTS_PATH, "r") as f:
+ results = json.load(f)
+ else:
+ results = []
+ results.append(result_entry)
+ with open(RESULTS_PATH, "w") as f:
+ json.dump(results, f, indent=2)
+
+ self.assertIsInstance(output_texts, list)
+ self.assertTrue(all(isinstance(x, str) for x in output_texts))
+
+ # ------------------------
+ # Distributed test
+ # ------------------------
+ @parameterized.expand(PARAMETERS)
+ @require_read_token
+ def test_model_outputs_distributed(self, quantized, model, kernels, attn_impl, mode):
+ self.run_distributed_test(quantized, model, kernels, attn_impl, mode)
+
+ def test_model_matches_original_20b(self):
+ input_text = "Roses are red, violets"
+
+ original_output = "Roses are red, violets are blue, I love you, and I love you too."
+ original_logprobs = torch.tensor(
+ [
+ -0.037353515625,
+ -0.08154296875,
+ -1.21875,
+ -1.953125,
+ -2.234375,
+ -0.96875,
+ -1.546875,
+ -1.640625,
+ -0.93359375,
+ -1.609375,
+ -1.625,
+ -0.85546875,
+ -1.7265625,
+ -0.7421875,
+ -2.078125,
+ -0.006561279296875,
+ -0.10498046875,
+ -0.1767578125,
+ -0.1240234375,
+ -0.099609375,
+ ]
+ )
+
+ model_id = "/fsx/vb/new-oai/gpt-oss-20b-trfs"
+
+ model = AutoModelForCausalLM.from_pretrained(
+ model_id,
+ torch_dtype=torch.bfloat16,
+ device_map="auto",
+ attn_implementation="eager",
+ )
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
+ tokens = tokenizer(input_text)["input_ids"]
+
+ num_generated_tokens = 0
+ with torch.no_grad():
+ for i in range(12):
+ tensors = torch.as_tensor(tokens, dtype=torch.int32, device=model.device).unsqueeze(0)
+ logits = model(tensors).logits[0]
+
+ predicted_token = torch.argmax(logits[-1, :], dim=-1).item()
+ logprobs = torch.log_softmax(logits[-1, :], dim=-1)
+ selected_logprobs = logprobs[predicted_token]
+
+ tokens.append(predicted_token)
+ num_generated_tokens += 1
+ decoded_token = tokenizer.decode([predicted_token])
+ logprob_differences = selected_logprobs - original_logprobs[i]
+
+ print(
+ f"Generated token: {repr(decoded_token)}, logprob: {selected_logprobs}, logprob differences: {logprob_differences}"
+ )
+ torch.testing.assert_close(
+ selected_logprobs.cpu().to(original_logprobs.dtype), original_logprobs[i], atol=1e-1, rtol=1e-1
+ )
+
+ decoded_string = tokenizer.decode(tokens)
+ self.assertTrue(original_output.startswith(decoded_string))
+
+ def test_model_matches_original_120b(self):
+ input_text = "Roses are red, violets"
+
+ original_output = """Roses are red, violets are blue,
+I am a language model, not a human being"""
+ original_logprobs = torch.tensor(
+ [
+ -0.90234375,
+ -0.66015625,
+ -1.546875,
+ -2.703125,
+ -2.078125,
+ -1.21875,
+ -2.484375,
+ -0.031982421875,
+ -0.84765625,
+ -1.890625,
+ -0.1923828125,
+ -2.046875,
+ -1.65625,
+ -1.3515625,
+ -1.1640625,
+ -0.3671875,
+ -1.9921875,
+ -1.5390625,
+ -1.46875,
+ -0.85546875,
+ ]
+ )
+
+ model_id = "/fsx/vb/new-oai/gpt-oss-120b-trfs"
+
+ model = AutoModelForCausalLM.from_pretrained(
+ model_id,
+ torch_dtype=torch.bfloat16,
+ device_map="auto",
+ attn_implementation="eager",
+ )
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
+ tokens = tokenizer(input_text)["input_ids"]
+
+ num_generated_tokens = 0
+ with torch.no_grad():
+ for i in range(12):
+ tensors = torch.as_tensor(tokens, dtype=torch.int32, device=model.device).unsqueeze(0)
+ logits = model(tensors).logits[0]
+
+ predicted_token = torch.argmax(logits[-1, :], dim=-1).item()
+ logprobs = torch.log_softmax(logits[-1, :], dim=-1)
+ selected_logprobs = logprobs[predicted_token]
+
+ tokens.append(predicted_token)
+ num_generated_tokens += 1
+ decoded_token = tokenizer.decode([predicted_token])
+ logprob_differences = selected_logprobs - original_logprobs[i]
+
+ print(
+ f"Generated token: {repr(decoded_token)}, logprob: {selected_logprobs}, logprob differences: {logprob_differences}"
+ )
+ torch.testing.assert_close(
+ selected_logprobs.cpu().to(original_logprobs.dtype), original_logprobs[i], atol=1e-1, rtol=1e-1
+ )
+
+ decoded_string = tokenizer.decode(tokens)
+ self.assertTrue(original_output.startswith(decoded_string))
diff --git a/tests/quantization/mxfp4/__init__.py b/tests/quantization/mxfp4/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/quantization/mxfp4/test_mxfp4.py b/tests/quantization/mxfp4/test_mxfp4.py
new file mode 100644
index 000000000000..2194c2d3219e
--- /dev/null
+++ b/tests/quantization/mxfp4/test_mxfp4.py
@@ -0,0 +1,420 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import unittest
+from unittest.mock import patch
+
+from transformers import AutoTokenizer, GptOssForCausalLM, Mxfp4Config
+from transformers.testing_utils import (
+ require_torch,
+ require_torch_gpu,
+ require_torch_large_gpu,
+ require_triton,
+ require_triton_kernels,
+ slow,
+)
+from transformers.utils import (
+ is_torch_available,
+)
+
+
+if is_torch_available():
+ import torch
+
+
+class Mxfp4ConfigTest(unittest.TestCase):
+ def test_basic_config_creation(self):
+ """Test basic configuration creation with default values"""
+ config = Mxfp4Config()
+ self.assertEqual(config.quant_method.value, "mxfp4")
+ self.assertIsNone(config.modules_to_not_convert)
+ self.assertFalse(config.dequantize)
+
+ def test_config_with_modules_to_not_convert(self):
+ """Test configuration with modules to not convert"""
+ modules = ["model.layers.*.self_attn", "lm_head"]
+ config = Mxfp4Config(modules_to_not_convert=modules)
+ self.assertEqual(config.modules_to_not_convert, modules)
+
+ def test_config_with_dequantize(self):
+ """Test configuration with dequantize enabled"""
+ config = Mxfp4Config(dequantize=True)
+ self.assertTrue(config.dequantize)
+
+ def test_get_loading_attributes(self):
+ """Test get_loading_attributes method"""
+ config = Mxfp4Config(dequantize=True)
+ attrs = config.get_loading_attributes()
+ self.assertEqual(attrs, {"dequantize": True})
+
+ def test_to_dict(self):
+ """Test configuration serialization to dict"""
+ config = Mxfp4Config(modules_to_not_convert=["lm_head"], dequantize=True)
+ config_dict = config.to_dict()
+ self.assertEqual(config_dict["quant_method"], "mxfp4")
+ self.assertEqual(config_dict["modules_to_not_convert"], ["lm_head"])
+ self.assertTrue(config_dict["dequantize"])
+
+ def test_from_dict(self):
+ """Test configuration creation from dict"""
+ config_dict = {"quant_method": "mxfp4", "modules_to_not_convert": ["lm_head"], "dequantize": True}
+ config = Mxfp4Config.from_dict(config_dict)
+ self.assertEqual(config.modules_to_not_convert, ["lm_head"])
+ self.assertTrue(config.dequantize)
+
+
+class Mxfp4QuantizerTest(unittest.TestCase):
+ """Test the Mxfp4HfQuantizer class"""
+
+ def setUp(self):
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ def test_quantizer_validation_no_torch(self):
+ """Test quantizer validation when torch is not available"""
+ with patch("transformers.quantizers.quantizer_mxfp4.is_torch_available", return_value=False):
+ from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
+
+ config = Mxfp4Config()
+ quantizer = Mxfp4HfQuantizer(config)
+
+ with self.assertRaises(ImportError):
+ quantizer.validate_environment()
+
+ def test_quantizer_validation_no_cuda(self):
+ """Test quantizer validation when CUDA is not available"""
+ with patch("torch.cuda.is_available", return_value=False):
+ from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
+
+ config = Mxfp4Config()
+ quantizer = Mxfp4HfQuantizer(config)
+
+ with self.assertRaises(RuntimeError):
+ quantizer.validate_environment()
+
+ def test_quantizer_validation_low_compute_capability(self):
+ """Test quantizer validation with low compute capability"""
+ with patch("torch.cuda.get_device_capability", return_value=(8, 0)):
+ from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
+
+ config = Mxfp4Config()
+ quantizer = Mxfp4HfQuantizer(config)
+
+ with self.assertRaises(ValueError):
+ quantizer.validate_environment()
+
+ def test_quantizer_validation_low_compute_capability_with_dequantize(self):
+ """Test quantizer validation with low compute capability but dequantize enabled"""
+ with patch("torch.cuda.get_device_capability", return_value=(8, 0)):
+ from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
+
+ config = Mxfp4Config(dequantize=True)
+ quantizer = Mxfp4HfQuantizer(config)
+
+ # Should not raise error with dequantize=True
+ try:
+ quantizer.validate_environment()
+ except ValueError as e:
+ if "compute capability" in str(e):
+ self.fail("Should not raise compute capability error when dequantize=True")
+
+ def test_quantizer_validation_missing_triton(self):
+ """Test quantizer validation when triton is not available"""
+ with (
+ patch("transformers.quantizers.quantizer_mxfp4.is_triton_available", return_value=False),
+ patch("transformers.quantizers.quantizer_mxfp4.is_triton_kernels_availalble", return_value=False),
+ ):
+ from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
+
+ config = Mxfp4Config()
+ quantizer = Mxfp4HfQuantizer(config)
+ quantizer.pre_quantized = False
+ with self.assertRaises(ValueError):
+ quantizer.validate_environment()
+
+ def test_quantizer_validation_missing_triton_pre_quantized_no_dequantize(self):
+ """Test quantizer validation when triton is not available but model is pre-quantized and dequantize is False"""
+ with (
+ patch("transformers.quantizers.quantizer_mxfp4.is_triton_available", return_value=False),
+ patch("transformers.quantizers.quantizer_mxfp4.is_triton_kernels_availalble", return_value=False),
+ ):
+ from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
+
+ config = Mxfp4Config()
+ quantizer = Mxfp4HfQuantizer(config)
+ quantizer.pre_quantized = True
+
+ # Should automatically set dequantize=True and warn
+ quantizer.validate_environment()
+ self.assertTrue(quantizer.quantization_config.dequantize)
+
+ def test_update_torch_dtype(self):
+ """Test torch dtype updating"""
+ from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
+
+ config = Mxfp4Config()
+ quantizer = Mxfp4HfQuantizer(config)
+
+ # Should default to bfloat16
+ result_dtype = quantizer.update_torch_dtype(None)
+ self.assertEqual(result_dtype, torch.bfloat16)
+
+ # Should preserve existing dtype
+ result_dtype = quantizer.update_torch_dtype(torch.float32)
+ self.assertEqual(result_dtype, torch.float32)
+
+ def test_update_expected_keys(self):
+ """Test expected keys updating for quantized models"""
+ from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
+
+ config = Mxfp4Config()
+ quantizer = Mxfp4HfQuantizer(config)
+
+ expected_keys = [
+ "model.layers.0.mlp.experts.gate_up_proj",
+ "model.layers.0.mlp.experts.down_proj",
+ "model.embed_tokens.weight",
+ ]
+
+ updated_keys = quantizer.update_expected_keys(None, expected_keys, [])
+
+ expected_updated = [
+ "model.layers.0.mlp.experts.gate_up_proj_blocks",
+ "model.layers.0.mlp.experts.gate_up_proj_scales",
+ "model.layers.0.mlp.experts.down_proj_blocks",
+ "model.layers.0.mlp.experts.down_proj_scales",
+ "model.embed_tokens.weight",
+ ]
+
+ self.assertEqual(set(updated_keys), set(expected_updated))
+
+ def test_update_param_name_dequantize(self):
+ """Test parameter name updating when dequantizing"""
+ from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
+
+ config = Mxfp4Config(dequantize=True)
+ quantizer = Mxfp4HfQuantizer(config)
+
+ # Should remove _blocks suffix
+ param_name = "model.layers.0.mlp.experts.gate_up_proj_blocks"
+ updated_name = quantizer.update_param_name(param_name)
+ self.assertEqual(updated_name, "model.layers.0.mlp.experts.gate_up_proj")
+
+ # Should remove _scales suffix
+ param_name = "model.layers.0.mlp.experts.down_proj_scales"
+ updated_name = quantizer.update_param_name(param_name)
+ self.assertEqual(updated_name, "model.layers.0.mlp.experts.down_proj")
+
+ # Should not change other names
+ param_name = "model.embed_tokens.weight"
+ updated_name = quantizer.update_param_name(param_name)
+ self.assertEqual(updated_name, "model.embed_tokens.weight")
+
+ def test_update_param_name_no_dequantize(self):
+ """Test parameter name updating when not dequantizing"""
+ from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
+
+ config = Mxfp4Config(dequantize=False)
+ quantizer = Mxfp4HfQuantizer(config)
+
+ param_name = "model.layers.0.mlp.experts.gate_up_proj_blocks"
+ updated_name = quantizer.update_param_name(param_name)
+ self.assertEqual(updated_name, param_name)
+
+ def test_is_serializable(self):
+ """Test serialization capability"""
+ from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
+
+ config = Mxfp4Config()
+ quantizer = Mxfp4HfQuantizer(config)
+
+ # MXFP4 is not serializable with safetensors
+ self.assertFalse(quantizer.is_serializable())
+
+ def test_is_trainable(self):
+ """Test trainability"""
+ from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
+
+ config = Mxfp4Config()
+ quantizer = Mxfp4HfQuantizer(config)
+
+ # MXFP4 is not trainable
+ self.assertFalse(quantizer.is_trainable)
+
+
+class Mxfp4IntegrationTest(unittest.TestCase):
+ """Test mxfp4 integration functions"""
+
+ def test_should_convert_module(self):
+ """Test module conversion decision logic"""
+ from transformers.integrations.mxfp4 import should_convert_module
+
+ # Should convert by default
+ self.assertTrue(should_convert_module(["model", "layers", "0", "mlp"], []))
+
+ # Should not convert if in exclusion list
+ patterns = ["model.layers.*.self_attn", "lm_head"]
+ self.assertFalse(should_convert_module(["model", "layers", "0", "self_attn"], patterns))
+ self.assertFalse(should_convert_module(["lm_head"], patterns))
+
+ # Should convert if not in exclusion list
+ self.assertTrue(should_convert_module(["model", "layers", "0", "mlp", "experts"], patterns))
+
+ @require_torch
+ def test_convert_moe_packed_tensors(self):
+ """Test unpacking of quantized tensors"""
+ from transformers.integrations.mxfp4 import convert_moe_packed_tensors
+
+ # Create dummy packed tensors
+ blocks = torch.randint(0, 255, (2, 4, 8), dtype=torch.uint8)
+ scales = torch.randint(100, 150, (2, 4), dtype=torch.uint8)
+
+ result = convert_moe_packed_tensors(blocks, scales, dtype=torch.bfloat16)
+
+ # Check output shape - should be [2, 4, 16] (8 * 2 for unpacking)
+ self.assertEqual(result.shape, (2, 4 * 16))
+ self.assertEqual(result.dtype, torch.bfloat16)
+
+ @require_triton(min_version="3.4.0")
+ @require_triton_kernels
+ @require_torch_gpu
+ @require_torch
+ def test_quantize_to_mxfp4(self):
+ """Test quantization function"""
+ from transformers.integrations.mxfp4 import quantize_to_mxfp4
+
+ # Create dummy weight tensor
+ w = torch.randn(32, 64, 128, dtype=torch.bfloat16, device=torch.device("cuda"))
+
+ quantized_w, flex_data, mx_ctx = quantize_to_mxfp4(w, None, None)
+
+ # Check that shapes are reasonable
+ self.assertEqual(quantized_w.dtype, torch.uint8)
+ self.assertIsNotNone(flex_data)
+ self.assertIsNotNone(mx_ctx)
+
+
+@require_torch
+@require_torch_large_gpu
+@slow
+class Mxfp4ModelTest(unittest.TestCase):
+ """Test mxfp4 with actual models (requires specific model and hardware)"""
+
+ # These should be paths to real OpenAI MoE models for proper testing
+ model_name_packed = "/fsx/mohamed/oai-hf/tests/20b_converted_packed" # TODO: Use real packed quantized model
+
+ input_text = "Once upon a time"
+
+ # Expected outputs for generation tests
+ EXPECTED_OUTPUTS = set()
+ EXPECTED_OUTPUTS.add("Once upon a time, in a small village, there lived a young")
+
+ def setUp(self):
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ def tearDown(self):
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ def check_inference_correctness_quantized(self, model, tokenizer):
+ # Check that inference pass works on the model
+ encoded_input = tokenizer(self.input_text, return_tensors="pt").to(model.device)
+
+ # Set pad token if not set
+ if tokenizer.pad_token is None:
+ tokenizer.pad_token = tokenizer.eos_token
+
+ with torch.no_grad():
+ output_sequences = model.generate(
+ **encoded_input,
+ max_new_tokens=10,
+ do_sample=False,
+ pad_token_id=tokenizer.eos_token_id,
+ use_cache=False,
+ )
+
+ generated_text = tokenizer.decode(output_sequences[0], skip_special_tokens=True)
+
+ self.assertIn(generated_text, self.EXPECTED_OUTPUTS)
+
+ def test_gpt_oss_model_loading_quantized_with_device_map(self):
+ """Test loading OpenAI MoE model with mxfp4 quantization and device_map"""
+
+ quantization_config = Mxfp4Config(dequantize=False)
+
+ # Test that config is properly set up
+ self.assertFalse(quantization_config.dequantize)
+
+ model = GptOssForCausalLM.from_pretrained(
+ self.model_name_packed,
+ quantization_config=quantization_config,
+ torch_dtype=torch.bfloat16,
+ device_map="auto",
+ )
+ tokenizer = AutoTokenizer.from_pretrained(self.model_name_packed)
+ self.check_inference_correctness_quantized(model, tokenizer)
+
+ def test_gpt_oss_model_loading_dequantized_with_device_map(self):
+ """Test loading OpenAI MoE model with mxfp4 dequantization and device_map"""
+
+ quantization_config = Mxfp4Config(dequantize=True)
+
+ # Test that config is properly set up
+ self.assertTrue(quantization_config.dequantize)
+
+ model = GptOssForCausalLM.from_pretrained(
+ self.model_name_packed,
+ quantization_config=quantization_config,
+ torch_dtype=torch.bfloat16,
+ device_map="auto",
+ )
+ tokenizer = AutoTokenizer.from_pretrained(self.model_name_packed)
+ self.check_inference_correctness_quantized(model, tokenizer)
+
+ def test_model_device_map_validation(self):
+ """Test device map validation"""
+ from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
+
+ config = Mxfp4Config()
+ quantizer = Mxfp4HfQuantizer(config)
+ quantizer.pre_quantized = False
+
+ # Test with CPU in device map (should raise error for non-pre-quantized)
+ with self.assertRaises(ValueError):
+ quantizer.validate_environment(device_map={"": "cpu"})
+
+ def test_memory_footprint_comparison(self):
+ """Test memory footprint differences between quantized and unquantized models"""
+
+ # Expected: quantized < dequantized < unquantized memory usage
+ quantization_config = Mxfp4Config(dequantize=True)
+ quantized_model = GptOssForCausalLM.from_pretrained(
+ self.model_name_packed,
+ torch_dtype=torch.bfloat16,
+ device_map="auto",
+ )
+ dequantized_model = GptOssForCausalLM.from_pretrained(
+ self.model_name_packed,
+ torch_dtype=torch.bfloat16,
+ device_map="auto",
+ quantization_config=quantization_config,
+ )
+ quantized_mem = quantized_model.get_memory_footprint()
+ dequantized_mem = dequantized_model.get_memory_footprint()
+ self.assertLess(quantized_mem, dequantized_mem)
diff --git a/utils/check_docstrings.py b/utils/check_docstrings.py
index a9ca4f1d56f2..23ba44958dcb 100644
--- a/utils/check_docstrings.py
+++ b/utils/check_docstrings.py
@@ -79,6 +79,7 @@
# docstrings instead. If formatting should be ignored for the docstring, you can put a comment # no-format on the
# line before the docstring.
OBJECTS_TO_IGNORE = [
+ "Mxfp4Config",
"Exaone4Config",
"SmolLM3Config",
"Gemma3nVisionConfig",