-
-
Notifications
You must be signed in to change notification settings - Fork 11k
allow calc_kv_scales #23906
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
allow calc_kv_scales #23906
Conversation
Summary: 1. The self.calc_kv_scales() should be invoked without checking `attn_metadata` 2. `attn_metadata` is avail only when full graph mode of cudagraph. If user did not use it, there is an error when checking `attn_metadata.enable_kv_scales_calculation` This diff should fix the above problem. But we can not use torch.compile when we set `calculate_kv_scales=True`, it will complain using .item() in `def calc_kv_scales()` Differential Revision: D81300417
|
This pull request was exported from Phabricator. Differential Revision: D81300417 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request fixes a crash that occurs when calculate_kv_scales is enabled but the code is not running in cudagraph full graph mode. The fix removes a dependency on attn_metadata, which can be None in this scenario. While this correctly addresses the crash, it introduces a potential new issue: calc_kv_scales is called unconditionally, but it will fail if key or value tensors are None. The existence of checks for key is not None and value is not None later in the forward method suggests this is a valid possibility. I've added a review comment to guard the call to calc_kv_scales to prevent this potential crash.
| attn_metadata = get_forward_context().attn_metadata | ||
| if attn_metadata.enable_kv_scales_calculation: | ||
| self.calc_kv_scales(query, key, value) | ||
| self.calc_kv_scales(query, key, value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The call to self.calc_kv_scales here could lead to a TypeError if key or value is None, as torch.abs(None) would be executed. Later in this method (lines 260-263), there are checks for key is not None and value is not None, which implies they can indeed be None. To prevent a potential crash, it's crucial to ensure key and value are not None before calling calc_kv_scales.
| self.calc_kv_scales(query, key, value) | |
| if key is not None and value is not None: | |
| self.calc_kv_scales(query, key, value) |
|
@heheda12345 Do you mind reviewing this as I saw you touched it in #12536 Thanks! |
|
@mgoin Can you help to take a look? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we want to always enable scale calculation if not necessary.
|
See #21640 for info and to stay updated on the fix |
Summary:
When running the gpt-oss, I found that there is a bug when enabling
calculate_kv_scales:attn_metadataattn_metadatais avail only when full graph mode of cudagraph. If a user does not use it, there is an error(NoneType) when checkingattn_metadata.enable_kv_scales_calculationThis PR should fix the above problem.
But we can not use torch.compile when we set
calculate_kv_scales=True, it will complain using .item() indef calc_kv_scales()Differential Revision: D81300417