Skip to content

Conversation

@varun-sundar-rabindranath
Copy link

@varun-sundar-rabindranath varun-sundar-rabindranath commented Apr 12, 2025

Add new FP16 configs and Support expert_map for EP

E2E benchmark numbers : link

Micro benchmarks : link - there are some bald spots I am looking into.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@varun-sundar-rabindranath
Copy link
Author

@ElizaWszola @dsikka can you please a look at the expert_map support part of the PR please ! Thanks 🙌

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly refactor existing tests and add EP tests.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ElizaWszola I see this function being called below in this file in fused_moe::__init__() - I updated this code to not raise any errors as the expected behavior seems to be to fallback to triton impl. PTAL! Thanks !

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good update!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dsikka Can you please take a look at the compressed_tensors changes please. Thanks !

Signed-off-by: varun sundar rabindranath <vsundarr@redhat.com>
@varun-sundar-rabindranath varun-sundar-rabindranath force-pushed the varun/expert-maps-fp16-configs branch from f1b859b to 1168828 Compare April 12, 2025 03:03
varun sundar rabindranath added 4 commits April 12, 2025 03:13
Signed-off-by: varun sundar rabindranath <vsundarr@redhat.com>
Signed-off-by: varun sundar rabindranath <vsundarr@redhat.com>
Signed-off-by: varun sundar rabindranath <vsundarr@redhat.com>
Signed-off-by: varun sundar rabindranath <vsundarr@redhat.com>

c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype)
c2 = torch.zeros((m * topk, k), device=device, dtype=out_dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed for correctness? empty should be faster than zeros

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK I see this now

      // c2 is initialized to zeros, therefore by setting the output_permutation
      // to num_tokens, we are guaranteed to fill the moe outputs to zero
      // for "invalid" topk_ids.

Comment on lines 155 to 160
a_map = torch.zeros((local_topk_ids.numel()),
dtype=torch.int32,
device=device)
c_map = torch.zeros((local_topk_ids.numel()),
dtype=torch.int32,
device=device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about these two, why do they need to be zeros?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a_map has to be zeros because we don't fill the indices related to "invalid" topk ids. the c_map can be actually empty as we fill all the indices in the get_cutlass_moe_mm_data. Ill make the change.


c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype)
c2 = torch.zeros((m * topk, k), device=device, dtype=out_dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK I see this now

      // c2 is initialized to zeros, therefore by setting the output_permutation
      // to num_tokens, we are guaranteed to fill the moe outputs to zero
      // for "invalid" topk_ids.

varun sundar rabindranath added 2 commits April 13, 2025 03:28
Signed-off-by: varun sundar rabindranath <vsundarr@redhat.com>
Signed-off-by: varun sundar rabindranath <vsundarr@redhat.com>
Comment on lines 181 to 182
uint32_t const n = out_tensors.size(1);
uint32_t const k = a_tensors.size(1);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these vars look unused now, it's ok to remove them

@ElizaWszola
Copy link

ElizaWszola commented Apr 14, 2025

Lgtm! I added a minor comment

@ElizaWszola
Copy link

Merged through command line, I think it's safe to close now

@varun-sundar-rabindranath
Copy link
Author

merged through command line.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants