Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Experimental] Float8 support in AQT #671

Merged
merged 35 commits into from
Aug 28, 2024
Merged

Conversation

jainapurva
Copy link
Contributor

@jainapurva jainapurva commented Aug 14, 2024

Add float8 inference support to current Affine Quantized Tensor.

Test Plan : test/dtypes/test_affine_quantized.py

Copy link

pytorch-bot bot commented Aug 14, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/671

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 482f537 with merge base 9a56e80 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 14, 2024
@vkuzo
Copy link
Contributor

vkuzo commented Aug 14, 2024

I think a good next step would be to add numerical tests, and ensure that this new object matches the numerical behavior of Float8Tensor.

@HDCharles
Copy link
Contributor

can i ask whether we're sure we should include float8 tensors in AQT instead of another paradigm?

@HDCharles HDCharles closed this Aug 20, 2024
@jainapurva jainapurva reopened this Aug 20, 2024
@jainapurva
Copy link
Contributor Author

can i ask whether we're sure we should include float8 tensors in AQT instead of another paradigm?

AQT conceptually aligns a lot with fp8/fpx. Instead of writing a separate tensor subclass, it's more efficient to add float support to AQT. The concept of AQT is shared, the major difference is for dtype. This is an experimental PR, to test the feasibility. The design will be modified.

@jainapurva jainapurva marked this pull request as ready for review August 20, 2024 22:04
@HDCharles
Copy link
Contributor

(sorry didn't mean to close this, missclick)

@@ -269,6 +270,42 @@ def from_float_static(
dtype=input_float.dtype,
)

@classmethod
def from_float_float8(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jerryzh168 can you help with a design on how to have from_float, from_float_static, etc extend to this use case? Ideally we shouldn't special case a set of dtypes (float8) to have their own function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A combined function is better, will be refactoring it after testing float8

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah sure, I think we could have the two following final state:

  • have separate from_float_fpx and from_from_intx since they have a bit different arg list
  • if we manage to generalize the arg list enough so it is reasonable merge the two then we can merge as well, I will discuss with Apurva and Driss about the args but at the first glance maybe preserve_zero is always going to be true and zero_point_domain may not apply here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense.

one thought, the from_float name will become more confusing if both the source and the target can also be various floating point bitwidths. To clarify this in torchao.float8, I went with the high_precision|hp and low_precision|lp naming scheme

Copy link
Contributor

@cpuhrsch cpuhrsch Aug 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also why I like the idea of extending to and having our own factory functions that we can pass dtype enums. For example

"to_nf4",

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vkuzo yeah makes sense, we can rename from_float to from_high_precision as well. as @cpuhrsch mentioned, this is not user facing API, we'll have to factory functions for various dtypes as the user facing API:

to_affine_quantized = AffineQuantizedTensor.from_float
to_affine_quantized_static = AffineQuantizedTensor.from_float_static

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jainapurva as discussed from the meeting, let's merge this into from_float and add some guards on arguments

@jainapurva
Copy link
Contributor Author

@jainapurva can you add this to https://github.com/pytorch/ao/blob/main/torchao/_models/llama/generate.py and https://github.com/pytorch/ao/blob/main/torchao/_models/llama/eval.py as well and get some e2e accuracy and perf numbers for them

Testing perf in this PR: #732

def validate_float8_params(
input_float, mapping_type, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain, layout_type, use_hqq
):
assert input_float.is_floating_point(), "input_float must be a floating point tensor"
Copy link
Contributor

@jerryzh168 jerryzh168 Aug 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually based on this it seems most of the args are irrelevant for float8, maybe just splitting the op makes more sense?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, created a new op from_float_to_floatx. It'll be calling the existing from_float but with some pre-defined param values. Also for future we'll need to rename and refactor these methods.

Applies float8 weight-only symmetric per-channel quantization to linear layers.
"""
def apply_float8wo_quant(weight):
# avoid circular dep
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe you want to import to_affine_quantized_floatx here, I'm also refactoring this file to change the import to the file to avoid circular dep as well

@jainapurva jainapurva merged commit 0916b5b into main Aug 28, 2024
16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants