- 
                Notifications
    
You must be signed in to change notification settings  - Fork 357
 
Fix Float8Tensor quantize op kernrel preference dispatch #2883
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
          
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2883
 Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 4516f6e with merge base 2a53216 ( This comment was automatically generated by Dr. CI and updates every 15 minutes.  | 
    
6935cc8    to
    bacbe8c      
    Compare
  
    6cf26bd    to
    815a964      
    Compare
  
            
          
                test/quantization/quantize_/workflows/float8/test_float8_tensor.py
              
                Outdated
          
            Show resolved
            Hide resolved
        
              
          
                torchao/quantization/quantize_/workflows/float8/float8_tensor.py
              
                Outdated
          
            Show resolved
            Hide resolved
        
              
          
                torchao/quantization/quantize_/workflows/float8/float8_tensor.py
              
                Outdated
          
            Show resolved
            Hide resolved
        
      815a964    to
    a78fc11      
    Compare
  
            
          
                torchao/quantization/quantize_/workflows/float8/float8_tensor.py
              
                Outdated
          
            Show resolved
            Hide resolved
        
      be69537    to
    5f6ec32      
    Compare
  
    | kernel_choice = "fbgemm" | ||
| elif weight_tensor.kernel_preference == KernelPreference.TRITON: | ||
| # no triton gemm op is available, so we'll fallback to torch | ||
| kernel_choice = "torch" | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also weird, your kernel choice is doing double duty and now a recipe. that recipe is also not very clear from your initial doc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you mean kernel choice is used for both quantize and gemm?
what is the initial doc you are referring to?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah exactly, The doc is just the code block and reading on the kernel choice doc string
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kernel choice is used for both quantize and gemm
yeah that's a decision we made before, according to Josh there is no need to have kernel level choice for now, just to keep things simple.
we did mention this in the KernelPreference doc I think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO it should be:
- fbgemm - use all fbgemm kernels, error out if something is not supported
 - torch - use all torch kernels, error out if something is not supported
 - auto - torchao decides what to do
 
we should not use torch kernels in the fbgemm setting, as that is not honoring what the user asked for
8ea051d    to
    74dd7dd      
    Compare
  
    | 
               | 
          ||
| """Use triton quantize and quantized mm kernels (if available), requires fbgemm_gpu_genai library, if no triton kernel for the quantize op or mm kernel is available, we'll fallback to torch ops | ||
| """ | ||
| TRITON = "triton" | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this name isn't coherent with the rest of the enum.  We already have an FBGEMM option which does not say anything about cutlass vs triton and therefore already includes these kernels.  I think you have two options:
- have the fbgemm option pick the best kernel (cutlass vs triton) for the user. I prefer this one.
 - make it clear that "FBGEMM" does not mean "FBGEMM", but really means "FBGEMM_CUTLASS", and also add "FBGEMM_TRITON". I don't really like this option.
 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK thanks, yeah 1 seems easiest for now, will update to that. unless there is request to distinguish these in the future
74dd7dd    to
    0b2ab3e      
    Compare
  
    | if ( | ||
| isinstance(granularity, PerTensor) | ||
| and kernel_preference == KernelPreference.FBGEMM | ||
| ): | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: lets Xfail this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we are using unittest, seems like we can't do return unittest.expectedFailure("...")?
.../ao/test/quantization/quantize_/workflows/float8/test_float8_tensor.py", line 92, in test_fp8_linear_variants
    return unittest.expectedFailure(
  File ".../python3.10/unittest/case.py", line 148, in expectedFailure
    test_item.__unittest_expecting_failure__ = True
AttributeError: 'str' object has no attribute '__unittest_expecting_failure__'
but let me know if there is an example to do expectedFailure conditionally instead of skipping entire test
          
 Can you explain specifically what did not work, and why it works after this PR? It would also be good to have a test which fails before this PR and passes after this PR.  | 
    
a1f4504    to
    a73fa51      
    Compare
  
    | 
           @vkuzo sure, updated the PR summary and added a test for this one and next PR as well  | 
    
          
 note that I don't think this is true, all values of   | 
    
| config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity) | ||
| self._test_moe_weight_reshape_ops(config) | ||
| 
               | 
          ||
| def test_expected_gpu_kernel_fbgemm(self): | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this test should be together with the other tests we have which check the same thing for other settings of this config, currently in test_affine_quantized_float.py.  Can we add a TODO to unify?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah I think we can put everything here after we deprecate the AQT path in 9 months
a73fa51    to
    f685f8b      
    Compare
  
    
          
 makes sense, it is user facing  | 
    
f685f8b    to
    8bea88e      
    Compare
  
    Summary: Previously if user specifies kernel_preference == "fbgemm", we'll use torch ops like `_choose_scale_float8` and `_quantize_affine_float8` to quantize the high precision Tensor into a float8 Tensor this PR makes sure we use fbgemm kernels when kernel_preference is "fbgemm", meaning: `torch.ops.triton.quantize_fp8_row` for per row, and `torch.ops.fbgemm.quantize_fp8_per_tensor` for per tensor (while `torch.ops.fbgemm.quantize_fp8_per_tensor` has some issues right now and we'll enable later when it's fixed) This doesn't have impact on BC, meaning old serialized model can still be loaded and run, only thing is fixing the kernel choice for fbgemm kernel preference means users who requested FBGEMM kernelpreference now actually run fbgemm quantize op instead of torch op Test Plan: python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_expected_gpu_kernel_fbgemm Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2883, branch: jerryzh168/stack/59
8bea88e    to
    4516f6e      
    Compare
  
    
Stacked PRs:
Fix Float8Tensor quantize op kernrel preference dispatch
Summary:
Previously if user specifies kernel_preference == "fbgemm", we'll use torch ops like
_choose_scale_float8and_quantize_affine_float8to quantize the high precision Tensorinto a float8 Tensor
this PR makes sure we use fbgemm kernels when kernel_preference is "fbgemm", meaning:
torch.ops.triton.quantize_fp8_rowfor per row, andtorch.ops.fbgemm.quantize_fp8_per_tensorfor per tensor (while
torch.ops.fbgemm.quantize_fp8_per_tensorhas some issues right now and we'll enable later when it's fixed)This doesn't have impact on BC, meaning old serialized model can still be loaded and run, only thing is fixing the kernel choice for fbgemm kernel preference
means users who requested FBGEMM kernelpreference now actually run fbgemm quantize op instead of torch op
Test Plan:
python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_expected_gpu_kernel_fbgemm
Reviewers:
Subscribers:
Tasks:
Tags: