- 
                Notifications
    You must be signed in to change notification settings 
- Fork 357
Update Float8Tensor for GRPO training in unsloth #3158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| 🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3158
 Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 092ca75 with merge base f3fc5e7 ( NEW FAILURE - The following job has failed:
 
 This comment was automatically generated by Dr. CI and updates every 15 minutes. | 
ed3c237    to
    19500bf      
    Compare
  
    3d4cb8d    to
    7e0749d      
    Compare
  
    7e0749d    to
    c0f4b4e      
    Compare
  
    345bb63    to
    9d27057      
    Compare
  
    | output_tensor = torch.matmul(input_tensor, weight_tensor.t()) | ||
| output_tensor_fp8 = torch.matmul(input_tensor_fp8, weight_tensor_fp8.t()) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is not used through the quantize_ API?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if this can be accessed through quantize_ API then we can merge the test with test_linear_variants I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this can be accessed through the quantize_ API unfortunately, nn.Linear will dispatch to F.linear first
| output_tensor, input_tensor, weight_tensor = ( | ||
| args[0], | ||
| args[1], | ||
| args[2] if len(args) > 2 else None, | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is weight tensor optional?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we add assert for the last 2 kwargs as well https://github.com/pytorch/pytorch/blob/82ff07c7884d478ddd5d638bebbb938e55c9bebf/aten/src/ATen/native/native_functions.yaml#L7214
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also I thought one of mat1 and mat2 should be bias_tensor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah the first tensor is the bias, also added the asserts
        
          
                torchao/quantization/quantize_/workflows/float8/float8_tensor.py
              
                Outdated
          
            Show resolved
            Hide resolved
        
              
          
                torchao/quantization/quantize_/workflows/float8/float8_tensor.py
              
                Outdated
          
            Show resolved
            Hide resolved
        
              
          
                torchao/quantization/quantize_/workflows/float8/float8_tensor.py
              
                Outdated
          
            Show resolved
            Hide resolved
        
              
          
                torchao/quantization/quantize_/workflows/float8/float8_tensor.py
              
                Outdated
          
            Show resolved
            Hide resolved
        
      | if is_transposed: | ||
| return _float8_linear_impl(input_tensor, weight_tensor.t()) | ||
| else: | ||
| return torch.matmul(input_tensor, weight_tensor.dequantize()) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mm to matmul is also going to a higher level thing, better to call torch.mm here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_float8_mm_impl seems confusing, IMO this should be refactor to cleanly override individual torch or aten ops and ensure that the logic of when to do weight-only vs dynamic quant is consistent everywhere
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
plz clean up _float8_mm_impl
060b217    to
    092ca75      
    Compare
  
    
Summary: Support a few extra ops called during GRPO loop in unsloth/vllm for Float8Tensor.
Test Plan: