-
Notifications
You must be signed in to change notification settings - Fork 5.9k
[Quantization] Add Quanto backend #10756
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ff50418
ba5bba7
aa8cdaf
39e20e2
f52050a
f4c14c2
f67d97c
5cff237
f734c09
7472f18
e96686e
4ae8691
7b841dc
e090177
559f124
9e5a3d0
b136d23
2c7f303
c80d4d4
d355e6a
9a72fef
79901e4
c4b6e24
c29684f
6cf9a78
0736f87
4eabed7
f512c28
dbaef7c
963559f
156db08
4516f22
830b734
8afff1b
8163687
bb7fb66
6cad1d5
d5ab9ca
deebc22
cf4694e
1b46a32
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
<!--Copyright 2025 The HuggingFace Team. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
the License. You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations under the License. | ||
|
||
--> | ||
|
||
# Quanto | ||
|
||
[Quanto](https://github.com/huggingface/optimum-quanto) is a PyTorch quantization backend for [Optimum](https://huggingface.co/docs/optimum/en/index). It has been designed with versatility and simplicity in mind: | ||
|
||
- All features are available in eager mode (works with non-traceable models) | ||
- Supports quantization aware training | ||
- Quantized models are compatible with `torch.compile` | ||
- Quantized models are Device agnostic (e.g CUDA,XPU,MPS,CPU) | ||
|
||
In order to use the Quanto backend, you will first need to install `optimum-quanto>=0.2.6` and `accelerate` | ||
|
||
```shell | ||
pip install optimum-quanto accelerate | ||
``` | ||
|
||
Now you can quantize a model by passing the `QuantoConfig` object to the `from_pretrained()` method. Although the Quanto library does allow quantizing `nn.Conv2d` and `nn.LayerNorm` modules, currently, Diffusers only supports quantizing the weights in the `nn.Linear` layers of a model. The following snippet demonstrates how to apply `float8` quantization with Quanto. | ||
|
||
```python | ||
import torch | ||
from diffusers import FluxTransformer2DModel, QuantoConfig | ||
|
||
model_id = "black-forest-labs/FLUX.1-dev" | ||
quantization_config = QuantoConfig(weights_dtype="float8") | ||
transformer = FluxTransformer2DModel.from_pretrained( | ||
model_id, | ||
subfolder="transformer", | ||
quantization_config=quantization_config, | ||
torch_dtype=torch.bfloat16, | ||
) | ||
|
||
pipe = FluxPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch_dtype) | ||
pipe.to("cuda") | ||
|
||
prompt = "A cat holding a sign that says hello world" | ||
image = pipe( | ||
prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512 | ||
).images[0] | ||
image.save("output.png") | ||
``` | ||
|
||
## Skipping Quantization on specific modules | ||
|
||
It is possible to skip applying quantization on certain modules using the `modules_to_not_convert` argument in the `QuantoConfig`. Please ensure that the modules passed in to this argument match the keys of the modules in the `state_dict` | ||
|
||
```python | ||
import torch | ||
from diffusers import FluxTransformer2DModel, QuantoConfig | ||
|
||
model_id = "black-forest-labs/FLUX.1-dev" | ||
quantization_config = QuantoConfig(weights_dtype="float8", modules_to_not_convert=["proj_out"]) | ||
transformer = FluxTransformer2DModel.from_pretrained( | ||
model_id, | ||
subfolder="transformer", | ||
quantization_config=quantization_config, | ||
torch_dtype=torch.bfloat16, | ||
) | ||
``` | ||
|
||
## Using `from_single_file` with the Quanto Backend | ||
|
||
`QuantoConfig` is compatible with `~FromOriginalModelMixin.from_single_file`. | ||
|
||
```python | ||
import torch | ||
from diffusers import FluxTransformer2DModel, QuantoConfig | ||
|
||
ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors" | ||
quantization_config = QuantoConfig(weights_dtype="float8") | ||
transformer = FluxTransformer2DModel.from_single_file(ckpt_path, quantization_config=quantization_config, torch_dtype=torch.bfloat16) | ||
``` | ||
|
||
## Saving Quantized models | ||
|
||
Diffusers supports serializing Quanto models using the `~ModelMixin.save_pretrained` method. | ||
|
||
The serialization and loading requirements are different for models quantized directly with the Quanto library and models quantized | ||
with Diffusers using Quanto as the backend. It is currently not possible to load models quantized directly with Quanto into Diffusers using `~ModelMixin.from_pretrained` | ||
|
||
```python | ||
import torch | ||
from diffusers import FluxTransformer2DModel, QuantoConfig | ||
|
||
model_id = "black-forest-labs/FLUX.1-dev" | ||
quantization_config = QuantoConfig(weights_dtype="float8") | ||
transformer = FluxTransformer2DModel.from_pretrained( | ||
model_id, | ||
subfolder="transformer", | ||
quantization_config=quantization_config, | ||
torch_dtype=torch.bfloat16, | ||
) | ||
# save quantized model to reuse | ||
transformer.save_pretrained("<your quantized model save path>") | ||
|
||
# you can reload your quantized model with | ||
model = FluxTransformer2DModel.from_pretrained("<your quantized model save path>") | ||
``` | ||
|
||
## Using `torch.compile` with Quanto | ||
|
||
Currently the Quanto backend supports `torch.compile` for the following quantization types: | ||
|
||
- `int8` weights | ||
|
||
```python | ||
import torch | ||
from diffusers import FluxPipeline, FluxTransformer2DModel, QuantoConfig | ||
|
||
model_id = "black-forest-labs/FLUX.1-dev" | ||
quantization_config = QuantoConfig(weights_dtype="int8") | ||
transformer = FluxTransformer2DModel.from_pretrained( | ||
model_id, | ||
subfolder="transformer", | ||
quantization_config=quantization_config, | ||
torch_dtype=torch.bfloat16, | ||
) | ||
transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great that this works. |
||
|
||
pipe = FluxPipeline.from_pretrained( | ||
model_id, transformer=transformer, torch_dtype=torch_dtype | ||
) | ||
pipe.to("cuda") | ||
images = pipe("A cat holding a sign that says hello").images[0] | ||
images.save("flux-quanto-compile.png") | ||
``` | ||
|
||
## Supported Quantization Types | ||
|
||
### Weights | ||
|
||
- float8 | ||
- int8 | ||
- int4 | ||
- int2 | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -245,6 +245,9 @@ def load_model_dict_into_meta( | |
): | ||
param = param.to(torch.float32) | ||
set_module_kwargs["dtype"] = torch.float32 | ||
# For quantizers have save weights using torch.float8_e4m3fn | ||
elif hf_quantizer is not None and param.dtype == getattr(torch, "float8_e4m3fn", None): | ||
pass | ||
Comment on lines
+249
to
+250
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How would the param be handled in that case? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't wouldn't apply any casting and the parameter would be loaded as is into the model. |
||
else: | ||
param = param.to(dtype) | ||
set_module_kwargs["dtype"] = dtype | ||
|
@@ -292,7 +295,9 @@ def load_model_dict_into_meta( | |
elif is_quantized and ( | ||
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=param_device) | ||
): | ||
hf_quantizer.create_quantized_param(model, param, param_name, param_device, state_dict, unexpected_keys) | ||
hf_quantizer.create_quantized_param( | ||
model, param, param_name, param_device, state_dict, unexpected_keys, dtype=dtype | ||
) | ||
else: | ||
set_module_tensor_to_device(model, param_name, param_device, value=param, **set_module_kwargs) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have we verified this? Last time I checked only weight-quantized models were compatible with
torch.compile
. Cc: @dacorvo.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True, but this should be fixed in pytorch 2.6 (I did not check though).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dacorvo I tried to run torch compile with float8 weights in the following way and hit an error during inference
Traceback:
The
torch.compile
step seems to work. The error is raised during the forward pass.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same with nightly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah same errors with nightly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's be specific that only int8 supports torch.compile for now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mentioned in the compile section of the docs
diffusers/docs/source/en/quantization/quanto.md
Line 105 in bb7fb66