-
-
Couldn't load subscription status.
- Fork 10.8k
[Kernel] Enable FP16 and BF16 CUTLASS MoE kernels #15932
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kernel] Enable FP16 and BF16 CUTLASS MoE kernels #15932
Conversation
Signed-off-by: ElizaWszola <ewszola@redhat.com>
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
General comment that using the name fp16 for operations that handle both fp16 and bf16 is confusing and we should either pick a more general name (16bit?), or better: append fp8 to the names of ops that handle fp8 and remove fp16 from names altogether
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
| def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | ||
| super().process_weights_after_loading(layer) | ||
|
|
||
| # TODO half() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the TODO? Resolve before landing?
|
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like the 16-bit configs are the same as the fp8 configs -- these need to be re-tuned for the fp16/bf16 case
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
tests/kernels/test_cutlass_moe.py
Outdated
| def run_8_bit(a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, | ||
| w2_q: torch.Tensor, w1_scale: torch.Tensor, | ||
| w2_scale: torch.Tensor, topk_weights: torch.Tensor, | ||
| topk_ids: torch.Tensor, ab_strides1: torch.Tensor, | ||
| c_strides1: torch.Tensor, ab_strides2: torch.Tensor, | ||
| c_strides2: torch.Tensor): | ||
| with set_current_vllm_config( | ||
| VllmConfig(parallel_config=ParallelConfig( | ||
| pipeline_parallel_size=1))): | ||
| return cutlass_moe_fp8(a, | ||
| w1_q, | ||
| w2_q, | ||
| w1_scale, | ||
| w2_scale, | ||
| topk_weights, | ||
| topk_ids, | ||
| ab_strides1, | ||
| c_strides1, | ||
| ab_strides2, | ||
| c_strides2, | ||
| a1_scale=a_scale) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("m", [2, 64, 224]) | ||
| @pytest.mark.parametrize("n", [1024, 3072]) | ||
| @pytest.mark.parametrize("k", [1024, 1536]) | ||
| return cutlass_moe(a, | ||
| w1_q, | ||
| w2_q, | ||
| topk_weights, | ||
| topk_ids, | ||
| ab_strides1, | ||
| c_strides1, | ||
| ab_strides2, | ||
| c_strides2, | ||
| w1_scale=w1_scale, | ||
| w2_scale=w2_scale, | ||
| a1_scale=a_scale) | ||
|
|
||
|
|
||
| def run_16_bit(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, | ||
| topk_weights: torch.Tensor, topk_ids: torch.Tensor, | ||
| ab_strides1: torch.Tensor, c_strides1: torch.Tensor, | ||
| ab_strides2: torch.Tensor, c_strides2: torch.Tensor): | ||
| with set_current_vllm_config( | ||
| VllmConfig(parallel_config=ParallelConfig( | ||
| pipeline_parallel_size=1))): | ||
| return cutlass_moe(a, w1, w2, topk_weights, topk_ids, ab_strides1, | ||
| c_strides1, ab_strides2, c_strides2) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Could these be combined with the scales made optional and defaulted to None for the fp16 case? I don't have strong feelings about this though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, probably
tests/kernels/test_cutlass_moe.py
Outdated
| print(triton_output) | ||
| print(cutlass_output) | ||
| print("*") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Do we need this prints?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tolerances in the tests are a bit high, so I use these prints to examine manually how off the values are if I'm close to the treshold
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the prints are fine for debugging but I don't think we should push with them enabled.
tests/kernels/test_cutlass_moe.py
Outdated
| print(triton_output) | ||
| print(cutlass_output) | ||
| print("*") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
Signed-off-by: varun sundar rabindranath <vsundarr@redhat.com>
dda21cc to
5ab56cb
Compare
Signed-off-by: varun sundar rabindranath <vsundarr@redhat.com>
Signed-off-by: varun sundar rabindranath <vsundarr@redhat.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@varun-sundar-rabindranath do you have e2e benchmark results that we could share before landing this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I see the new FP16/BF16 configs still aren't in -- Please LMK when this is ready!
|
@tlrmchlsmth - I have the changes here neuralmagic#57 waiting to be merged on the |
Signed-off-by: varun sundar rabindranath <vsundarr@redhat.com>
Signed-off-by: varun sundar rabindranath <vsundarr@redhat.com> Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: varun sundar rabindranath <vsundarr@redhat.com>
|
Factoring out expert_map support into a separate PR #16861 |
Signed-off-by: varun sundar rabindranath <vsundarr@redhat.com>
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
|
This pull request has merge conflicts that must be resolved before it can be |
|
This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you! |
|
This pull request has been automatically closed due to inactivity. Please feel free to reopen if you intend to continue working on it. Thank you! |
Implement BF16 and FP16 weight support in CUTLASS MoE kernels. Tested with
and