-
Notifications
You must be signed in to change notification settings - Fork 172
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Float8 dynamic autoquant #946
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/946
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit ae18023 with merge base fbe97a0 (): FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ffca55c
to
58fe60d
Compare
torchao/quantization/autoquant.py
Outdated
return weight | ||
|
||
@classmethod | ||
def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
did you test this in practice? we may need different constants for int8 and float8 dynamic, how does it perform in benchmarks and stuff. If you haven't really tested this on 2-3 models it may be better to just remove it and use the default method which will be very conservative under the interpolation mode and will still work reasonably under the relu mode.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've tested it on Llama, the numbers aren't great, but I can push it to next PR, with more benchmarks
torchao/quantization/utils.py
Outdated
@@ -139,13 +139,12 @@ def _get_per_token_block_size(x: torch.Tensor) -> List[int]: | |||
# taken from | |||
# https://github.com/mit-han-lab/smoothquant/blob/2f87951dacfb9238d8d657f52ae83a82a3c9ba0c/smoothquant/fake_quant.py#L26 | |||
# and slightly modified | |||
def quantize_activation_per_token_absmax(t): | |||
def quantize_activation_per_token_absmax(t, dtype=torch.int8): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this actually work in practice with non int8 dtypes?, we're still using the same quant min/max as +-128, this seems inadvisable.
I also don't think we should extend this function, should probably just call into whatever quant function is normally used for that dtype, this is a specific instance of function where the mapping types and quant min/max are hard coded to specific values so it shouldn't be extended.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need this anymore, I'm reverting the changes to this method, as there's another implementation in float8 that I'll be using.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see comments, i would probably skip the quant test bit unless it gets tested on 2-3 models and the way the activation quantization is implemented seems like its going to cause issues because its only superficially been altered away from its normal int8 quantization.
ebcfb9e
to
bfe1eee
Compare
torchao/quantization/autoquant.py
Outdated
@@ -492,6 +494,46 @@ def from_float(cls, weight): | |||
block_size = (1, weight.shape[1]) | |||
return super(AQFloat8WeightOnlyQuantizedLinearWeight, cls).from_hp_to_floatx(weight, block_size, target_dtype=cls.target_dtype, layout_type=Float8LayoutType()) | |||
|
|||
class AQFloat8DynamicallyQuantizedLinearWeight(AQMixin, LinearActivationQuantizedTensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this looks good, lets name this PerRow scaling and lets only import PerRow above
Added support to autoquant for float8 dynamically quantized linear weight, i.e float8 weight and activation.
Added fallback path in safe_int_mm to support float8 with torch.compile
Benchmark:
On llama3.1b - bfloat16 model
For Float8 Dynamic Quant
For Float8 weight only