-
Notifications
You must be signed in to change notification settings - Fork 100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Experimental] Float8 support in AQT #671
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/671
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 482f537 with merge base 9a56e80 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
I think a good next step would be to add numerical tests, and ensure that this new object matches the numerical behavior of |
can i ask whether we're sure we should include float8 tensors in AQT instead of another paradigm? |
AQT conceptually aligns a lot with fp8/fpx. Instead of writing a separate tensor subclass, it's more efficient to add float support to AQT. The concept of AQT is shared, the major difference is for dtype. This is an experimental PR, to test the feasibility. The design will be modified. |
(sorry didn't mean to close this, missclick) |
@@ -269,6 +270,42 @@ def from_float_static( | |||
dtype=input_float.dtype, | |||
) | |||
|
|||
@classmethod | |||
def from_float_float8( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jerryzh168 can you help with a design on how to have from_float
, from_float_static
, etc extend to this use case? Ideally we shouldn't special case a set of dtypes (float8) to have their own function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A combined function is better, will be refactoring it after testing float8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah sure, I think we could have the two following final state:
- have separate
from_float_fpx
andfrom_from_intx
since they have a bit different arg list - if we manage to generalize the arg list enough so it is reasonable merge the two then we can merge as well, I will discuss with Apurva and Driss about the args but at the first glance maybe preserve_zero is always going to be true and zero_point_domain may not apply here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense.
one thought, the from_float
name will become more confusing if both the source and the target can also be various floating point bitwidths. To clarify this in torchao.float8
, I went with the high_precision|hp
and low_precision|lp
naming scheme
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is also why I like the idea of extending to
and having our own factory functions that we can pass dtype enums. For example
Line 16 in 5c0e060
"to_nf4", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vkuzo yeah makes sense, we can rename from_float
to from_high_precision
as well. as @cpuhrsch mentioned, this is not user facing API, we'll have to
factory functions for various dtypes as the user facing API:
ao/torchao/dtypes/affine_quantized_tensor.py
Lines 990 to 991 in 5c0e060
to_affine_quantized = AffineQuantizedTensor.from_float | |
to_affine_quantized_static = AffineQuantizedTensor.from_float_static |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jainapurva as discussed from the meeting, let's merge this into from_float
and add some guards on arguments
Testing perf in this PR: #732 |
def validate_float8_params( | ||
input_float, mapping_type, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain, layout_type, use_hqq | ||
): | ||
assert input_float.is_floating_point(), "input_float must be a floating point tensor" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually based on this it seems most of the args are irrelevant for float8, maybe just splitting the op makes more sense?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, created a new op from_float_to_floatx. It'll be calling the existing from_float but with some pre-defined param values. Also for future we'll need to rename and refactor these methods.
Applies float8 weight-only symmetric per-channel quantization to linear layers. | ||
""" | ||
def apply_float8wo_quant(weight): | ||
# avoid circular dep |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe you want to import to_affine_quantized_floatx
here, I'm also refactoring this file to change the import to the file to avoid circular dep as well
Add float8 inference support to current Affine Quantized Tensor.
Test Plan : test/dtypes/test_affine_quantized.py