Skip to content

[Quantization] Add Quanto backend #10756

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 41 commits into from
Mar 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
ff50418
update
DN6 Feb 5, 2025
ba5bba7
updaet
DN6 Feb 5, 2025
aa8cdaf
update
DN6 Feb 5, 2025
39e20e2
update
DN6 Feb 8, 2025
f52050a
update
DN6 Feb 8, 2025
f4c14c2
update
DN6 Feb 10, 2025
f67d97c
update
DN6 Feb 10, 2025
5cff237
update
DN6 Feb 10, 2025
f734c09
update
DN6 Feb 10, 2025
7472f18
update
DN6 Feb 10, 2025
e96686e
update
DN6 Feb 10, 2025
4ae8691
update
DN6 Feb 10, 2025
7b841dc
Update docs/source/en/quantization/quanto.md
DN6 Feb 11, 2025
e090177
update
DN6 Feb 11, 2025
559f124
Merge https://github.com/huggingface/diffusers into add-quanto
DN6 Feb 11, 2025
9e5a3d0
Merge branch 'add-quanto' of https://github.com/huggingface/diffusers…
DN6 Feb 11, 2025
b136d23
update
DN6 Feb 11, 2025
2c7f303
update
DN6 Feb 11, 2025
c80d4d4
update
DN6 Feb 12, 2025
d355e6a
update
DN6 Feb 13, 2025
9a72fef
Merge branch 'main' into add-quanto
DN6 Feb 13, 2025
79901e4
update
DN6 Feb 18, 2025
c4b6e24
update
DN6 Feb 20, 2025
c29684f
Merge branch 'main' into add-quanto
DN6 Feb 20, 2025
6cf9a78
update
DN6 Feb 20, 2025
0736f87
update
DN6 Feb 20, 2025
4eabed7
update
DN6 Feb 25, 2025
f512c28
update
DN6 Feb 25, 2025
dbaef7c
update
DN6 Feb 25, 2025
963559f
update
DN6 Feb 25, 2025
156db08
Merge branch 'main' into add-quanto
DN6 Mar 3, 2025
4516f22
update
DN6 Mar 7, 2025
830b734
update
DN6 Mar 7, 2025
8afff1b
Merge branch 'main' into add-quanto
DN6 Mar 7, 2025
8163687
update
DN6 Mar 7, 2025
bb7fb66
update
DN6 Mar 7, 2025
6cad1d5
update
DN6 Mar 7, 2025
d5ab9ca
Update src/diffusers/quantizers/quanto/utils.py
DN6 Mar 7, 2025
deebc22
update
DN6 Mar 7, 2025
cf4694e
Merge branch 'add-quanto' of https://github.com/huggingface/diffusers…
DN6 Mar 7, 2025
1b46a32
update
DN6 Mar 10, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/nightly_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,8 @@ jobs:
test_location: "gguf"
- backend: "torchao"
test_location: "torchao"
- backend: "optimum_quanto"
test_location: "quanto"
runs-on:
group: aws-g6e-xlarge-plus
container:
Expand Down
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@
title: gguf
- local: quantization/torchao
title: torchao
- local: quantization/quanto
title: quanto
title: Quantization Methods
- sections:
- local: optimization/fp16
Expand Down
5 changes: 5 additions & 0 deletions docs/source/en/api/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui
## GGUFQuantizationConfig

[[autodoc]] GGUFQuantizationConfig

## QuantoConfig

[[autodoc]] QuantoConfig

## TorchAoConfig

[[autodoc]] TorchAoConfig
Expand Down
1 change: 1 addition & 0 deletions docs/source/en/quantization/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,6 @@ Diffusers currently supports the following quantization methods.
- [BitsandBytes](./bitsandbytes)
- [TorchAO](./torchao)
- [GGUF](./gguf)
- [Quanto](./quanto.md)

[This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
148 changes: 148 additions & 0 deletions docs/source/en/quantization/quanto.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

-->

# Quanto

[Quanto](https://github.com/huggingface/optimum-quanto) is a PyTorch quantization backend for [Optimum](https://huggingface.co/docs/optimum/en/index). It has been designed with versatility and simplicity in mind:

- All features are available in eager mode (works with non-traceable models)
- Supports quantization aware training
- Quantized models are compatible with `torch.compile`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have we verified this? Last time I checked only weight-quantized models were compatible with torch.compile. Cc: @dacorvo.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, but this should be fixed in pytorch 2.6 (I did not check though).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dacorvo I tried to run torch compile with float8 weights in the following way and hit an error during inference

import torch
from diffusers import FluxPipeline, FluxTransformer2DModel, QuantoConfig
from optimum.quanto import quantize, freeze, qint8, qint4, qfloat8

model_id = "black-forest-labs/FLUX.1-dev"
transformer = FluxTransformer2DModel.from_pretrained(
    model_id,
    subfolder="transformer",
    torch_dtype=torch.bfloat16,
)
quantize(transformer, weights=qfloat8)
freeze(transformer)

transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
pipe = FluxPipeline.from_pretrained(
    model_id, transformer=transformer, torch_dtype=torch.bfloat16
)
pipe.to("cuda")
images = pipe("A cat holding a sign that says hello").images[0]
images.save("flux-quanto-compile.png")

Traceback:

  File "/home/dhruv/miniconda3/envs/mochi/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 2082, in validate
    raise AssertionError(
torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function linear>(*(FakeTensor(..., device='cuda:0', size=(1, 4096, 64), dtype=torch.bfloat16
), MarlinF8QBytesTensor(MarlinF8PackedTensor(FakeTensor(..., device='cuda:0', size=(4, 12288), dtype=torch.int32)), scale=FakeTensor(..., device='cuda:0', size=(1, 3072
), dtype=torch.bfloat16), dtype=torch.bfloat16)), **{'bias': Parameter(FakeTensor(..., device='cuda:0', size=(3072,), dtype=torch.bfloat16,
           requires_grad=True))}):
Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode with 'allow_non_fake_inputs'. Found in quanto.gemm_f16f8_marlin.default(FakeTensor(..., de
vice='cuda:0', size=(4096, 64), dtype=torch.bfloat16), FakeTensor(..., device='cuda:0', size=(4, 12288), dtype=torch.int32), FakeTensor(..., device='cuda:0', size=(1, 3
072), dtype=torch.bfloat16), tensor([...], device='cuda:0', size=(768,), dtype=torch.int32), 8, 4096, 3072, 64)

from user code:
   File "/home/dhruv/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 482, in forward
    hidden_states = self.x_embedder(hidden_states)
  File "/home/dhruv/miniconda3/envs/mochi/lib/python3.11/site-packages/optimum/quanto/nn/qlinear.py", line 50, in forward
    return torch.nn.functional.linear(input, self.qweight, bias=self.bias)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

The torch.compile step seems to work. The error is raised during the forward pass.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same with nightly?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah same errors with nightly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's be specific that only int8 supports torch.compile for now?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mentioned in the compile section of the docs

- Quantized models are Device agnostic (e.g CUDA,XPU,MPS,CPU)

In order to use the Quanto backend, you will first need to install `optimum-quanto>=0.2.6` and `accelerate`

```shell
pip install optimum-quanto accelerate
```

Now you can quantize a model by passing the `QuantoConfig` object to the `from_pretrained()` method. Although the Quanto library does allow quantizing `nn.Conv2d` and `nn.LayerNorm` modules, currently, Diffusers only supports quantizing the weights in the `nn.Linear` layers of a model. The following snippet demonstrates how to apply `float8` quantization with Quanto.

```python
import torch
from diffusers import FluxTransformer2DModel, QuantoConfig

model_id = "black-forest-labs/FLUX.1-dev"
quantization_config = QuantoConfig(weights_dtype="float8")
transformer = FluxTransformer2DModel.from_pretrained(
model_id,
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
)

pipe = FluxPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch_dtype)
pipe.to("cuda")

prompt = "A cat holding a sign that says hello world"
image = pipe(
prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512
).images[0]
image.save("output.png")
```

## Skipping Quantization on specific modules

It is possible to skip applying quantization on certain modules using the `modules_to_not_convert` argument in the `QuantoConfig`. Please ensure that the modules passed in to this argument match the keys of the modules in the `state_dict`

```python
import torch
from diffusers import FluxTransformer2DModel, QuantoConfig

model_id = "black-forest-labs/FLUX.1-dev"
quantization_config = QuantoConfig(weights_dtype="float8", modules_to_not_convert=["proj_out"])
transformer = FluxTransformer2DModel.from_pretrained(
model_id,
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
)
```

## Using `from_single_file` with the Quanto Backend

`QuantoConfig` is compatible with `~FromOriginalModelMixin.from_single_file`.

```python
import torch
from diffusers import FluxTransformer2DModel, QuantoConfig

ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
quantization_config = QuantoConfig(weights_dtype="float8")
transformer = FluxTransformer2DModel.from_single_file(ckpt_path, quantization_config=quantization_config, torch_dtype=torch.bfloat16)
```

## Saving Quantized models

Diffusers supports serializing Quanto models using the `~ModelMixin.save_pretrained` method.

The serialization and loading requirements are different for models quantized directly with the Quanto library and models quantized
with Diffusers using Quanto as the backend. It is currently not possible to load models quantized directly with Quanto into Diffusers using `~ModelMixin.from_pretrained`

```python
import torch
from diffusers import FluxTransformer2DModel, QuantoConfig

model_id = "black-forest-labs/FLUX.1-dev"
quantization_config = QuantoConfig(weights_dtype="float8")
transformer = FluxTransformer2DModel.from_pretrained(
model_id,
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
)
# save quantized model to reuse
transformer.save_pretrained("<your quantized model save path>")

# you can reload your quantized model with
model = FluxTransformer2DModel.from_pretrained("<your quantized model save path>")
```

## Using `torch.compile` with Quanto

Currently the Quanto backend supports `torch.compile` for the following quantization types:

- `int8` weights

```python
import torch
from diffusers import FluxPipeline, FluxTransformer2DModel, QuantoConfig

model_id = "black-forest-labs/FLUX.1-dev"
quantization_config = QuantoConfig(weights_dtype="int8")
transformer = FluxTransformer2DModel.from_pretrained(
model_id,
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
)
transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great that this works.


pipe = FluxPipeline.from_pretrained(
model_id, transformer=transformer, torch_dtype=torch_dtype
)
pipe.to("cuda")
images = pipe("A cat holding a sign that says hello").images[0]
images.save("flux-quanto-compile.png")
```

## Supported Quantization Types

### Weights

- float8
- int8
- int4
- int2


9 changes: 9 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@
"GitPython<3.1.19",
"scipy",
"onnx",
"optimum_quanto>=0.2.6",
"gguf>=0.10.0",
"torchao>=0.7.0",
"bitsandbytes>=0.43.3",
"regex!=2019.12.17",
"requests",
"tensorboard",
Expand Down Expand Up @@ -235,6 +239,11 @@ def run(self):
)
extras["torch"] = deps_list("torch", "accelerate")

extras["bitsandbytes"] = deps_list("bitsandbytes", "accelerate")
extras["gguf"] = deps_list("gguf", "accelerate")
extras["optimum_quanto"] = deps_list("optimum_quanto", "accelerate")
extras["torchao"] = deps_list("torchao", "accelerate")

if os.name == "nt": # windows
extras["flax"] = [] # jax is not supported on windows
else:
Expand Down
94 changes: 92 additions & 2 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@

from typing import TYPE_CHECKING

from diffusers.quantizers import quantization_config
from diffusers.utils import dummy_gguf_objects
from diffusers.utils.import_utils import (
is_bitsandbytes_available,
is_gguf_available,
is_optimum_quanto_version,
is_torchao_available,
)

from .utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
Expand All @@ -11,6 +20,7 @@
is_librosa_available,
is_note_seq_available,
is_onnx_available,
is_optimum_quanto_available,
is_scipy_available,
is_sentencepiece_available,
is_torch_available,
Expand All @@ -32,7 +42,7 @@
"loaders": ["FromOriginalModelMixin"],
"models": [],
"pipelines": [],
"quantizers.quantization_config": ["BitsAndBytesConfig", "GGUFQuantizationConfig", "TorchAoConfig"],
"quantizers.quantization_config": [],
"schedulers": [],
"utils": [
"OptionalDependencyNotAvailable",
Expand All @@ -54,6 +64,55 @@
],
}

try:
if not is_bitsandbytes_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_bitsandbytes_objects

_import_structure["utils.dummy_bitsandbytes_objects"] = [
name for name in dir(dummy_bitsandbytes_objects) if not name.startswith("_")
]
else:
_import_structure["quantizers.quantization_config"].append("BitsAndBytesConfig")

try:
if not is_gguf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_gguf_objects

_import_structure["utils.dummy_gguf_objects"] = [
name for name in dir(dummy_gguf_objects) if not name.startswith("_")
]
else:
_import_structure["quantizers.quantization_config"].append("GGUFQuantizationConfig")

try:
if not is_torchao_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_torchao_objects

_import_structure["utils.dummy_torchao_objects"] = [
name for name in dir(dummy_torchao_objects) if not name.startswith("_")
]
else:
_import_structure["quantizers.quantization_config"].append("TorchAoConfig")

try:
if not is_optimum_quanto_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_optimum_quanto_objects

_import_structure["utils.dummy_optimum_quanto_objects"] = [
name for name in dir(dummy_optimum_quanto_objects) if not name.startswith("_")
]
else:
_import_structure["quantizers.quantization_config"].append("QuantoConfig")


try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()
Expand Down Expand Up @@ -598,7 +657,38 @@

if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .configuration_utils import ConfigMixin
from .quantizers.quantization_config import BitsAndBytesConfig, GGUFQuantizationConfig, TorchAoConfig

try:
if not is_bitsandbytes_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_bitsandbytes_objects import *
else:
from .quantizers.quantization_config import BitsAndBytesConfig

try:
if not is_gguf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_gguf_objects import *
else:
from .quantizers.quantization_config import GGUFQuantizationConfig

try:
if not is_torchao_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_torchao_objects import *
else:
from .quantizers.quantization_config import TorchAoConfig

try:
if not is_optimum_quanto_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_optimum_quanto_objects import *
else:
from .quantizers.quantization_config import QuantoConfig

try:
if not is_onnx_available():
Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
"GitPython": "GitPython<3.1.19",
"scipy": "scipy",
"onnx": "onnx",
"optimum_quanto": "optimum_quanto>=0.2.6",
"gguf": "gguf>=0.10.0",
"torchao": "torchao>=0.7.0",
"bitsandbytes": "bitsandbytes>=0.43.3",
"regex": "regex!=2019.12.17",
"requests": "requests",
"tensorboard": "tensorboard",
Expand Down
7 changes: 6 additions & 1 deletion src/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,9 @@ def load_model_dict_into_meta(
):
param = param.to(torch.float32)
set_module_kwargs["dtype"] = torch.float32
# For quantizers have save weights using torch.float8_e4m3fn
elif hf_quantizer is not None and param.dtype == getattr(torch, "float8_e4m3fn", None):
pass
Comment on lines +249 to +250
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would the param be handled in that case?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't wouldn't apply any casting and the parameter would be loaded as is into the model.

else:
param = param.to(dtype)
set_module_kwargs["dtype"] = dtype
Expand Down Expand Up @@ -292,7 +295,9 @@ def load_model_dict_into_meta(
elif is_quantized and (
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=param_device)
):
hf_quantizer.create_quantized_param(model, param, param_name, param_device, state_dict, unexpected_keys)
hf_quantizer.create_quantized_param(
model, param, param_name, param_device, state_dict, unexpected_keys, dtype=dtype
)
else:
set_module_tensor_to_device(model, param_name, param_device, value=param, **set_module_kwargs)

Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/quantizers/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,26 @@
GGUFQuantizationConfig,
QuantizationConfigMixin,
QuantizationMethod,
QuantoConfig,
TorchAoConfig,
)
from .quanto import QuantoQuantizer
from .torchao import TorchAoHfQuantizer


AUTO_QUANTIZER_MAPPING = {
"bitsandbytes_4bit": BnB4BitDiffusersQuantizer,
"bitsandbytes_8bit": BnB8BitDiffusersQuantizer,
"gguf": GGUFQuantizer,
"quanto": QuantoQuantizer,
"torchao": TorchAoHfQuantizer,
}

AUTO_QUANTIZATION_CONFIG_MAPPING = {
"bitsandbytes_4bit": BitsAndBytesConfig,
"bitsandbytes_8bit": BitsAndBytesConfig,
"gguf": GGUFQuantizationConfig,
"quanto": QuantoConfig,
"torchao": TorchAoConfig,
}

Expand Down
Loading
Loading