- 
          
 - 
                Notifications
    
You must be signed in to change notification settings  - Fork 11k
 
[Kernel] Add Kernel Support for NVFP4 #12519
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kernel] Add Kernel Support for NVFP4 #12519
Conversation
| 
           👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these: 
 🚀  | 
    
        
          
                tests/kernels/test_nvfp4_gemm.py
              
                Outdated
          
        
      There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs correction when m is < 128
| 
           Exciting!!!!  | 
    
c0445c0    to
    af8205f      
    Compare
  
            
          
                vllm/_custom_ops.py
              
                Outdated
          
        
      There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- please call this 
cutlass_scaled_fp4_mmfor naming consistency - please update the argument names to be consistent with 
cutlass_scaled_mmwherever possible 
        
          
                vllm/_custom_ops.py
              
                Outdated
          
        
      There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
workspace_bytes is unused?
        
          
                vllm/_custom_ops.py
              
                Outdated
          
        
      There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably better to have this in the c++?
        
          
                vllm/_custom_ops.py
              
                Outdated
          
        
      There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- This should be called 
scaled_fp4_quant - This should be next to 
scaled_fp8_quantbelow - I think we should create 
output_sfin this function (rather than have it be an argument). This will make the integration code more consistent withscaled_fp8_quantcode better and more consistent with - args should be called (
inputandscaleto be consistent withscaled_fp8_quant 
        
          
                csrc/torch_bindings.cpp
              
                Outdated
          
        
      There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move this next to cutlass_scaled_mm
        
          
                csrc/torch_bindings.cpp
              
                Outdated
          
        
      There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move this next to scaled_fp8_quant
| 
           Nice PR! Left some comments on the integration code. I will leave it to others to review the kernel.  | 
    
        
          
                tests/kernels/test_nvfp4_quant.py
              
                Outdated
          
        
      There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does the sf postfix stand for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
block scaling factor
This commit adds gemms for NVFP4 datatype and quantization kernels to convert to NVFP4 Co-authored by kahmadian@nvidia.com Co-authored by kaixih@nvidia.com Signed-off-by: Pavani Majety <pmajety@nvidia.com>
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
Correct usage of scaled_fp4_quant to used rounded m / n Signed-off-by: Pavani Majety <pmajety@nvidia.com>
af8205f    to
    fdcf219      
    Compare
  
    | 
           Hi we have decided to extract the fp4 quantization part into a separate PR. This PR will be based on it and only focus on the fp4 gemm.  | 
    
| " Tensor! b, Tensor! block_scale_a," | ||
| " Tensor! block_scale_b, Tensor! gscale," | ||
| " Tensor! workspace, int workspace_bytes) -> ()"); | ||
| ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm); | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we are missing a header definition for these in csrc/ops.h. I'm getting this compiler error:
/opt/vllm/vllm-src/csrc/torch_bindings.cpp: In function ‘void TORCH_LIBRARY_init__C(torch::Library&)’:
/opt/vllm/vllm-src/csrc/torch_bindings.cpp:390:52: error: ‘cutlass_scaled_fp4_mm’ was not declared in this scope; did you mean ‘cutlass_scaled_mm’?
  390 |   ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm);
      |                                                    ^~~~~~~~~~~~~~~~~~~~~
      |                                                    cutlass_scaled_mm
/opt/vllm/vllm-src/csrc/torch_bindings.cpp:397:47: error: ‘scaled_fp4_quant’ was not declared in this scope; did you mean ‘static_scaled_fp8_quant’?
  397 |   ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant);
      |                                               ^~~~~~~~~~~~~~~~
      |                                               static_scaled_fp8_quant
| ChooseWithHeuristic, | ||
| 
               | 
          ||
| // CTA configs for M=128 | ||
| CtaShape128x128x64B, | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How to select those CTA configs and the following ClusterShape? Any reasons behind those selection to achieve the best performance for various M, N, K shapes? Curious if only three CTA configs can already achieve the best performance with sm100?
| 
           This pull request has merge conflicts that must be resolved before it can be  | 
    
This commit adds gemms for NVFP4 datatype and quantization kernels to convert to NVFP4
Co-authored by kahmadian@nvidia.com
Co-authored by kaixih@nvidia.com