Skip to content

Conversation

@zpqiu
Copy link

@zpqiu zpqiu commented Sep 5, 2025

Purpose

Fix a bug in Attention.forward where attn_metadata.enable_kv_scales_calculation was read unconditionally. During warmup/eager/graph-capture paths the metadata can be None, which caused:

  • torch.compile Dynamo data-dependent branching error (when enforce_eager=False)
  • AttributeError: 'NoneType' object has no attribute 'enable_kv_scales_calculation' (when enforce_eager=True)

This change treats missing/invalid metadata as disabled (False) for KV scale calculation, avoiding data-dependent control flow and None access. When metadata is properly constructed and injected via FlashAttention metadata, KV scale calculation remains enabled.

Resolves: #21640

Test Plan

Manual verification with FP8 KV cache and calculate_kv_scales=True:

  1. Default (torch.compile) path — no Dynamo error
import vllm

engine = vllm.LLM(
    model="Qwen/Qwen2-0.5B",          # any compatible model
    tensor_parallel_size=1,
    kv_cache_dtype="fp8_e4m3",
    calculate_kv_scales=True,         # triggers the guarded branch
)
out = engine.generate("Hello, world!")
print(out[0].outputs[0].text)
  1. Eager path — no AttributeError
import vllm

engine = vllm.LLM(
    model="Qwen/Qwen2-0.5B",
    tensor_parallel_size=1,
    kv_cache_dtype="fp8_e4m3",
    calculate_kv_scales=True,
    enforce_eager=True,
)
out = engine.generate("Hello, world!")
print(out[0].outputs[0].text)

Test Result

Before this change:

  • torch.compile path: Dynamo raised data-dependent branching at attn_metadata.enable_kv_scales_calculation.
  • eager path: AttributeError on NoneType for attn_metadata.

After this change:

  • Both paths run successfully and produce normal generations (no exceptions).
  • Behavior is unchanged when metadata is present; KV scale computation proceeds as before.

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: alexchiu <qiuzhaopeng@foxmail.com>
@zpqiu zpqiu marked this pull request as ready for review September 5, 2025 03:14
@github-actions
Copy link

github-actions bot commented Sep 5, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cc @LucasWilkinson for another set of eyes

@ProExpertProg ProExpertProg added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 8, 2025
@elvischenv
Copy link
Contributor

There is also another fix in #23912. attn_metadata can be a dict from the code:

forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]

if self.calculate_kv_scales:
attn_metadata = get_forward_context().attn_metadata
if attn_metadata.enable_kv_scales_calculation:
if (attn_metadata is not None and getattr(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zpqiu one question: if enable_kv_scales_calculation=True but not set during compilation, wouldn't attention metadata possibly be None during the profile_run (which also triggers compilation) and then the graph is compiled without this, meaning it never runs even if later calculation is enabled?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for piecewise cudagraphs, the dummy run runs without attention metadata https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu_model_runner.py#L2535

So the initial compile run has no attention metadata. This means that yes the graph will get compiled without this and it will be wrong later on

Copy link
Collaborator

@zou3519 zou3519 Sep 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that the kv scales get computed on the first real input only and are then used in subsequent inputs.

To actually fix this, I think what we need is that the first real input should run without torch.compile and CUDAGraphs. All subsequent inputs should run with torch.compile and CUDAGraphs.

Then we need to actually make sure the torch.compile'd graph includes the kv scales.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for piecewise cudagraphs, the dummy run runs without attention metadata https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu_model_runner.py#L2535

So the initial compile run has no attention metadata. This means that yes the graph will get compiled without this and it will be wrong later on

Thanks for pointing this out—you’re right. I printed the QKV scale values in vllm/v1/attention/backends/flash_attn.py forward() function, and they’re all the default 1.0, which suggests the dynamic scale computation didn’t take effect.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that the kv scales get computed on the first real input only and are then used in subsequent inputs.

To actually fix this, I think what we need is that the first real input should run without torch.compile and CUDAGraphs. All subsequent inputs should run with torch.compile and CUDAGraphs.

Then we need to actually make sure the torch.compile'd graph includes the kv scales.

Got it—I’ll try that approach. I’ll first sort out the profiling run logic.

@ProExpertProg ProExpertProg self-requested a review September 8, 2025 17:14
Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just want to clarify the question before merging

Copy link
Collaborator

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this works, please see comment

@heheda12345
Copy link
Collaborator

Another PR related to kv scale. Can anyone help to review it? #23906

@ProExpertProg
Copy link
Collaborator

See #21640 for info and to stay updated on the fix

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: calculate_kv_scales leads to dynamo compilation issue; enforce_eager=True leads to another issue

5 participants