Skip to content

Conversation

@pavanimajety
Copy link
Collaborator

@pavanimajety pavanimajety commented Jan 28, 2025

This commit adds gemms for NVFP4 datatype and quantization kernels to convert to NVFP4

Co-authored by kahmadian@nvidia.com
Co-authored by kaixih@nvidia.com

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs correction when m is < 128

@robertgshaw2-redhat
Copy link
Collaborator

Exciting!!!!

@pavanimajety pavanimajety force-pushed the blackwell-fp4-gemms-ckpt branch 2 times, most recently from c0445c0 to af8205f Compare January 28, 2025 22:41
Copy link
Collaborator

@robertgshaw2-redhat robertgshaw2-redhat Jan 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • please call this cutlass_scaled_fp4_mm for naming consistency
  • please update the argument names to be consistent with cutlass_scaled_mm wherever possible

Copy link
Collaborator

@robertgshaw2-redhat robertgshaw2-redhat Jan 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

workspace_bytes is unused?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably better to have this in the c++?

Copy link
Collaborator

@robertgshaw2-redhat robertgshaw2-redhat Jan 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • This should be called scaled_fp4_quant
  • This should be next to scaled_fp8_quant below
  • I think we should create output_sf in this function (rather than have it be an argument). This will make the integration code more consistent with scaled_fp8_quant code better and more consistent with
  • args should be called (input and scale to be consistent with scaled_fp8_quant

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this next to cutlass_scaled_mm

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this next to scaled_fp8_quant

@robertgshaw2-redhat
Copy link
Collaborator

Nice PR! Left some comments on the integration code.

I will leave it to others to review the kernel.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does the sf postfix stand for?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

block scaling factor

This commit adds gemms for NVFP4 datatype and quantization kernels
to convert to NVFP4

Co-authored by kahmadian@nvidia.com
Co-authored by kaixih@nvidia.com

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
Correct usage of scaled_fp4_quant to used rounded m / n

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
@kaixih
Copy link
Contributor

kaixih commented Feb 5, 2025

Hi we have decided to extract the fp4 quantization part into a separate PR. This PR will be based on it and only focus on the fp4 gemm.

@pavanimajety @kushanam

" Tensor! b, Tensor! block_scale_a,"
" Tensor! block_scale_b, Tensor! gscale,"
" Tensor! workspace, int workspace_bytes) -> ()");
ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we are missing a header definition for these in csrc/ops.h. I'm getting this compiler error:

/opt/vllm/vllm-src/csrc/torch_bindings.cpp: In function ‘void TORCH_LIBRARY_init__C(torch::Library&)’:
/opt/vllm/vllm-src/csrc/torch_bindings.cpp:390:52: error: ‘cutlass_scaled_fp4_mm’ was not declared in this scope; did you mean ‘cutlass_scaled_mm’?
  390 |   ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm);
      |                                                    ^~~~~~~~~~~~~~~~~~~~~
      |                                                    cutlass_scaled_mm
/opt/vllm/vllm-src/csrc/torch_bindings.cpp:397:47: error: ‘scaled_fp4_quant’ was not declared in this scope; did you mean ‘static_scaled_fp8_quant’?
  397 |   ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant);
      |                                               ^~~~~~~~~~~~~~~~
      |                                               static_scaled_fp8_quant

ChooseWithHeuristic,

// CTA configs for M=128
CtaShape128x128x64B,
Copy link

@jiawenliu64 jiawenliu64 Feb 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How to select those CTA configs and the following ClusterShape? Any reasons behind those selection to achieve the best performance for various M, N, K shapes? Curious if only three CTA configs can already achieve the best performance with sm100?

@pavanimajety @kaixih @ @kushanam @LucasWilkinson

@mergify
Copy link

mergify bot commented Feb 13, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @pavanimajety.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants