-
-
Notifications
You must be signed in to change notification settings - Fork 11.3k
[Misc] parametrize 'dtype' in test_flash_mla #22641
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Misc] parametrize 'dtype' in test_flash_mla #22641
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request successfully parametrizes the dtype and device for the test_flash_mla test, removing the hardcoded values and making the test more flexible as intended. My review includes one suggestion to improve the logic for determining the list of CUDA devices to test on, making it more robust, especially for environments with no CUDA devices.
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
NickLucche
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sweeping over gpus on the same host doesn't make much sense. Let's put a list with just ["cuda"] for now, we can extend to other platforms in due time when supported. Also please remove the print.
Signed-off-by: RUTHLESS-BOT <wujiafeng@cmbchina.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: RUTHLESS-BOT <wujiafeng@cmbchina.com>
Signed-off-by: RUTHLESS-BOT <wujiafeng@cmbchina.com>
32b690a to
c21d5c4
Compare
yewentao256
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, thanks for the work! Bug I think device won't be needed to be parametrized
@yewentao256 sorry to disturb you, I unexpectedly trigger re-request to you... Any need to move param 'device' to constant again? It seems also to be clean and applicable now LOL. |
| @pytest.mark.parametrize("dtype", [torch.bfloat16]) | ||
| @pytest.mark.parametrize("device", ["cuda:0"]) | ||
| @torch.inference_mode() | ||
| def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, | ||
| varlen): | ||
| # TODO: parametrize using pytest | ||
| dtype = torch.bfloat16 | ||
| device = torch.device("cuda:0") | ||
| varlen, dtype, device): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean leave the device device = torch.device("cuda:0") same, just parametrizing for dtype. And possibly adding more dtype
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got it :)
Signed-off-by: RUTHLESS-BOT <wujiafeng@cmbchina.com>
@NickLucche got it :) |
yewentao256
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, thanks for the work!
NickLucche
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for contributing!
Signed-off-by: RUTHLESS-BOT <wujiafeng@cmbchina.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: Diego-Castan <diego.castan@ibm.com>
Signed-off-by: RUTHLESS-BOT <wujiafeng@cmbchina.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: RUTHLESS-BOT <wujiafeng@cmbchina.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: RUTHLESS-BOT <wujiafeng@cmbchina.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: RUTHLESS-BOT <wujiafeng@cmbchina.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: Xiao Yu <xiao.yu@amd.com>
Signed-off-by: RUTHLESS-BOT <wujiafeng@cmbchina.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.Purpose
complete the TODO list in
tests/kernels/attention/test_flashmla.pyto maketest_flash_mlatest more flexible by parametrizingdtypeTest Plan
Test Result
(Optional) Documentation Update