Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 1 addition & 23 deletions tensorrt_llm/_torch/models/modeling_deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,12 @@
from transformers import PretrainedConfig

from tensorrt_llm._ipc_utils import can_access_peer
from tensorrt_llm._utils import get_sm_version, is_sm_100f
from tensorrt_llm._utils import get_sm_version
from tensorrt_llm.functional import PositionEmbeddingType
from tensorrt_llm.llmapi.utils import enable_llm_debug
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.modeling_utils import QuantConfig
from tensorrt_llm.quantization.mode import QuantAlgo
from tensorrt_llm.quantization.utils.fp8_utils import (
resmooth_to_fp8_e8m0, transform_sf_into_required_layout)

from ..attention_backend import AttentionMetadata
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
Expand Down Expand Up @@ -1528,26 +1526,6 @@ def load_weights(self, weights: Dict):
weight_loader.load_weights(weights)

def post_load_weights(self):
all_named_modules = dict(self.model.named_modules())
for name, module in tqdm(all_named_modules.items(),
desc="Post loading weights"):
if len(module._parameters) <= 0 or name.startswith("draft_model"):
continue
else:
if self.model_config.quant_config.layer_quant_mode.has_fp8_block_scales(
) and is_sm_100f() and hasattr(module, "weight_scale"):
weight, weight_scale = resmooth_to_fp8_e8m0(
module.weight, module.weight_scale)
transfromed_scale = transform_sf_into_required_layout(
weight_scale,
mn=weight.shape[0],
k=weight.shape[1],
recipe=(1, 128, 128),
is_sfa=False)
module.weight = nn.Parameter(weight, requires_grad=False)
module.weight_scale = nn.Parameter(transfromed_scale,
requires_grad=False)

for idx, layer in enumerate(
self.model.layers[:self.config.num_hidden_layers]):
if idx == self.config.num_hidden_layers - 1:
Expand Down
1 change: 1 addition & 0 deletions tensorrt_llm/_torch/models/modeling_gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,7 @@ def load_weights(self, weights: Dict):
else:
self.load_hf_weights(weights)

def post_load_weights(self):
for idx, layer in enumerate(
self.model.block[:self.config.num_hidden_layers]):
if idx == 0:
Expand Down
5 changes: 2 additions & 3 deletions tensorrt_llm/_torch/models/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,9 +979,7 @@ def __init__(
):
super().__init__(LlamaModel(model_config), model_config)

def load_weights(self, weights: Dict):
super().load_weights(weights)

def post_load_weights(self):
for idx, layer in enumerate(
self.model.layers[:self.config.num_hidden_layers]):
if idx == self.config.num_hidden_layers - 1:
Expand Down Expand Up @@ -1307,6 +1305,7 @@ def load_weights(self, weights: Dict, weight_mapper: BaseWeightMapper):
if had_mm_encoder:
self.mm_encoder = saved_mm_encoder

def post_load_weights(self):
for idx, layer in enumerate(
self.model.layers[:self.config.num_hidden_layers]):
if idx == self.config.num_hidden_layers - 1:
Expand Down
25 changes: 0 additions & 25 deletions tensorrt_llm/_torch/models/modeling_qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,9 @@

import torch
from torch import nn
from tqdm import tqdm
from transformers import Qwen3Config

from tensorrt_llm._utils import is_sm_100f
from tensorrt_llm.functional import PositionEmbeddingType
from tensorrt_llm.quantization.utils.fp8_utils import (
resmooth_to_fp8_e8m0, transform_sf_into_required_layout)

from ..attention_backend import AttentionMetadata
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
Expand Down Expand Up @@ -217,24 +213,3 @@ def __init__(
Qwen3Model(model_config),
model_config,
)

def post_load_weights(self):
all_named_modules = dict(self.model.named_modules())
for name, module in tqdm(all_named_modules.items(),
desc="Post loading weights"):
if len(module._parameters) <= 0 or name.startswith("draft_model"):
continue
else:
if self.model_config.quant_config.layer_quant_mode.has_fp8_block_scales(
) and is_sm_100f() and hasattr(module, "weight_scale"):
weight, weight_scale = resmooth_to_fp8_e8m0(
module.weight, module.weight_scale)
transfromed_scale = transform_sf_into_required_layout(
weight_scale,
mn=weight.shape[0],
k=weight.shape[1],
recipe=(1, 128, 128),
is_sfa=False)
module.weight = nn.Parameter(weight, requires_grad=False)
module.weight_scale = nn.Parameter(transfromed_scale,
requires_grad=False)
6 changes: 1 addition & 5 deletions tensorrt_llm/_torch/models/modeling_qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from transformers import Qwen3MoeConfig

from tensorrt_llm._ipc_utils import can_access_peer
from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \
BaseWeightMapper

from ..attention_backend import AttentionMetadata
from ..distributed import (AllReduce, AllReduceFusionOp, AllReduceParams,
Expand Down Expand Up @@ -390,9 +388,7 @@ def __init__(
)
self.preload_weight_modules = self.model.preload_weight_modules

def load_weights(self, weights: dict, weight_mapper: BaseWeightMapper):
super().load_weights(weights, weight_mapper)

def post_load_weights(self):
for idx, layer in enumerate(
self.model.layers[:self.config.num_hidden_layers]):
if idx == self.config.num_hidden_layers - 1:
Expand Down
1 change: 1 addition & 0 deletions tensorrt_llm/_torch/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def remove_weights(
for mod in iter_modules(module, ignore_modules):
mod._parameters.clear()
mod._buffers.clear()
mod._weights_removed = True


def skip_forward(
Expand Down
10 changes: 5 additions & 5 deletions tensorrt_llm/_torch/modules/fused_moe/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import torch.nn.functional as F
from torch import nn

import tensorrt_llm.logger as trtllm_logger
from tensorrt_llm._utils import get_sm_version, is_sm_100f
from tensorrt_llm.logger import logger
from tensorrt_llm.quantization.functional import \
preprocess_weights_for_mixed_gemm
from tensorrt_llm.quantization.utils.fp4_utils import (
Expand Down Expand Up @@ -271,8 +271,6 @@ def load_weights(self, module: torch.nn.Module, weights: List[Dict],
module.w2_bias.data if module.bias else None)

self.load_quant_scales(module, weights)
# Re-setup quant scales after loading weights as the tensors may have been modified.
self.setup_quant_scales(module)

if self.need_load_shared_weights(module):
local_shared_load_expert_ids = module.layer_load_balancer.get_load_expert_ids(
Expand Down Expand Up @@ -323,7 +321,8 @@ def load_weights(self, module: torch.nn.Module, weights: List[Dict],
module.initial_global_assignments)

def post_load_weights(self, module: torch.nn.Module):
pass
# Re-setup quant scales after loading weights as the tensors may have been modified.
self.setup_quant_scales(module)

def load_quant_scales(self, module: torch.nn.Module, weights: List[Dict]):
pass
Expand Down Expand Up @@ -722,14 +721,15 @@ def load_weights(self, module: torch.nn.Module, weights: List[Dict],
if int(name.split(".")[0]) not in expert_ids:
continue
weight_name = name.replace("weight_scale_inv", "weight")
trtllm_logger.logger.debug(f"Resmoothing {weight_name}")
logger.debug(f"Resmoothing {weight_name}")
weight = weights[weight_name][:]
scale = weights[name][:]
weights[weight_name], weights[name] = resmooth_to_fp8_e8m0(
weight, scale)
super().load_weights(module, weights, weight_loading_mode)

def post_load_weights(self, module: torch.nn.Module):
super().post_load_weights(module)
if is_sm_100f():
transfromed_w3_w1_scale = transform_sf_into_required_layout(
module.quant_scales[0],
Expand Down
20 changes: 20 additions & 0 deletions tensorrt_llm/_torch/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from tensorrt_llm.quantization.functional import \
preprocess_weights_for_mixed_gemm
from tensorrt_llm.quantization.mode import QuantAlgo
from tensorrt_llm.quantization.utils.fp8_utils import (
resmooth_to_fp8_e8m0, transform_sf_into_required_layout)

from ..._utils import is_sm_100f
from ...models.modeling_utils import QuantConfig
Expand Down Expand Up @@ -715,6 +717,24 @@ def load_weights_fused_gate_up_linear(self, module: Linear,
copy_weight(module.weight, fused_weight)
copy_weight(module.weight_scale, fused_scale)

def post_load_weights(self, module: Linear):
super().post_load_weights(module)
if is_sm_100f() and not (module.use_cute_dsl_blockscaling_mm
or module.disable_deep_gemm):
weight, weight_scale = resmooth_to_fp8_e8m0(module.weight,
module.weight_scale)
transfromed_scale = transform_sf_into_required_layout(
weight_scale,
mn=weight.shape[0],
k=weight.shape[1],
recipe=(1, 128, 128),
is_sfa=False)
module.weight = nn.Parameter(weight, requires_grad=False)
module.weight_scale = nn.Parameter(
transfromed_scale,
requires_grad=False,
)


class NVFP4LinearMethod(LinearMethodBase):

Expand Down
3 changes: 2 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,8 @@ def init_meta_tensor(t: torch.Tensor):
f"No load support for load format: {load_format}")

for module in model.modules():
if hasattr(module, 'post_load_weights'):
if hasattr(module, 'post_load_weights') and not getattr(
module, '_weights_removed', False):
module.post_load_weights()

if isinstance(moe_load_balancer, MoeLoadBalancer):
Expand Down
6 changes: 4 additions & 2 deletions tests/integration/defs/accuracy/accuracy_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ def evaluate(self,
extra_acc_spec: Optional[str] = None,
extra_evaluator_kwargs: Optional[dict] = None,
sampling_params: Optional[SamplingParams] = None,
streaming: bool = False):
streaming: bool = False,
is_integration_test: bool = False):
assert self.EVALUATOR_CLS is not None

if llm.args.speculative_config is None:
Expand All @@ -199,7 +200,8 @@ def evaluate(self,
raise ValueError(
f"Not recognized speculative_config: {llm.args.speculative_config}."
)
is_integration_test = os.getenv('INTEGRATION_TEST', '0') == '1'
is_integration_test = is_integration_test or os.getenv(
'INTEGRATION_TEST', '0') == '1'

if is_integration_test:
logger.info(
Expand Down
49 changes: 49 additions & 0 deletions tests/integration/defs/accuracy/test_llm_api_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,13 @@ def test_chunked_prefill(self, attn_backend):
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)

@pytest.mark.skip_less_device_memory(32000)
def test_dummy_load_format(self):
llm = LLM(self.MODEL_PATH, load_format="dummy")
with llm:
task = MMLU(self.MODEL_NAME)
task.evaluate(llm, is_integration_test=True)

@pytest.mark.skip_less_device_memory(32000)
@parametrize_with_ids("torch_compile", [False, True])
@parametrize_with_ids("attn_backend", ["TRTLLM", "FLASHINFER"])
Expand Down Expand Up @@ -1894,6 +1901,18 @@ def test_guided_decoding_4gpus(self, backend: str, mtp_nextn: int, mocker):
task = JsonModeEval(self.MODEL_NAME)
task.evaluate(llm)

@skip_pre_hopper
def test_dummy_load_format(self):
llm = LLM(
f"{llm_models_root()}/DeepSeek-V3-Lite/fp8",
load_format="dummy",
moe_config=MoeConfig(
backend="DEEPGEMM" if get_sm_version() >= 100 else "CUTLASS"),
)
with llm:
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm, is_integration_test=True)


@pytest.mark.timeout(7200)
@pytest.mark.skip_less_device_memory(80000)
Expand Down Expand Up @@ -2641,6 +2660,16 @@ def test_fp8_block_scales(self, tp_size, pp_size, ep_size, attention_dp,
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)

@skip_pre_hopper
def test_dummy_load_format(self):
llm = LLM(
f"{llm_models_root()}/Qwen3/Qwen3-8B-FP8",
load_format="dummy",
)
with llm:
task = MMLU(self.MODEL_NAME)
task.evaluate(llm, is_integration_test=True)

@pytest.mark.parametrize(
"tp_size,pp_size,ep_size,attention_dp,cuda_graph,overlap_scheduler,is_cached",
[(1, 1, 1, False, True, True, True),
Expand Down Expand Up @@ -2759,6 +2788,16 @@ def test_fp8_block_scales(self, tp_size, pp_size, ep_size, attention_dp,
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)

@skip_pre_hopper
def test_dummy_load_format(self):
llm = LLM(
f"{llm_models_root()}/Qwen3/Qwen3-30B-A3B-FP8",
load_format="dummy",
)
with llm:
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm, is_integration_test=True)

@skip_pre_hopper
@parametrize_with_ids("torch_compile", [False, True])
@pytest.mark.parametrize(
Expand Down Expand Up @@ -3295,6 +3334,16 @@ def test_w4_1gpu(self, kv_cache_dtype, moe_backend, cuda_graph,
task.evaluate(llm,
extra_evaluator_kwargs=self.extra_evaluator_kwargs)

def test_dummy_load_format(self):
llm = LLM(
f"{llm_models_root()}/gpt_oss/gpt-oss-20b",
load_format="dummy",
)
with llm:
model_name = "GPT-OSS/MXFP4"
task = GSM8K(model_name)
task.evaluate(llm, is_integration_test=True)

@pytest.mark.skip_less_device(4)
@pytest.mark.parametrize(
"kv_cache_dtype",
Expand Down
2 changes: 2 additions & 0 deletions tests/integration/test_lists/test-db/l0_b200.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,12 @@ l0_b200:
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_chunked_prefill[quant_dtype=nvfp4-kv_cache_reuse=True-fp8kv=False-overlap_scheduler=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_chunked_prefill[quant_dtype=nvfp4-kv_cache_reuse=True-fp8kv=True-overlap_scheduler=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=2]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_dummy_load_format
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-cutlass-auto]
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-trtllm-auto]
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-trtllm-fp8]
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-triton-auto]
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_dummy_load_format
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency] # Cover nvbugs 5461712 and 5505402
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_w4a8_mxfp4[fp8-latency]
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_w4a8_mxfp4[mxfp8-latency]
Expand Down
4 changes: 4 additions & 0 deletions tests/integration/test_lists/test-db/l0_h100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ l0_h100:
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16[attn_backend=TRTLLM-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16[attn_backend=TRTLLM-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_chunked_prefill[attn_backend=TRTLLM] TIMEOUT (90)
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_dummy_load_format
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[xgrammar]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[xgrammar-eagle3_one_model=True]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[xgrammar-eagle3_one_model=False]
Expand All @@ -59,9 +60,12 @@ l0_h100:
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_chunked_prefill[quant_dtype=none-kv_cache_reuse=True-fp8kv=False-overlap_scheduler=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_chunked_prefill[quant_dtype=none-kv_cache_reuse=False-fp8kv=False-overlap_scheduler=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_chunked_prefill[quant_dtype=fp8-kv_cache_reuse=True-fp8kv=True-overlap_scheduler=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_dummy_load_format
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency]
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_dummy_load_format
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8[latency-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8[latency-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_dummy_load_format
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_eagle3[enable_chunked_prefill=False-eagle3_one_model=False]
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_eagle3[enable_chunked_prefill=True-eagle3_one_model=True]
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_eagle3[enable_chunked_prefill=False-eagle3_one_model=True]
Expand Down
8 changes: 6 additions & 2 deletions tests/unittest/_torch/modeling/test_modeling_gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import shutil

import pytest
from transformers import AutoTokenizer
from utils.llm_data import llm_models_root

from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm._torch.modules.fused_moe.fused_moe_triton import \
Expand Down Expand Up @@ -51,8 +53,6 @@ def test_gpt_oss_trtllmgen(moe_backend):
if moe_backend == "TRITON" and not IS_TRITON_KERNELS_AVAILABLE:
pytest.skip("Triton kernels are not available")

pytest.skip("https://nvbugspro.nvidia.com/bug/5441721")

prompts = [
"How are you?",
"Hello, my name is",
Expand All @@ -73,7 +73,11 @@ def test_gpt_oss_trtllmgen(moe_backend):

dump_config_json(tmp_model_dir)

tokenizer = AutoTokenizer.from_pretrained(
f"{llm_models_root()}/gpt_oss/gpt-oss-20b")

llm = LLM(model=tmp_model_dir,
tokenizer=tokenizer,
tensor_parallel_size=1,
enable_chunked_prefill=False,
**pytorch_config,
Expand Down
Loading