Skip to content

Conversation

@ElizaWszola
Copy link
Contributor

@ElizaWszola ElizaWszola commented Apr 2, 2025

Implement BF16 and FP16 weight support in CUTLASS MoE kernels. Tested with

llm = LLM("mistralai/Mixtral-8x7B-Instruct-v0.1",
          tensor_parallel_size=2,
)

and

llm = LLM("mistralai/Mixtral-8x7B-Instruct-v0.1",
          tensor_parallel_size=2,
          dtype=torch.float16,
)

Signed-off-by: ElizaWszola <ewszola@redhat.com>
@github-actions
Copy link

github-actions bot commented Apr 2, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
@tlrmchlsmth tlrmchlsmth changed the title [WIP][Kernel] Enable BF16 weights in CUTLASS MoE [WIP][Kernel] Enable FP16 and BF16 CUTLASS MoE kernels Apr 2, 2025
Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General comment that using the name fp16 for operations that handle both fp16 and bf16 is confusing and we should either pick a more general name (16bit?), or better: append fp8 to the names of ops that handle fp8 and remove fp16 from names altogether

Signed-off-by: ElizaWszola <ewszola@redhat.com>
@mergify mergify bot added the ci/build label Apr 3, 2025
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
@ElizaWszola ElizaWszola marked this pull request as ready for review April 3, 2025 14:21
@ElizaWszola ElizaWszola changed the title [WIP][Kernel] Enable FP16 and BF16 CUTLASS MoE kernels [Kernel] Enable FP16 and BF16 CUTLASS MoE kernels Apr 3, 2025
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)

# TODO half()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the TODO? Resolve before landing?

@mergify
Copy link

mergify bot commented Apr 4, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ElizaWszola.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Apr 4, 2025
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the 16-bit configs are the same as the fp8 configs -- these need to be re-tuned for the fp16/bf16 case

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
@mergify mergify bot removed the needs-rebase label Apr 4, 2025
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Comment on lines 31 to 63
def run_8_bit(a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor,
w2_q: torch.Tensor, w1_scale: torch.Tensor,
w2_scale: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, ab_strides1: torch.Tensor,
c_strides1: torch.Tensor, ab_strides2: torch.Tensor,
c_strides2: torch.Tensor):
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))):
return cutlass_moe_fp8(a,
w1_q,
w2_q,
w1_scale,
w2_scale,
topk_weights,
topk_ids,
ab_strides1,
c_strides1,
ab_strides2,
c_strides2,
a1_scale=a_scale)


@pytest.mark.parametrize("m", [2, 64, 224])
@pytest.mark.parametrize("n", [1024, 3072])
@pytest.mark.parametrize("k", [1024, 1536])
return cutlass_moe(a,
w1_q,
w2_q,
topk_weights,
topk_ids,
ab_strides1,
c_strides1,
ab_strides2,
c_strides2,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_scale)


def run_16_bit(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
ab_strides1: torch.Tensor, c_strides1: torch.Tensor,
ab_strides2: torch.Tensor, c_strides2: torch.Tensor):
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))):
return cutlass_moe(a, w1, w2, topk_weights, topk_ids, ab_strides1,
c_strides1, ab_strides2, c_strides2)

Copy link
Contributor

@bnellnm bnellnm Apr 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Could these be combined with the scales made optional and defaulted to None for the fp16 case? I don't have strong feelings about this though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, probably

Comment on lines 313 to 315
print(triton_output)
print(cutlass_output)
print("*")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Do we need this prints?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tolerances in the tests are a bit high, so I use these prints to examine manually how off the values are if I'm close to the treshold

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the prints are fine for debugging but I don't think we should push with them enabled.

Comment on lines 371 to 373
print(triton_output)
print(cutlass_output)
print("*")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Signed-off-by: varun sundar rabindranath <vsundarr@redhat.com>
varun sundar rabindranath added 6 commits April 12, 2025 02:58
Signed-off-by: varun sundar rabindranath <vsundarr@redhat.com>
Signed-off-by: varun sundar rabindranath <vsundarr@redhat.com>
Signed-off-by: varun sundar rabindranath <vsundarr@redhat.com>
Signed-off-by: varun sundar rabindranath <vsundarr@redhat.com>
Signed-off-by: varun sundar rabindranath <vsundarr@redhat.com>
Signed-off-by: varun sundar rabindranath <vsundarr@redhat.com>
Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@varun-sundar-rabindranath do you have e2e benchmark results that we could share before landing this?

Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see the new FP16/BF16 configs still aren't in -- Please LMK when this is ready!

@varun-sundar-rabindranath
Copy link
Contributor

varun-sundar-rabindranath commented Apr 12, 2025

@tlrmchlsmth - I have the changes here neuralmagic#57 waiting to be merged on the neuralmagic:cutlass-moe-bf16-weights branch. I am still getting the e2e and microbenchmarks.

varun sundar rabindranath and others added 5 commits April 13, 2025 03:28
Signed-off-by: varun sundar rabindranath <vsundarr@redhat.com>
Signed-off-by: varun sundar rabindranath <vsundarr@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: varun sundar rabindranath <vsundarr@redhat.com>
@varun-sundar-rabindranath
Copy link
Contributor

Factoring out expert_map support into a separate PR #16861

Signed-off-by: varun sundar rabindranath <vsundarr@redhat.com>
@mergify
Copy link

mergify bot commented Apr 27, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ElizaWszola.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Apr 27, 2025
Signed-off-by: ElizaWszola <ewszola@redhat.com>
@mergify mergify bot removed the needs-rebase label Apr 29, 2025
Signed-off-by: ElizaWszola <ewszola@redhat.com>
@mergify
Copy link

mergify bot commented May 2, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ElizaWszola.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label May 2, 2025
@mergify mergify bot added the performance Performance-related issues label Jun 23, 2025
@github-actions
Copy link

This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you!

@github-actions github-actions bot added the stale Over 90 days of inactivity label Sep 22, 2025
@github-actions
Copy link

This pull request has been automatically closed due to inactivity. Please feel free to reopen if you intend to continue working on it. Thank you!

@github-actions github-actions bot closed this Oct 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build needs-rebase performance Performance-related issues stale Over 90 days of inactivity

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants