Skip to content

[Quantization] Add TRT-ModelOpt as a Backend #11173

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@
"gguf>=0.10.0",
"torchao>=0.7.0",
"bitsandbytes>=0.43.3",
"nvidia_modelopt[torch, hf]>=0.27.0",
Copy link

@kevalmorabia97 kevalmorabia97 Apr 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"nvidia_modelopt[torch, hf]>=0.27.0",
"nvidia_modelopt[torch]>=0.27.0",

I think only torch optional deps are needed here. Diffusers is anyways installed here and I think diffusers depending indirectly on transformers, accelerate, datasets, etc hf dependencies doesn't make sense. nvidia-modelopt[hf] might pin the dependencies differently than here which may potentially cause conflicts

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Diffusers is anyways installed here and I think diffusers depending indirectly on transformers, accelerate, datasets, etc hf dependencies doesn't make sense.

transformers, datasets are not required packages diffusers just as an FYI. We prefer keeping our dependencies lean.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keeping this as it is for now, lmkyt ?

"regex!=2019.12.17",
"requests",
"tensorboard",
Expand Down Expand Up @@ -244,6 +245,7 @@ def run(self):
extras["gguf"] = deps_list("gguf", "accelerate")
extras["optimum_quanto"] = deps_list("optimum_quanto", "accelerate")
extras["torchao"] = deps_list("torchao", "accelerate")
extras["nvidia_modelopt"] = deps_list("nvidia_modelopt[torch, hf]")

if os.name == "nt": # windows
extras["flax"] = [] # jax is not supported on windows
Expand Down
21 changes: 21 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
is_k_diffusion_available,
is_librosa_available,
is_note_seq_available,
is_nvidia_modelopt_available,
is_onnx_available,
is_opencv_available,
is_optimum_quanto_available,
Expand Down Expand Up @@ -108,6 +109,18 @@
else:
_import_structure["quantizers.quantization_config"].append("QuantoConfig")

try:
if not is_torch_available() and not is_accelerate_available() and not is_nvidia_modelopt_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_nvidia_modelopt_objects

_import_structure["utils.dummy_nvidia_modelopt_objects"] = [
name for name in dir(dummy_nvidia_modelopt_objects) if not name.startswith("_")
]
else:
_import_structure["quantizers.quantization_config"].append("NVIDIAModelOptConfig")

try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()
Expand Down Expand Up @@ -725,6 +738,14 @@
else:
from .quantizers.quantization_config import QuantoConfig

try:
if not is_nvidia_modelopt_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_nvidia_modelopt_objects import *
else:
from .quantizers.quantization_config import NVIDIAModelOptConfig

try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"gguf": "gguf>=0.10.0",
"torchao": "torchao>=0.7.0",
"bitsandbytes": "bitsandbytes>=0.43.3",
"nvidia_modelopt[torch, hf]": "nvidia_modelopt[torch, hf]>=0.27.0",
"regex": "regex!=2019.12.17",
"requests": "requests",
"tensorboard": "tensorboard",
Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/quantizers/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@

from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer
from .gguf import GGUFQuantizer
from .modelopt import NVIDIAModelOptQuantizer
from .quantization_config import (
BitsAndBytesConfig,
GGUFQuantizationConfig,
NVIDIAModelOptConfig,
QuantizationConfigMixin,
QuantizationMethod,
QuantoConfig,
Expand All @@ -39,6 +41,7 @@
"gguf": GGUFQuantizer,
"quanto": QuantoQuantizer,
"torchao": TorchAoHfQuantizer,
"modelopt": NVIDIAModelOptQuantizer,
}

AUTO_QUANTIZATION_CONFIG_MAPPING = {
Expand All @@ -47,6 +50,7 @@
"gguf": GGUFQuantizationConfig,
"quanto": QuantoConfig,
"torchao": TorchAoConfig,
"modelopt": NVIDIAModelOptConfig,
}


Expand Down
1 change: 1 addition & 0 deletions src/diffusers/quantizers/modelopt/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .modelopt_quantizer import NVIDIAModelOptQuantizer
157 changes: 157 additions & 0 deletions src/diffusers/quantizers/modelopt/modelopt_quantizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
from typing import TYPE_CHECKING, Any, Dict, List, Union

from ...utils import (
get_module_from_name,
is_accelerate_available,
is_nvidia_modelopt_available,
is_nvidia_modelopt_version,
is_torch_available,
logging,
)
from ..base import DiffusersQuantizer


if TYPE_CHECKING:
from ...models.modeling_utils import ModelMixin


if is_torch_available():
import torch

if is_accelerate_available():
from accelerate.utils import set_module_tensor_to_device


logger = logging.get_logger(__name__)


class NVIDIAModelOptQuantizer(DiffusersQuantizer):
r"""
Diffusers Quantizer for TensorRT Model Optimizer
"""

use_keep_in_fp32_modules = True
requires_calibration = False
required_packages = ["modelopt"]

def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)

def validate_environment(self, *args, **kwargs):
if not is_nvidia_modelopt_available():
raise ImportError(
"Loading an nvidia-modelopt quantized model requires nvidia-modelopt library (`pip install nvidia-modelopt`)"
)
if not is_nvidia_modelopt_version(">=", "0.25.0"):
raise ImportError(
"Loading an nvidia-modelopt quantized model requires `nvidia-modelopt>=0.25.0`. "
"Please upgrade your installation with `pip install --upgrade nvidia-modelopt"
)

self.offload = False

device_map = kwargs.get("device_map", None)
if isinstance(device_map, dict):
if "cpu" in device_map.values() or "disk" in device_map.values():
if self.pre_quantized:
raise ValueError(
"You are attempting to perform cpu/disk offload with a pre-quantized modelopt model "
"This is not supported yet. Please remove the CPU or disk device from the `device_map` argument."
)
else:
self.offload = True

def check_if_quantized_param(
self,
model: "ModelMixin",
param_value: "torch.Tensor",
param_name: str,
state_dict: Dict[str, Any],
**kwargs,
):
# ModelOpt imports diffusers internally. This is here to prevent circular imports
from modelopt.torch.quantization.qtensor import BaseQuantizedTensor
from modelopt.torch.quantization.utils import is_quantized

module, tensor_name = get_module_from_name(model, param_name)
if self.pre_quantized and any(isinstance(module, t) for t in [BaseQuantizedTensor]):
return True
elif is_quantized(module) and "weight" in tensor_name:
return True
return False

def create_quantized_param(
self,
model: "ModelMixin",
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
*args,
**kwargs,
):
"""
Create the quantized parameter by calling .calibrate() after setting it to the module.
"""
# ModelOpt imports diffusers internally. This is here to prevent circular imports
import modelopt.torch.quantization as mtq

dtype = kwargs.get("dtype", torch.float32)
module, tensor_name = get_module_from_name(model, param_name)
if self.pre_quantized:
setattr(module, tensor_name, param_value)
else:
set_module_tensor_to_device(model, param_name, target_device, param_value, dtype)
mtq.calibrate(module, self.quantization_config.modelopt_config["algorithm"], self.quantization_config.forward_loop)
mtq.compress(module)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mtq.compress compresses the model weights into lower-bit representations, allowing users to leverage it directly at the Torch level. However, as previously mentioned, to achieve actual speed improvements, we need to utilize the TensorRT runtime rather than the Torch runtime.

module.weight.requires_grad = False

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, this should actually be part of ModelOpt.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you meant L113-116 should be a part of modelopt?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes.

Copy link
Contributor Author

@ishan-modi ishan-modi May 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @realAsma, any specific reason why you think it should be a part of modelOpt. The reason why I kept it separate following the convention followed in diffusers

We initially want to process and add all quant layers and then we want to actually calibrate/compress modules. There are two methods inside the NVIDIAModelOptQuantizer class for aforementioned.

let me know if I am missing something do you mean using mtq.quantize instead since compress is no longer part of calibrate ?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please ignore my comments. That was an internal note for ModelOpt.


def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
max_memory = {key: val * 0.90 for key, val in max_memory.items()}
return max_memory

def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
if self.quantization_config.quant_type == "FP8":
target_dtype = torch.float8_e4m3fn
return target_dtype

def update_torch_dtype(self, torch_dtype: "torch.dtype" = None) -> "torch.dtype":
if torch_dtype is None:
logger.info("You did not specify `torch_dtype` in `from_pretrained`. Setting it to `torch.float32`.")
torch_dtype = torch.float32
return torch_dtype

def _process_model_before_weight_loading(
self,
model: "ModelMixin",
device_map,
keep_in_fp32_modules: List[str] = [],
**kwargs,
):
# ModelOpt imports diffusers internally. This is here to prevent circular imports
import modelopt.torch.opt as mto

modules_to_not_convert = self.quantization_config.modules_to_not_convert

if modules_to_not_convert is None:
modules_to_not_convert = []
if isinstance(modules_to_not_convert, str):
modules_to_not_convert = [modules_to_not_convert]
modules_to_not_convert.extend(keep_in_fp32_modules)

for module in modules_to_not_convert:
self.quantization_config.modelopt_config["quant_cfg"]["*" + module + "*"] = {"enable": False}
self.quantization_config.modules_to_not_convert = modules_to_not_convert

mto.apply_mode(model, mode=[("quantize", self.quantization_config.modelopt_config)])
model.config.quantization_config = self.quantization_config

def _process_model_after_weight_loading(self, model, **kwargs):
return model

@property
def is_trainable(self):
return True

@property
def is_serializable(self):
return True
Loading