Skip to content

Conversation

@andrewor14
Copy link
Contributor

@andrewor14 andrewor14 commented Oct 12, 2025

Summary: Support a few extra ops called during GRPO loop in unsloth/vllm for Float8Tensor.

Test Plan:

python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_fp8_matmul_variants
python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_to_dtype_layout
python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_has_compatible_shallow_copy_type
python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_transpose

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 12, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3158

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 092ca75 with merge base f3fc5e7 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@andrewor14 andrewor14 marked this pull request as draft October 12, 2025 22:21
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 12, 2025
@andrewor14 andrewor14 added the topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) label Oct 12, 2025
@andrewor14 andrewor14 changed the title [HACK] Update Float8Tensor for GRPO training in unsloth [draft] Update Float8Tensor for GRPO training in unsloth Oct 13, 2025
@andrewor14 andrewor14 changed the title [draft] Update Float8Tensor for GRPO training in unsloth Update Float8Tensor for GRPO training in unsloth Oct 29, 2025
@andrewor14 andrewor14 requested a review from jerryzh168 October 29, 2025 20:15
@andrewor14 andrewor14 marked this pull request as ready for review October 29, 2025 20:15
Comment on lines +204 to +205
output_tensor = torch.matmul(input_tensor, weight_tensor.t())
output_tensor_fp8 = torch.matmul(input_tensor_fp8, weight_tensor_fp8.t())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not used through the quantize_ API?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this can be accessed through quantize_ API then we can merge the test with test_linear_variants I think

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this can be accessed through the quantize_ API unfortunately, nn.Linear will dispatch to F.linear first

output_tensor, input_tensor, weight_tensor = (
args[0],
args[1],
args[2] if len(args) > 2 else None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is weight tensor optional?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also I thought one of mat1 and mat2 should be bias_tensor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah the first tensor is the bias, also added the asserts

if is_transposed:
return _float8_linear_impl(input_tensor, weight_tensor.t())
else:
return torch.matmul(input_tensor, weight_tensor.dequantize())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mm to matmul is also going to a higher level thing, better to call torch.mm here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_float8_mm_impl seems confusing, IMO this should be refactor to cleanly override individual torch or aten ops and ensure that the logic of when to do weight-only vs dynamic quant is consistent everywhere

Copy link
Contributor

@vkuzo vkuzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

plz clean up _float8_mm_impl

@andrewor14 andrewor14 requested a review from vkuzo October 30, 2025 23:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants