Skip to content

Conversation

@xuebwang-amd
Copy link

@xuebwang-amd xuebwang-amd commented Sep 4, 2025

Purpose

This PR aims to support layerwise mixed-precision quantization model inference, extending from quantized models in single scheme such as MXFP4, FP8 (aka PTQ models).

Here, the layerwise mixed-precision configuration for a given model is searched and then quantized by amd-quark. Specifically, in this PR, we focus on mixed scheme of {MXFP4, FP8}. FP8 here denotes for FP8 per-tensor scheme.

With the mixed-precision quantized model, one could achieve an optimal balance between accuracy and hardware metrics.
To demonstrate the benefits of mixed-precision model in the PR, we show the model accuracies on several commonly used tasks only using Quark emulation kernel for MXFP4 and triton kernel for FP8.

Test Plan

Test on

  1. amd/Llama-2-70b-chat-hf-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8
  2. amd/Mixtral-8x7B-Instruct-v0.1-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8
  3. amd/Qwen3-8B-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8

Test Result

image

List of TODO items

  • Layerwise mixed-precision quantization scheme of {MXFP4, FP8} (exactly this PR aims for)
  • extend model coverages
  • benchmark hardware metrics
  • further support Unquantized Linear and/or MoE layer(s) into mixed-precision scheme, i.e., {MXFP4, FP8, BF16/FP16}
  • further support MXFP6 scheme for mixed-precision quantization, i.e., {MXFP4, MXFP6, FP8, BF16/FP16}

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request extends Quark to support mixed-precision models, specifically for {MXFP4, FP8} schemes. The changes involve updating quantization configuration logic to handle mixed-precision setups and adding new tests to validate model accuracies. My review identified two high-severity issues. First, in the new test file, environment variables are not handled safely, which could lead to test state leakage. I've recommended using pytest.monkeypatch for robust cleanup. Second, in the Quark configuration logic, a fragile substring check is used for matching layer names, which could result in applying incorrect quantization schemes. I've suggested a more robust pattern matching approach to ensure correctness. Addressing these issues will improve the reliability and correctness of the new mixed-precision quantization feature.

Comment on lines 79 to 95
def test_mixed_precision_model_accuracies(config: EvaluationConfig, task: str):
os.environ["VLLM_QUARK_EMU_MEM_OPT"] = "1"

results = lm_eval.simple_evaluate(model="vllm",
model_args=config.get_model_args(),
tasks=task,
batch_size="auto")

rtol = 0.05

EXPECTED_VALUE = config.excepted_value
measured_value = results["results"][task]["acc,none"]
assert (measured_value - rtol < EXPECTED_VALUE
and measured_value + rtol > EXPECTED_VALUE
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"

del os.environ["VLLM_QUARK_EMU_MEM_OPT"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Setting and deleting an environment variable directly using os.environ can lead to state leakage between tests if an exception occurs before the del statement. This can cause subsequent tests to fail or behave unexpectedly. It's safer to use pytest's monkeypatch fixture, which is already used in this file, to manage environment variables. monkeypatch ensures that the environment is restored to its original state after the test function completes, regardless of whether it passes or fails.

Suggested change
def test_mixed_precision_model_accuracies(config: EvaluationConfig, task: str):
os.environ["VLLM_QUARK_EMU_MEM_OPT"] = "1"
results = lm_eval.simple_evaluate(model="vllm",
model_args=config.get_model_args(),
tasks=task,
batch_size="auto")
rtol = 0.05
EXPECTED_VALUE = config.excepted_value
measured_value = results["results"][task]["acc,none"]
assert (measured_value - rtol < EXPECTED_VALUE
and measured_value + rtol > EXPECTED_VALUE
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
del os.environ["VLLM_QUARK_EMU_MEM_OPT"]
def test_mixed_precision_model_accuracies(config: EvaluationConfig, task: str, monkeypatch):
monkeypatch.setenv("VLLM_QUARK_EMU_MEM_OPT", "1")
results = lm_eval.simple_evaluate(model="vllm",
model_args=config.get_model_args(),
tasks=task,
batch_size="auto")
rtol = 0.05
EXPECTED_VALUE = config.excepted_value
measured_value = results["results"][task]["acc,none"]
assert (measured_value - rtol < EXPECTED_VALUE
and measured_value + rtol > EXPECTED_VALUE
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"

Copy link
Author

@xuebwang-amd xuebwang-amd Sep 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated, along with several refactoring.

Comment on lines 289 to 294
layer_quant_configs = list()
for name_pattern in layer_quant_config:
if fnmatch.fnmatch(layer_name, name_pattern):
return layer_quant_config[name_pattern]
if layer_name in name_pattern:
layer_quant_configs.append(
layer_quant_config[name_pattern])
return layer_quant_configs[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The matching logic if layer_name in name_pattern: is likely incorrect and can lead to bugs. It performs a substring check, which can cause unintended matches (e.g., attention.q_proj would incorrectly match a pattern for cross_attention.q_proj). This is likely not the intended behavior for matching layer configurations and could lead to applying the wrong quantization scheme to a layer.

Given that fnmatch is used elsewhere in this file for pattern matching, it seems the intention is to support glob patterns. If name_pattern can be a comma-separated list of patterns, the logic should be updated to split the string and apply fnmatch to each part. This ensures accurate matching of layer configurations and prevents applying the wrong quantization scheme.

The current implementation also unnecessarily creates a list layer_quant_configs to immediately return its first element. This can be simplified by returning directly upon finding a match.

Suggested change
layer_quant_configs = list()
for name_pattern in layer_quant_config:
if fnmatch.fnmatch(layer_name, name_pattern):
return layer_quant_config[name_pattern]
if layer_name in name_pattern:
layer_quant_configs.append(
layer_quant_config[name_pattern])
return layer_quant_configs[0]
for name_pattern in layer_quant_config:
patterns = [p.strip() for p in name_pattern.split(',')]
for p in patterns:
if fnmatch.fnmatch(layer_name, p):
return layer_quant_config[name_pattern]

Copy link
Author

@xuebwang-amd xuebwang-amd Sep 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code snippet suggest from gemini-code-assist is problematic. Because for name_pattern, it looks like model.layers.0.block_sparse_moe.experts.0.w1 as an example. So name_pattern.split(',') doesn't make sense and subsequent fnmatch.fnmatch is also irrelevant.

@mergify mergify bot added the documentation Improvements or additions to documentation label Sep 4, 2025
Copy link
Contributor

@BowenBao BowenBao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, great start!

dict[str, Any], self.quant_config.get("layer_quant_config"))
layer_quant_configs = list()
for name_pattern in layer_quant_config:
if fnmatch.fnmatch(layer_name, name_pattern):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change necessary? Also layer_quant_configs seem unused: appends the first matched config and immediately returns it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update as also suggested #24239 (comment)

) -> tuple[torch.Tensor, None]:
assert block_shape is None
if not current_platform.supports_mx():
VLLM_QUARK_EMU_MEM_OPT = (os.environ.get("VLLM_QUARK_EMU_MEM_OPT",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general for env flags it is better to add to vllm/vllm/envs.py with comments on its effect.

Can you keep this change local? In particular we want to move away from simulation to triton kernels as we move forward. cc @fxmarty-amd

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally agree on that.
The reason why VLLM_QUARK_EMU_MEM_OPT is not added into vllm/vllm/envs.py is because it's better to make it as a local and temporal environment variable, just for make things work at this moment. After non-emulation kernels such as triton or aiter implementations are integrated, we can totally remove it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xuebwang-amd this variable that I added previously has been removed as per @mgoin request in order to avoid adding new a new unnecessary env variable to vllm, especially given that we have a decently fast mxfp4 dequantization kernel.

Please avoid adding this environment variable, keep it local for testing if needed.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I appreciate your previous effort about this emulation approach, it played a role more than local test. The functionality goes on like what I'm doing here.
Actually, it indeed goes to the mx.qdq_mxfp4 defined in the https://github.com/vllm-project/vllm/blob/8de261b04a0a0e916d3d25d528d0f2ddeede2a6b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py#L94C5-L94C25 with enable the VLLM_QUARK_EMU_MEM_OPT=1.

The real motivation of this environment variable is to let flow go to the emulation flow regardless of platform support of MX because the non-emulation kernels haven't been implemented into the flow.

Therefore, the solution here is to remove the if-else statement:
if not current_platform.supports_mx(): A = quant_dequant_mxfp4(A) else: raise NotImplementedError()
and let it to be always A = quant_dequant_mxfp4(A).

layer_quant_set = set(layer_quant_names)

if not kv_cache_set.issubset(layer_quant_set):
if not (kv_cache_set.issubset(layer_quant_set) or \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain what is goal for these changes around kv cache?

For AMP models, are kv-caches still uniformly quantized the same way across all layers?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, currently mixed precision is not applied on the KV cache dimension across all KV layers.
Changes here aim to correctly verify if the kv cache pattern such as {'*v_proj', '*k_proj'} can match, in other words, can be found in at least one layer_quant_set keys (i.e., layer names).
This is essential when going to AMP scenarios that layer_quant_names are specified one by one, rather than concentrating in a fuzzy matching way.

Comment on lines 20 to 25
@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
"""
This module relies on V0 internals, so set VLLM_USE_V1=0.
"""
monkeypatch.setenv('VLLM_USE_V1', '0')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's avoid using v0

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For test purpose, especially for accuracy test, using V0 is safe. Even for hardware metric test later, using V0 is still safer while valuable for demonstrations.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vllm v0 is deprecated: #18571

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

V1 is reported to be having issues as you can see. Since mixed-precision quantization is not dependent on V0/V1 engine, it's safe to use V0.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use_v0_only had been removed as the V0 backend is deprecated #25351 very recently. Thanks @fxmarty-amd

Comment on lines 32 to 37
try:
huggingface_hub.list_repo_refs(
"amd/Llama-2-70b-chat-hf-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8")
HF_HUB_AMD_ORG_ACCESS = True
except huggingface_hub.errors.RepositoryNotFoundError:
HF_HUB_AMD_ORG_ACCESS = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use public models

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These models are under progress for publish.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have an ETA for when we can expect these models to be published?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AMD's colleagues are speeding up the progress, hopefully they can make it happen some time next week.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xuebwang-amd I meant that for unit testing you can probably use small models just for integration test purposes (as e.g. in

@pytest.mark.parametrize('model_case', [
ModelCase("fxmarty/qwen_1.5-moe-a2.7b-mxfp4", tp=1),
ModelCase("fxmarty/deepseek_r1_3_layers_mxfp4", tp=8),
ModelCase("fxmarty/Llama-4-Scout-17B-16E-Instruct-2-layers-mxfp4", tp=1)
])
) - but having private models is okay for a while I guess.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fxmarty-amd your motivation here is to reduce the CI time cost, that's good. We can consider pick up one public model into the CI test. @gshtras @SageMoore

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have an ETA for when we can expect these models to be published?

Eventually they are published.

reason="Read access to huggingface.co/amd is required for this test.")
def test_mixed_precision_model_accuracies(model_name: str,
accuracy_numbers: dict, monkeypatch):
monkeypatch.setenv("VLLM_QUARK_EMU_MEM_OPT", "1")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This environment variable has no effect - it has been removed from vllm.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then we need to remove the if-else statement in the _mxfp4_quantize, as commented in above #24239 (comment)

) -> tuple[torch.Tensor, None]:
assert block_shape is None
if not current_platform.supports_mx():
VLLM_QUARK_EMU_MEM_OPT = (os.environ.get("VLLM_QUARK_EMU_MEM_OPT",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xuebwang-amd this variable that I added previously has been removed as per @mgoin request in order to avoid adding new a new unnecessary env variable to vllm, especially given that we have a decently fast mxfp4 dequantization kernel.

Please avoid adding this environment variable, keep it local for testing if needed.

Comment on lines +275 to +305
As examples, we provide some ready-to-use quantized mixed precision model to show the usage in vLLM and the accuracy benifits. They are:

- amd/Llama-2-70b-chat-hf-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8
- amd/Mixtral-8x7B-Instruct-v0.1-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8
- amd/Qwen3-8B-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make these public + add link

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They're going to be published.

@fxmarty-amd
Copy link
Contributor

Test Plan

Test on

1. [amd/Llama-2-70b-chat-hf-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8](https://huggingface.co/amd/Llama-2-70b-chat-hf-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8)

2. [amd/Mixtral-8x7B-Instruct-v0.1-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8](https://huggingface.co/amd/Mixtral-8x7B-Instruct-v0.1-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8)

3. [amd/Qwen3-8B-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8](https://huggingface.co/amd/Qwen3-8B-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8)

Can you provide:

  • Which proportion of layers are in FP8/MXFP4
  • Comparison against MXFP4 alone?

@xuebwang-amd
Copy link
Author

xuebwang-amd commented Sep 15, 2025

Test Plan

Test on

1. [amd/Llama-2-70b-chat-hf-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8](https://huggingface.co/amd/Llama-2-70b-chat-hf-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8)

2. [amd/Mixtral-8x7B-Instruct-v0.1-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8](https://huggingface.co/amd/Mixtral-8x7B-Instruct-v0.1-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8)

3. [amd/Qwen3-8B-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8](https://huggingface.co/amd/Qwen3-8B-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8)

Can you provide:

  • Which proportion of layers are in FP8/MXFP4
  • Comparison against MXFP4 alone?

One can check the detailed layerwise MXFP8/FP8 configuration in the config.json, specifically the key quantization_config:layer_quant_config.
Not only plain MXFP4, but also plain FP8, the accuracies and hardware metrics are measured. Updates are on-going.
Note here, these numbers are and will not be guaranteed as final and optimized values. They're for demonstration purpose, and could be further improved.

@xuebwang-amd xuebwang-amd changed the title [Feature][Quantization] extend Quark to support mixed-precision quantized model [ROCm][Quantization] extend AMD Quark to support mixed-precision quantized model Sep 17, 2025
@mergify mergify bot added the rocm Related to AMD ROCm label Sep 17, 2025
A = quant_dequant_mxfp4(A)
else:
raise NotImplementedError()
A = quant_dequant_mxfp4(A)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean that before the PR MI350 would get an exception, and now this method is being called unconditionally?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes! it was an oversight in a previous PR, we should be able to run simulation on cdna4 until kernels are integrated.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see #22355

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR focus on the accuracy benefits of mixed precision via emulated QDQ at this moment, so the if-else about the platform is removed.
The benefits of hardware metrics using real kernel integrated is the next step.

Copy link
Contributor

@fxmarty-amd fxmarty-amd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xuebwang-amd if I understand correctly, this PR is now mostly about adding documentation right?

Comment on lines 289 to 328
for name_pattern, config in layer_quant_config.items():
if layer_name in name_pattern:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we make sure somewhere that e.g. q_proj from the checkpoint/Transformers gets correctly mapped to qkv_proj in vllm (https://github.com/ROCm/vllm/blob/eb9d4de9eb7649bdf36b2d0e4832fcaab8465153/vllm/model_executor/models/llama.py#L150) prior to doing the check layer_name in name_pattern?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question.
Quark model/config is nearly decoupled with vllm's model implementation. The q_proj, k_proj, v_proj are merged in vllm, while they are separated in Quark quantized model and configs. The q_proj, k_proj, v_proj are mandatorily to have same bitwidth, i.e., same quantization scheme in Quark's AMP so that the alignment is achieved.
Therefore, the matching of Quark's layerwise quant config is applied to q_proj, k_proj, v_proj individually.

layer_quant_set = set(layer_quant_names)

if not kv_cache_set.issubset(layer_quant_set):
if not (kv_cache_set.issubset(layer_quant_set) or \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I'd avoid writing the nested loops inline for readability purposes.

Comment on lines -282 to -316
for name_pattern in layer_quant_config:
if fnmatch.fnmatch(layer_name, name_pattern):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am confused. What is this PR changing here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The gemini-code-assist had also similar questions. Please see my comments above, e.g., #24239 (comment)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for referencing our previous discussion, @xuebwang-amd. I'd like to clarify the change in behavior introduced by replacing fnmatch.fnmatch(layer_name, name_pattern) with layer_name in name_pattern.

Key Difference:

  • fnmatch.fnmatch(layer_name, name_pattern): This function performs shell-style wildcard matching. For example, if name_pattern is "model.layers.*.attn.q_proj", it would match "model.layers.0.attn.q_proj", "model.layers.1.attn.q_proj", etc.
  • layer_name in name_pattern: This performs a simple substring check. For "model.layers.0.attn.q_proj" in "model.layers.*.attn.q_proj", it would return False because "model.layers.0.attn.q_proj" is not a substring of "model.layers.*.attn.q_proj".

This change fundamentally alters how layer names are matched against the layer_quant_config. If the intention is to support wildcard patterns for layer quantization (which is a common practice for applying the same quantization scheme to multiple layers), then the current change would likely break existing configurations or limit the flexibility of defining quantization patterns.

Unless there's a specific reason to remove wildcard matching, I recommend reverting to fnmatch.fnmatch to maintain the expected pattern matching behavior. The iteration over layer_quant_config.items() is correct for accessing the config value, but the matching logic should be adjusted.

Suggested change
for name_pattern in layer_quant_config:
if fnmatch.fnmatch(layer_name, name_pattern):
for name_pattern, config in layer_quant_config.items():
if fnmatch.fnmatch(layer_name, name_pattern):
return config

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gemini code bot is not useful here.

@xuebwang-amd I don't understand why this PR introduces an handling different than e.g.
https://github.com/amd/Quark/blob/release/0.9/quark/torch/export/main_export/quant_config_parser.py#L67-L70
and e.g.
https://github.com/amd/Quark/blob/release/0.9/quark/torch/quantization/model_transformation.py#L80-L84

why would the handling in vllm be different than we have in quark, e.g. when reloading models through Transformers library? I think it is not a good thing. Maybe existing models rely on fnmatch.fnmatch and things would break now.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There have been lots of discussions about it in this PR.
To emphasize here is, this is for AMP in which layers are specified one by one, so name_pattern in layer_quant_config works in a strict matching way while fnmatch.fnmatch doesn't fit here.

Copy link
Author

@xuebwang-amd xuebwang-amd Oct 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can fully understand your concern here. Please find my explanations above like:

#24239 (comment)
#24239 (comment)
#24239 (comment)

To ensure no breaking or confliction to existed PTQ model matching, I add a a non-mixed-precision (PTQ, public) model as a reference to demonstrate pipeline compatibility in the tests/quantization/test_mixed_precision.py https://github.com/xuebwang-amd/vllm/blob/db3cc7eba1609370e34b35f51c7a5fa3111bb868/tests/quantization/test_mixed_precision.py#L45

Conclusion is: no conflicts or breakings using precise substring containment matching rule.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can set both glob-style wildcard character and precise substring containment for layer_quant_config matching. @BowenBao

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I apologize for coming to this discussion late but I also have some concerns here. It looks like you would like to add substring matching to this check. So that layer_name will match with layer_name_0, layer_name_1, etc. Before your change the code would only do substring matching when the * character is appended to the end of the substring. So you would have to have layer_name*. The concern that I have is that you are turning on substring matching by default. Meaning that layer_name_1 will match with layer_name_12 even if that's not the callers intention. Would it make more sense to leave the code as is and just append the * to the quant config?

I'm not familiar with how quantization configs are specified but this does seem like it's introducing a foot gun?

Copy link
Author

@xuebwang-amd xuebwang-amd Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @SageMoore, that's a good suggestion and should work as well.
However, our current approach can be efficient, let me break it down into two aspects:

  • Simple substring matching is mostly more efficient than fnmatch which involves pattern matching. That's the reason why put substring matching as default one. Note this is for a single match check.
  • From a model-level perspective, the single-match effect described above aggregates across an Auto Mixed Precision (AMP) model whose layers are explicitly enumerated.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @SageMoore, regarding the concern over incorrect exact match, I think it is fine as it's dangerous only when the name ends with layer idx with no followed characters, while the names pattern we expect in this case is more like xxx.layers.1.yyy or xxx.experts.12.zzz.
@xuebwang-amd feel free to add a couple of examples here/in comment for illustration.

@xuebwang-amd
Copy link
Author

@xuebwang-amd if I understand correctly, this PR is now mostly about adding documentation right?

Not the case. This PR aims to support layerwise mixed-precision quantization of Quark, and demonstrate the resulting accuracy gains.
Documentation is one essential part since it's a new feature besides existing single quantization scheme, though it's nearly virtual free to enable this feature.

@xuebwang-amd
Copy link
Author

Hello @SageMoore , @gshtras ,
The models shown in the PR have been published, they are:

  • amd/Llama-2-70b-chat-hf-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8
  • amd/Mixtral-8x7B-Instruct-v0.1-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8
  • amd/Qwen3-8B-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8
    They're publicly available for CI test now.

cc to @BowenBao , @fxmarty-amd

Comment on lines 23 to 27
huggingface_hub.list_repo_refs(
"amd/Qwen3-8B-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8")
HF_HUB_AMD_ORG_ACCESS = True
except huggingface_hub.errors.RepositoryNotFoundError:
HF_HUB_AMD_ORG_ACCESS = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the models are now public, could we remove this part?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, it's removed.

@gshtras gshtras added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 30, 2025
@mergify
Copy link

mergify bot commented Oct 8, 2025

Documentation preview: https://vllm--24239.org.readthedocs.build/en/24239/

Copy link
Contributor

@BowenBao BowenBao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good from Quark side.

@xuebwang-amd
Copy link
Author

CI test https://buildkite.com/vllm/ci/builds/36689/steps/canvas?jid=019a2d89-7117-4eba-a593-770cdfaa5212 failed:
=========================== short test summary info ============================
  | [2025-10-29T02:43:28Z] FAILED quantization/test_fp8.py::test_kv_cache_model_load_and_run[False-nm-testing/Qwen2-1.5B-Instruct-FP8-K-V] - OSError: nm-testing/Qwen2-1.5B-Instruct-FP8-K-V is not a local folder and is not a valid model identifier listed on 'https://huggingface.co/models'
Cannot find nm-testing/Qwen2-1.5B-Instruct-FP8-K-V on HF.

@xuebwang-amd
Copy link
Author

CI test failed:
=========================== short test summary info ============================
  | [2025-10-29T03:49:59Z] FAILED evals/gsm8k/test_gsm8k_correctness.py::test_gsm8k_correctness_param[Qwen1.5-MoE-W4A16-CT-tp1] - AssertionError: Accuracy too low: 0.072 < 0.450 - 0.080
  | [2025-10-29T03:49:59Z] assert 0.07202426080363912 >= (0.45 - 0.08)

@DarkLight1337
Copy link
Member

Should be fixed on main, let me rebase

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build deepseek Related to DeepSeek models documentation Improvements or additions to documentation frontend gpt-oss Related to GPT-OSS models kv-connector llama Related to Llama models multi-modality Related to multi-modality (#4194) new-model Requests to new models performance Performance-related issues qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm speculative-decoding structured-output tool-calling v1

Projects

Status: Done
Status: Done
Status: To Triage

Development

Successfully merging this pull request may close these issues.

8 participants