Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add TRT as a Backend #11173

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@
"peft>=0.6.0",
"protobuf>=3.20.3,<4",
"pytest",
"pulp",
"pytest-timeout",
"pytest-xdist",
"python>=3.8.0",
Expand All @@ -128,10 +129,12 @@
"GitPython<3.1.19",
"scipy",
"onnx",
"torchprofile>=0.0.4",
"optimum_quanto>=0.2.6",
"gguf>=0.10.0",
"torchao>=0.7.0",
"bitsandbytes>=0.43.3",
"nvidia_modelopt>=0.27.0",
"regex!=2019.12.17",
"requests",
"tensorboard",
Expand Down Expand Up @@ -243,6 +246,7 @@ def run(self):
extras["gguf"] = deps_list("gguf", "accelerate")
extras["optimum_quanto"] = deps_list("optimum_quanto", "accelerate")
extras["torchao"] = deps_list("torchao", "accelerate")
extras["nvidia_modelopt"] = deps_list("nvidia_modelopt", "onnx", "pulp", "torchprofile", "accelerate")

if os.name == "nt": # windows
extras["flax"] = [] # jax is not supported on windows
Expand Down
21 changes: 21 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
is_k_diffusion_available,
is_librosa_available,
is_note_seq_available,
is_nvidia_modelopt_available,
is_onnx_available,
is_optimum_quanto_available,
is_scipy_available,
Expand Down Expand Up @@ -107,6 +108,18 @@
else:
_import_structure["quantizers.quantization_config"].append("QuantoConfig")

try:
if not is_torch_available() and not is_accelerate_available() and not is_nvidia_modelopt_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_nvidia_modelopt_objects

_import_structure["utils.dummy_nvidia_modelopt_objects"] = [
name for name in dir(dummy_nvidia_modelopt_objects) if not name.startswith("_")
]
else:
_import_structure["quantizers.quantization_config"].append("NVIDIAModelOptConfig")

try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()
Expand Down Expand Up @@ -693,6 +706,14 @@
else:
from .quantizers.quantization_config import QuantoConfig

try:
if not is_nvidia_modelopt_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_nvidia_modelopt_objects import *
else:
from .quantizers.quantization_config import NVIDIAModelOptConfig

try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()
Expand Down
3 changes: 3 additions & 0 deletions src/diffusers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"peft": "peft>=0.6.0",
"protobuf": "protobuf>=3.20.3,<4",
"pytest": "pytest",
"pulp": "pulp",
"pytest-timeout": "pytest-timeout",
"pytest-xdist": "pytest-xdist",
"python": "python>=3.8.0",
Expand All @@ -35,10 +36,12 @@
"GitPython": "GitPython<3.1.19",
"scipy": "scipy",
"onnx": "onnx",
"torchprofile": "torchprofile>=0.0.4",
"optimum_quanto": "optimum_quanto>=0.2.6",
"gguf": "gguf>=0.10.0",
"torchao": "torchao>=0.7.0",
"bitsandbytes": "bitsandbytes>=0.43.3",
"nvidia_modelopt": "nvidia_modelopt>=0.27.0",
"regex": "regex!=2019.12.17",
"requests": "requests",
"tensorboard": "tensorboard",
Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/quantizers/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@

from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer
from .gguf import GGUFQuantizer
from .modelopt import NVIDIAModelOptQuantizer
from .quantization_config import (
BitsAndBytesConfig,
GGUFQuantizationConfig,
NVIDIAModelOptConfig,
QuantizationConfigMixin,
QuantizationMethod,
QuantoConfig,
Expand All @@ -39,6 +41,7 @@
"gguf": GGUFQuantizer,
"quanto": QuantoQuantizer,
"torchao": TorchAoHfQuantizer,
"modelopt": NVIDIAModelOptQuantizer,
}

AUTO_QUANTIZATION_CONFIG_MAPPING = {
Expand All @@ -47,6 +50,7 @@
"gguf": GGUFQuantizationConfig,
"quanto": QuantoConfig,
"torchao": TorchAoConfig,
"modelopt": NVIDIAModelOptConfig,
}


Expand Down
1 change: 1 addition & 0 deletions src/diffusers/quantizers/modelopt/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .modelopt_quantizer import NVIDIAModelOptQuantizer
162 changes: 162 additions & 0 deletions src/diffusers/quantizers/modelopt/modelopt_quantizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
from typing import TYPE_CHECKING, Any, Dict, List, Union

from ...utils import (
get_module_from_name,
is_accelerate_available,
is_nvidia_modelopt_available,
is_nvidia_modelopt_version,
is_torch_available,
logging,
)
from ..base import DiffusersQuantizer


if TYPE_CHECKING:
from ...models.modeling_utils import ModelMixin


if is_torch_available():
import torch

if is_accelerate_available():
from accelerate.utils import set_module_tensor_to_device


logger = logging.get_logger(__name__)


class NVIDIAModelOptQuantizer(DiffusersQuantizer):
r"""
Diffusers Quantizer for TensorRT Model Optimizer
"""

use_keep_in_fp32_modules = True
requires_calibration = False
required_packages = ["modelopt"]

def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)

def validate_environment(self, *args, **kwargs):
if not is_nvidia_modelopt_available():
raise ImportError(
"Loading an nvidia-modelopt quantized model requires nvidia-modelopt library (`pip install nvidia-modelopt`)"
)
if not is_nvidia_modelopt_version(">=", "0.25.0"):
raise ImportError(
"Loading an nvidia-modelopt quantized model requires `nvidia-modelopt>=0.25.0`. "
"Please upgrade your installation with `pip install --upgrade nvidia-modelopt"
)

self.offload = False

device_map = kwargs.get("device_map", None)
if isinstance(device_map, dict):
if "cpu" in device_map.values() or "disk" in device_map.values():
if self.pre_quantized:
raise ValueError(
"You are attempting to perform cpu/disk offload with a pre-quantized modelopt model "
"This is not supported yet. Please remove the CPU or disk device from the `device_map` argument."
)
else:
self.offload = True

def check_if_quantized_param(
self,
model: "ModelMixin",
param_value: "torch.Tensor",
param_name: str,
state_dict: Dict[str, Any],
**kwargs,
):
# ModelOpt imports diffusers internally. This is here to prevent circular imports
from modelopt.torch.quantization.nn import QuantInputBase, SequentialQuantizer, TensorQuantizer
from modelopt.torch.quantization.qtensor import BaseQuantizedTensor

def is_param_quantized(module):
for _module in module.modules():
if isinstance(_module, TensorQuantizer) and not _module._dequantize:
return True
elif isinstance(_module, SequentialQuantizer):
for q in _module:
if isinstance(q, TensorQuantizer) and not q._dequantize:
return True
return False

module, tensor_name = get_module_from_name(model, param_name)
if self.pre_quantized and any(isinstance(module, t) for t in [BaseQuantizedTensor]):
return True
elif isinstance(module, QuantInputBase) and "weight" in tensor_name:
return is_param_quantized(module)
return False

def create_quantized_param(
self,
model: "ModelMixin",
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
*args,
**kwargs,
):
"""
Create the quantized parameter by calling .calibrate() after setting it to the module.
"""
# ModelOpt imports diffusers internally. This is here to prevent circular imports
import modelopt.torch.quantization as mtq

dtype = kwargs.get("dtype", torch.float32)
module, tensor_name = get_module_from_name(model, param_name)
if self.pre_quantized:
setattr(module, tensor_name, param_value)
else:
set_module_tensor_to_device(model, param_name, target_device, param_value, dtype)
mtq.compress(module)
module.weight.requires_grad = False

def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
max_memory = {key: val * 0.90 for key, val in max_memory.items()}
return max_memory

def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
if self.quantization_config.quant_type == "FP8":
target_dtype = torch.float8_e4m3fn
return target_dtype

def update_torch_dtype(self, torch_dtype: "torch.dtype" = None) -> "torch.dtype":
if torch_dtype is None:
logger.info("You did not specify `torch_dtype` in `from_pretrained`. Setting it to `torch.float32`.")
torch_dtype = torch.float32
return torch_dtype

def _process_model_before_weight_loading(
self,
model: "ModelMixin",
device_map,
keep_in_fp32_modules: List[str] = [],
**kwargs,
):
# ModelOpt imports diffusers internally. This is here to prevent circular imports
import modelopt.torch.quantization as mtq

self.modules_to_not_convert = self.quantization_config.modules_to_not_convert

if not isinstance(self.modules_to_not_convert, list):
self.modules_to_not_convert = [self.modules_to_not_convert]

self.modules_to_not_convert.extend(keep_in_fp32_modules)

config = self.quantization_config.get_config_from_quant_type()
mtq.quantize(model, config)
model.config.quantization_config = self.quantization_config

def _process_model_after_weight_loading(self, model, **kwargs):
return model

@property
def is_trainable(self):
return True

@property
def is_serializable(self):
return True
95 changes: 95 additions & 0 deletions src/diffusers/quantizers/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class QuantizationMethod(str, Enum):
GGUF = "gguf"
TORCHAO = "torchao"
QUANTO = "quanto"
MODELOPT = "modelopt"


if is_torchao_available():
Expand Down Expand Up @@ -722,3 +723,97 @@ def post_init(self):
accepted_weights = ["float8", "int8", "int4", "int2"]
if self.weights_dtype not in accepted_weights:
raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights_dtype}")


@dataclass
class NVIDIAModelOptConfig(QuantizationConfigMixin):
"""This is a config class to use nvidia modelopt for quantization.

Args:
QuantizationConfigMixin (_type_): _description_
Comment on lines +732 to +733
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes it is in progress

"""

def __init__(self, quant_type: str, modules_to_not_convert: Optional[List[str]] = None, **kwargs) -> None:
self.quant_method = QuantizationMethod.MODELOPT
self.quant_type = quant_type
QUANT_TYPES = [
"FP8_WO",
# "FP8_AINT8",
"INT8_WO",
# "INT8_AFP8",
# "INT8_AFP8_QKVFP8",
"INT4_WO",
# "INT4_AFP8",
# "INT4_AFP8_QKVFP8",
Comment on lines +741 to +747
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these not supported?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am currently testing these and will clean it up

]
if quant_type not in QUANT_TYPES:
logger.warning(
f"Quantization type {quant_type} not supported. Supported types are {QUANT_TYPES}, picking FP8_WO as default"
)
self.quant_type = "FP8_WO"
self.modules_to_not_convert = modules_to_not_convert
self.advanced_quant = kwargs

def get_config_from_quant_type(self) -> Dict[str, Any]:
"""
Get the config from the quantization type.
"""
# ModelOpt imports diffusers internally. This is here to prevent circular imports
external_conf = self.advanced_quant.pop("modelopt_config", None)
if external_conf:
return external_conf

BASE_CONFIG = {
"quant_cfg": {
"*weight_quantizer": {"fake_quant": False},
"*input_quantizer": {},
"*output_quantizer": {"enable": False},
"*q_bmm_quantizer": {},
"*k_bmm_quantizer": {},
"*v_bmm_quantizer": {},
"*softmax_quantizer": {},
"default": {"enable": False},
},
"algorithm": "max",
}

quant_cfg = BASE_CONFIG["quant_cfg"]
if "FP8" in self.quant_type:
for k in quant_cfg:
if "enable" not in quant_cfg[k]:
quant_cfg[k]["num_bits"] = (4, 3)
elif "INT8" in self.quant_type:
for k in quant_cfg:
if "enable" not in quant_cfg[k]:
quant_cfg[k]["num_bits"] = 8
elif "INT4" in self.quant_type:
for k in quant_cfg:
if "enable" not in quant_cfg[k]:
quant_cfg[k]["num_bits"] = 4
else:
raise ValueError(f"Unknown quantization type: {self.quant_type}")

if "WO" in self.quant_type:
for k in quant_cfg:
if "*weight_quantizer" not in k:
quant_cfg[k]["enable"] = False

per_channel = self.advanced_quant.pop("per_channel", False)
if per_channel:
quant_cfg["*weight_quantizer"]["axis"] = self.advanced_quant.pop("axis", -1)
quant_cfg["*input_quantizer"]["axis"] = self.advanced_quant.pop("axis", -1)

block_quantize = self.advanced_quant.pop("block_quantize", False)
if block_quantize:
quant_cfg["*weight_quantizer"]["block_sizes"] = {
self.advanced_quant.pop("axis", -1): self.advanced_quant.pop("block_size", 128)
}
quant_cfg["*input_quantizer"]["block_sizes"] = {
self.advanced_quant.pop("axis", -1): self.advanced_quant.pop("block_size", 128)
}

if self.modules_to_not_convert is not None:
for module in self.modules_to_not_convert:
quant_cfg["*" + module + "*"] = {"enable": False}

return BASE_CONFIG
2 changes: 2 additions & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@
is_librosa_available,
is_matplotlib_available,
is_note_seq_available,
is_nvidia_modelopt_available,
is_nvidia_modelopt_version,
is_onnx_available,
is_optimum_quanto_available,
is_optimum_quanto_version,
Expand Down
Loading