From b0d8ea1050fa3acfe2fa544c5a885f25da3274cb Mon Sep 17 00:00:00 2001 From: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Date: Mon, 29 Sep 2025 09:25:11 +0000 Subject: [PATCH 1/2] Fix dummy load format for key models. Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> --- .../_torch/models/modeling_deepseekv3.py | 24 +-------- .../_torch/models/modeling_gpt_oss.py | 1 + tensorrt_llm/_torch/models/modeling_llama.py | 5 +- .../_torch/models/modeling_qwen3_moe.py | 6 +-- tensorrt_llm/_torch/models/modeling_utils.py | 1 + .../_torch/modules/fused_moe/quantization.py | 10 ++-- tensorrt_llm/_torch/modules/linear.py | 19 +++++++ .../_torch/pyexecutor/model_loader.py | 3 +- .../defs/accuracy/accuracy_core.py | 6 ++- .../defs/accuracy/test_llm_api_pytorch.py | 49 +++++++++++++++++++ .../test_lists/test-db/l0_b200.yml | 2 + .../test_lists/test-db/l0_h100.yml | 4 ++ .../_torch/modeling/test_modeling_gpt_oss.py | 8 ++- .../_torch/modeling/test_modeling_llama.py | 1 + 14 files changed, 98 insertions(+), 41 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index c49064bfe38..2c66850d418 100755 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -40,14 +40,12 @@ from transformers import PretrainedConfig from tensorrt_llm._ipc_utils import can_access_peer -from tensorrt_llm._utils import get_sm_version, is_sm_100f +from tensorrt_llm._utils import get_sm_version from tensorrt_llm.functional import PositionEmbeddingType from tensorrt_llm.llmapi.utils import enable_llm_debug from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.modeling_utils import QuantConfig from tensorrt_llm.quantization.mode import QuantAlgo -from tensorrt_llm.quantization.utils.fp8_utils import ( - resmooth_to_fp8_e8m0, transform_sf_into_required_layout) from ..attention_backend import AttentionMetadata from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams @@ -1528,26 +1526,6 @@ def load_weights(self, weights: Dict): weight_loader.load_weights(weights) def post_load_weights(self): - all_named_modules = dict(self.model.named_modules()) - for name, module in tqdm(all_named_modules.items(), - desc="Post loading weights"): - if len(module._parameters) <= 0 or name.startswith("draft_model"): - continue - else: - if self.model_config.quant_config.layer_quant_mode.has_fp8_block_scales( - ) and is_sm_100f() and hasattr(module, "weight_scale"): - weight, weight_scale = resmooth_to_fp8_e8m0( - module.weight, module.weight_scale) - transfromed_scale = transform_sf_into_required_layout( - weight_scale, - mn=weight.shape[0], - k=weight.shape[1], - recipe=(1, 128, 128), - is_sfa=False) - module.weight = nn.Parameter(weight, requires_grad=False) - module.weight_scale = nn.Parameter(transfromed_scale, - requires_grad=False) - for idx, layer in enumerate( self.model.layers[:self.config.num_hidden_layers]): if idx == self.config.num_hidden_layers - 1: diff --git a/tensorrt_llm/_torch/models/modeling_gpt_oss.py b/tensorrt_llm/_torch/models/modeling_gpt_oss.py index 27b621ad27c..3079c15d53d 100644 --- a/tensorrt_llm/_torch/models/modeling_gpt_oss.py +++ b/tensorrt_llm/_torch/models/modeling_gpt_oss.py @@ -602,6 +602,7 @@ def load_weights(self, weights: Dict): else: self.load_hf_weights(weights) + def post_load_weights(self): for idx, layer in enumerate( self.model.block[:self.config.num_hidden_layers]): if idx == 0: diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index 5dc193074b6..88c327999e8 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -979,9 +979,7 @@ def __init__( ): super().__init__(LlamaModel(model_config), model_config) - def load_weights(self, weights: Dict): - super().load_weights(weights) - + def post_load_weights(self): for idx, layer in enumerate( self.model.layers[:self.config.num_hidden_layers]): if idx == self.config.num_hidden_layers - 1: @@ -1307,6 +1305,7 @@ def load_weights(self, weights: Dict, weight_mapper: BaseWeightMapper): if had_mm_encoder: self.mm_encoder = saved_mm_encoder + def post_load_weights(self): for idx, layer in enumerate( self.model.layers[:self.config.num_hidden_layers]): if idx == self.config.num_hidden_layers - 1: diff --git a/tensorrt_llm/_torch/models/modeling_qwen3_moe.py b/tensorrt_llm/_torch/models/modeling_qwen3_moe.py index 4619def65f2..80d1478833c 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen3_moe.py +++ b/tensorrt_llm/_torch/models/modeling_qwen3_moe.py @@ -6,8 +6,6 @@ from transformers import Qwen3MoeConfig from tensorrt_llm._ipc_utils import can_access_peer -from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \ - BaseWeightMapper from ..attention_backend import AttentionMetadata from ..distributed import (AllReduce, AllReduceFusionOp, AllReduceParams, @@ -390,9 +388,7 @@ def __init__( ) self.preload_weight_modules = self.model.preload_weight_modules - def load_weights(self, weights: dict, weight_mapper: BaseWeightMapper): - super().load_weights(weights, weight_mapper) - + def post_load_weights(self): for idx, layer in enumerate( self.model.layers[:self.config.num_hidden_layers]): if idx == self.config.num_hidden_layers - 1: diff --git a/tensorrt_llm/_torch/models/modeling_utils.py b/tensorrt_llm/_torch/models/modeling_utils.py index c30ffcb89c1..bdc8d819611 100755 --- a/tensorrt_llm/_torch/models/modeling_utils.py +++ b/tensorrt_llm/_torch/models/modeling_utils.py @@ -143,6 +143,7 @@ def remove_weights( for mod in iter_modules(module, ignore_modules): mod._parameters.clear() mod._buffers.clear() + mod._weights_removed = True def skip_forward( diff --git a/tensorrt_llm/_torch/modules/fused_moe/quantization.py b/tensorrt_llm/_torch/modules/fused_moe/quantization.py index 689edfae35f..7ccab06549d 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/quantization.py +++ b/tensorrt_llm/_torch/modules/fused_moe/quantization.py @@ -6,8 +6,8 @@ import torch.nn.functional as F from torch import nn -import tensorrt_llm.logger as trtllm_logger from tensorrt_llm._utils import get_sm_version, is_sm_100f +from tensorrt_llm.logger import logger from tensorrt_llm.quantization.functional import \ preprocess_weights_for_mixed_gemm from tensorrt_llm.quantization.utils.fp4_utils import ( @@ -271,8 +271,6 @@ def load_weights(self, module: torch.nn.Module, weights: List[Dict], module.w2_bias.data if module.bias else None) self.load_quant_scales(module, weights) - # Re-setup quant scales after loading weights as the tensors may have been modified. - self.setup_quant_scales(module) if self.need_load_shared_weights(module): local_shared_load_expert_ids = module.layer_load_balancer.get_load_expert_ids( @@ -323,7 +321,8 @@ def load_weights(self, module: torch.nn.Module, weights: List[Dict], module.initial_global_assignments) def post_load_weights(self, module: torch.nn.Module): - pass + # Re-setup quant scales after loading weights as the tensors may have been modified. + self.setup_quant_scales(module) def load_quant_scales(self, module: torch.nn.Module, weights: List[Dict]): pass @@ -722,7 +721,7 @@ def load_weights(self, module: torch.nn.Module, weights: List[Dict], if int(name.split(".")[0]) not in expert_ids: continue weight_name = name.replace("weight_scale_inv", "weight") - trtllm_logger.logger.debug(f"Resmoothing {weight_name}") + logger.debug(f"Resmoothing {weight_name}") weight = weights[weight_name][:] scale = weights[name][:] weights[weight_name], weights[name] = resmooth_to_fp8_e8m0( @@ -730,6 +729,7 @@ def load_weights(self, module: torch.nn.Module, weights: List[Dict], super().load_weights(module, weights, weight_loading_mode) def post_load_weights(self, module: torch.nn.Module): + super().post_load_weights(module) if is_sm_100f(): transfromed_w3_w1_scale = transform_sf_into_required_layout( module.quant_scales[0], diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index 0a5d4b5886b..9891436ceb4 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -20,6 +20,8 @@ from tensorrt_llm.quantization.functional import \ preprocess_weights_for_mixed_gemm from tensorrt_llm.quantization.mode import QuantAlgo +from tensorrt_llm.quantization.utils.fp8_utils import ( + resmooth_to_fp8_e8m0, transform_sf_into_required_layout) from ..._utils import is_sm_100f from ...models.modeling_utils import QuantConfig @@ -715,6 +717,23 @@ def load_weights_fused_gate_up_linear(self, module: Linear, copy_weight(module.weight, fused_weight) copy_weight(module.weight_scale, fused_scale) + def post_load_weights(self, module: Linear): + super().post_load_weights(module) + if is_sm_100f(): + weight, weight_scale = resmooth_to_fp8_e8m0(module.weight, + module.weight_scale) + transfromed_scale = transform_sf_into_required_layout( + weight_scale, + mn=weight.shape[0], + k=weight.shape[1], + recipe=(1, 128, 128), + is_sfa=False) + module.weight = nn.Parameter(weight, requires_grad=False) + module.weight_scale = nn.Parameter( + transfromed_scale, + requires_grad=False, + ) + class NVFP4LinearMethod(LinearMethodBase): diff --git a/tensorrt_llm/_torch/pyexecutor/model_loader.py b/tensorrt_llm/_torch/pyexecutor/model_loader.py index e3d9cfc5410..2909b29cac4 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_loader.py +++ b/tensorrt_llm/_torch/pyexecutor/model_loader.py @@ -267,7 +267,8 @@ def init_meta_tensor(t: torch.Tensor): f"No load support for load format: {load_format}") for module in model.modules(): - if hasattr(module, 'post_load_weights'): + if hasattr(module, 'post_load_weights') and not getattr( + module, '_weights_removed', False): module.post_load_weights() if isinstance(moe_load_balancer, MoeLoadBalancer): diff --git a/tests/integration/defs/accuracy/accuracy_core.py b/tests/integration/defs/accuracy/accuracy_core.py index 546fa274e9f..9907692a9ef 100644 --- a/tests/integration/defs/accuracy/accuracy_core.py +++ b/tests/integration/defs/accuracy/accuracy_core.py @@ -186,7 +186,8 @@ def evaluate(self, extra_acc_spec: Optional[str] = None, extra_evaluator_kwargs: Optional[dict] = None, sampling_params: Optional[SamplingParams] = None, - streaming: bool = False): + streaming: bool = False, + is_integration_test: bool = False): assert self.EVALUATOR_CLS is not None if llm.args.speculative_config is None: @@ -199,7 +200,8 @@ def evaluate(self, raise ValueError( f"Not recognized speculative_config: {llm.args.speculative_config}." ) - is_integration_test = os.getenv('INTEGRATION_TEST', '0') == '1' + is_integration_test = is_integration_test or os.getenv( + 'INTEGRATION_TEST', '0') == '1' if is_integration_test: logger.info( diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 8c609c7ca89..baac4a141bf 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -87,6 +87,13 @@ def test_chunked_prefill(self, attn_backend): task = MMLU(self.MODEL_NAME) task.evaluate(llm) + @pytest.mark.skip_less_device_memory(32000) + def test_dummy_load_format(self): + llm = LLM(self.MODEL_PATH, load_format="dummy") + with llm: + task = MMLU(self.MODEL_NAME) + task.evaluate(llm, is_integration_test=True) + @pytest.mark.skip_less_device_memory(32000) @parametrize_with_ids("torch_compile", [False, True]) @parametrize_with_ids("attn_backend", ["TRTLLM", "FLASHINFER"]) @@ -1894,6 +1901,18 @@ def test_guided_decoding_4gpus(self, backend: str, mtp_nextn: int, mocker): task = JsonModeEval(self.MODEL_NAME) task.evaluate(llm) + @skip_pre_hopper + def test_dummy_load_format(self): + llm = LLM( + f"{llm_models_root()}/DeepSeek-V3-Lite/fp8", + load_format="dummy", + moe_config=MoeConfig( + backend="DEEPGEMM" if get_sm_version() >= 100 else "CUTLASS"), + ) + with llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm, is_integration_test=True) + @pytest.mark.timeout(7200) @pytest.mark.skip_less_device_memory(80000) @@ -2641,6 +2660,16 @@ def test_fp8_block_scales(self, tp_size, pp_size, ep_size, attention_dp, task = MMLU(self.MODEL_NAME) task.evaluate(llm) + @skip_pre_hopper + def test_dummy_load_format(self): + llm = LLM( + f"{llm_models_root()}/Qwen3/Qwen3-8B-FP8", + load_format="dummy", + ) + with llm: + task = MMLU(self.MODEL_NAME) + task.evaluate(llm, is_integration_test=True) + @pytest.mark.parametrize( "tp_size,pp_size,ep_size,attention_dp,cuda_graph,overlap_scheduler,is_cached", [(1, 1, 1, False, True, True, True), @@ -2759,6 +2788,16 @@ def test_fp8_block_scales(self, tp_size, pp_size, ep_size, attention_dp, task = GSM8K(self.MODEL_NAME) task.evaluate(llm) + @skip_pre_hopper + def test_dummy_load_format(self): + llm = LLM( + f"{llm_models_root()}/Qwen3/Qwen3-30B-A3B-FP8", + load_format="dummy", + ) + with llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm, is_integration_test=True) + @skip_pre_hopper @parametrize_with_ids("torch_compile", [False, True]) @pytest.mark.parametrize( @@ -3295,6 +3334,16 @@ def test_w4_1gpu(self, kv_cache_dtype, moe_backend, cuda_graph, task.evaluate(llm, extra_evaluator_kwargs=self.extra_evaluator_kwargs) + def test_dummy_load_format(self): + llm = LLM( + f"{llm_models_root()}/gpt_oss/gpt-oss-20b", + load_format="dummy", + ) + with llm: + model_name = "GPT-OSS/MXFP4" + task = GSM8K(model_name) + task.evaluate(llm, is_integration_test=True) + @pytest.mark.skip_less_device(4) @pytest.mark.parametrize( "kv_cache_dtype", diff --git a/tests/integration/test_lists/test-db/l0_b200.yml b/tests/integration/test_lists/test-db/l0_b200.yml index b435b61e423..899166666de 100644 --- a/tests/integration/test_lists/test-db/l0_b200.yml +++ b/tests/integration/test_lists/test-db/l0_b200.yml @@ -34,10 +34,12 @@ l0_b200: - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_chunked_prefill[quant_dtype=nvfp4-kv_cache_reuse=True-fp8kv=False-overlap_scheduler=True] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_chunked_prefill[quant_dtype=nvfp4-kv_cache_reuse=True-fp8kv=True-overlap_scheduler=True] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=2] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_dummy_load_format - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-cutlass-auto] - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-trtllm-auto] - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-trtllm-fp8] - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-triton-auto] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_dummy_load_format - accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency] # Cover nvbugs 5461712 and 5505402 - accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_w4a8_mxfp4[fp8-latency] - accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_w4a8_mxfp4[mxfp8-latency] diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index 586161fa15b..d9c738550dd 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -42,6 +42,7 @@ l0_h100: - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16[attn_backend=TRTLLM-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16[attn_backend=TRTLLM-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_chunked_prefill[attn_backend=TRTLLM] TIMEOUT (90) + - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_dummy_load_format - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[xgrammar] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[xgrammar-eagle3_one_model=True] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[xgrammar-eagle3_one_model=False] @@ -59,9 +60,12 @@ l0_h100: - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_chunked_prefill[quant_dtype=none-kv_cache_reuse=True-fp8kv=False-overlap_scheduler=True] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_chunked_prefill[quant_dtype=none-kv_cache_reuse=False-fp8kv=False-overlap_scheduler=True] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_chunked_prefill[quant_dtype=fp8-kv_cache_reuse=True-fp8kv=True-overlap_scheduler=True] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_dummy_load_format - accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency] + - accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_dummy_load_format - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8[latency-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8[latency-torch_compile=True] + - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_dummy_load_format - accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_eagle3[enable_chunked_prefill=False-eagle3_one_model=False] - accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_eagle3[enable_chunked_prefill=True-eagle3_one_model=True] - accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_eagle3[enable_chunked_prefill=False-eagle3_one_model=True] diff --git a/tests/unittest/_torch/modeling/test_modeling_gpt_oss.py b/tests/unittest/_torch/modeling/test_modeling_gpt_oss.py index 43de8f81b5b..0e2b9745931 100644 --- a/tests/unittest/_torch/modeling/test_modeling_gpt_oss.py +++ b/tests/unittest/_torch/modeling/test_modeling_gpt_oss.py @@ -3,6 +3,8 @@ import shutil import pytest +from transformers import AutoTokenizer +from utils.llm_data import llm_models_root from tensorrt_llm import LLM, SamplingParams from tensorrt_llm._torch.modules.fused_moe.fused_moe_triton import \ @@ -51,8 +53,6 @@ def test_gpt_oss_trtllmgen(moe_backend): if moe_backend == "TRITON" and not IS_TRITON_KERNELS_AVAILABLE: pytest.skip("Triton kernels are not available") - pytest.skip("https://nvbugspro.nvidia.com/bug/5441721") - prompts = [ "How are you?", "Hello, my name is", @@ -73,7 +73,11 @@ def test_gpt_oss_trtllmgen(moe_backend): dump_config_json(tmp_model_dir) + tokenizer = AutoTokenizer.from_pretrained( + f"{llm_models_root()}/gpt_oss/gpt-oss-20b") + llm = LLM(model=tmp_model_dir, + tokenizer=tokenizer, tensor_parallel_size=1, enable_chunked_prefill=False, **pytorch_config, diff --git a/tests/unittest/_torch/modeling/test_modeling_llama.py b/tests/unittest/_torch/modeling/test_modeling_llama.py index 8de665741d8..b2a2bd0a181 100644 --- a/tests/unittest/_torch/modeling/test_modeling_llama.py +++ b/tests/unittest/_torch/modeling/test_modeling_llama.py @@ -226,6 +226,7 @@ def test_llama_allclose_to_hf(self, scenario: Scenario) -> None: llama = LlamaForCausalLM(model_config).to(dtype).to(device) llama.load_weights(hf_llama.state_dict()) + llama.post_load_weights() num_blocks = 1 tokens_per_block = 128 From 34a55e1366c642eb256cac5fc168bede7f364e37 Mon Sep 17 00:00:00 2001 From: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Date: Tue, 30 Sep 2025 08:16:14 +0000 Subject: [PATCH 2/2] Fix rebase error. Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> --- tensorrt_llm/_torch/models/modeling_qwen3.py | 25 -------------------- tensorrt_llm/_torch/modules/linear.py | 3 ++- 2 files changed, 2 insertions(+), 26 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_qwen3.py b/tensorrt_llm/_torch/models/modeling_qwen3.py index d93423a45c4..ce860ecd9b3 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen3.py +++ b/tensorrt_llm/_torch/models/modeling_qwen3.py @@ -2,13 +2,9 @@ import torch from torch import nn -from tqdm import tqdm from transformers import Qwen3Config -from tensorrt_llm._utils import is_sm_100f from tensorrt_llm.functional import PositionEmbeddingType -from tensorrt_llm.quantization.utils.fp8_utils import ( - resmooth_to_fp8_e8m0, transform_sf_into_required_layout) from ..attention_backend import AttentionMetadata from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams @@ -217,24 +213,3 @@ def __init__( Qwen3Model(model_config), model_config, ) - - def post_load_weights(self): - all_named_modules = dict(self.model.named_modules()) - for name, module in tqdm(all_named_modules.items(), - desc="Post loading weights"): - if len(module._parameters) <= 0 or name.startswith("draft_model"): - continue - else: - if self.model_config.quant_config.layer_quant_mode.has_fp8_block_scales( - ) and is_sm_100f() and hasattr(module, "weight_scale"): - weight, weight_scale = resmooth_to_fp8_e8m0( - module.weight, module.weight_scale) - transfromed_scale = transform_sf_into_required_layout( - weight_scale, - mn=weight.shape[0], - k=weight.shape[1], - recipe=(1, 128, 128), - is_sfa=False) - module.weight = nn.Parameter(weight, requires_grad=False) - module.weight_scale = nn.Parameter(transfromed_scale, - requires_grad=False) diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index 9891436ceb4..ebea00b4707 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -719,7 +719,8 @@ def load_weights_fused_gate_up_linear(self, module: Linear, def post_load_weights(self, module: Linear): super().post_load_weights(module) - if is_sm_100f(): + if is_sm_100f() and not (module.use_cute_dsl_blockscaling_mm + or module.disable_deep_gemm): weight, weight_scale = resmooth_to_fp8_e8m0(module.weight, module.weight_scale) transfromed_scale = transform_sf_into_required_layout(