-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Add TRT as a Backend #11173
Draft
ishan-modi
wants to merge
6
commits into
huggingface:main
Choose a base branch
from
ishan-modi:add-trtquant-backend
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
[WIP] Add TRT as a Backend #11173
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .modelopt_quantizer import NVIDIAModelOptQuantizer |
162 changes: 162 additions & 0 deletions
162
src/diffusers/quantizers/modelopt/modelopt_quantizer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
from typing import TYPE_CHECKING, Any, Dict, List, Union | ||
|
||
from ...utils import ( | ||
get_module_from_name, | ||
is_accelerate_available, | ||
is_nvidia_modelopt_available, | ||
is_nvidia_modelopt_version, | ||
is_torch_available, | ||
logging, | ||
) | ||
from ..base import DiffusersQuantizer | ||
|
||
|
||
if TYPE_CHECKING: | ||
from ...models.modeling_utils import ModelMixin | ||
|
||
|
||
if is_torch_available(): | ||
import torch | ||
|
||
if is_accelerate_available(): | ||
from accelerate.utils import set_module_tensor_to_device | ||
|
||
|
||
logger = logging.get_logger(__name__) | ||
|
||
|
||
class NVIDIAModelOptQuantizer(DiffusersQuantizer): | ||
r""" | ||
Diffusers Quantizer for TensorRT Model Optimizer | ||
""" | ||
|
||
use_keep_in_fp32_modules = True | ||
requires_calibration = False | ||
required_packages = ["modelopt"] | ||
|
||
def __init__(self, quantization_config, **kwargs): | ||
super().__init__(quantization_config, **kwargs) | ||
|
||
def validate_environment(self, *args, **kwargs): | ||
if not is_nvidia_modelopt_available(): | ||
raise ImportError( | ||
"Loading an nvidia-modelopt quantized model requires nvidia-modelopt library (`pip install nvidia-modelopt`)" | ||
) | ||
if not is_nvidia_modelopt_version(">=", "0.25.0"): | ||
raise ImportError( | ||
"Loading an nvidia-modelopt quantized model requires `nvidia-modelopt>=0.25.0`. " | ||
"Please upgrade your installation with `pip install --upgrade nvidia-modelopt" | ||
) | ||
|
||
self.offload = False | ||
|
||
device_map = kwargs.get("device_map", None) | ||
if isinstance(device_map, dict): | ||
if "cpu" in device_map.values() or "disk" in device_map.values(): | ||
if self.pre_quantized: | ||
raise ValueError( | ||
"You are attempting to perform cpu/disk offload with a pre-quantized modelopt model " | ||
"This is not supported yet. Please remove the CPU or disk device from the `device_map` argument." | ||
) | ||
else: | ||
self.offload = True | ||
|
||
def check_if_quantized_param( | ||
self, | ||
model: "ModelMixin", | ||
param_value: "torch.Tensor", | ||
param_name: str, | ||
state_dict: Dict[str, Any], | ||
**kwargs, | ||
): | ||
# ModelOpt imports diffusers internally. This is here to prevent circular imports | ||
from modelopt.torch.quantization.nn import QuantInputBase, SequentialQuantizer, TensorQuantizer | ||
from modelopt.torch.quantization.qtensor import BaseQuantizedTensor | ||
|
||
def is_param_quantized(module): | ||
for _module in module.modules(): | ||
if isinstance(_module, TensorQuantizer) and not _module._dequantize: | ||
return True | ||
elif isinstance(_module, SequentialQuantizer): | ||
for q in _module: | ||
if isinstance(q, TensorQuantizer) and not q._dequantize: | ||
return True | ||
return False | ||
|
||
module, tensor_name = get_module_from_name(model, param_name) | ||
if self.pre_quantized and any(isinstance(module, t) for t in [BaseQuantizedTensor]): | ||
return True | ||
elif isinstance(module, QuantInputBase) and "weight" in tensor_name: | ||
return is_param_quantized(module) | ||
return False | ||
|
||
def create_quantized_param( | ||
self, | ||
model: "ModelMixin", | ||
param_value: "torch.Tensor", | ||
param_name: str, | ||
target_device: "torch.device", | ||
*args, | ||
**kwargs, | ||
): | ||
""" | ||
Create the quantized parameter by calling .calibrate() after setting it to the module. | ||
""" | ||
# ModelOpt imports diffusers internally. This is here to prevent circular imports | ||
import modelopt.torch.quantization as mtq | ||
|
||
dtype = kwargs.get("dtype", torch.float32) | ||
module, tensor_name = get_module_from_name(model, param_name) | ||
if self.pre_quantized: | ||
setattr(module, tensor_name, param_value) | ||
else: | ||
set_module_tensor_to_device(model, param_name, target_device, param_value, dtype) | ||
mtq.compress(module) | ||
module.weight.requires_grad = False | ||
|
||
def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]: | ||
max_memory = {key: val * 0.90 for key, val in max_memory.items()} | ||
return max_memory | ||
|
||
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": | ||
if self.quantization_config.quant_type == "FP8": | ||
target_dtype = torch.float8_e4m3fn | ||
return target_dtype | ||
|
||
def update_torch_dtype(self, torch_dtype: "torch.dtype" = None) -> "torch.dtype": | ||
if torch_dtype is None: | ||
logger.info("You did not specify `torch_dtype` in `from_pretrained`. Setting it to `torch.float32`.") | ||
torch_dtype = torch.float32 | ||
return torch_dtype | ||
|
||
def _process_model_before_weight_loading( | ||
self, | ||
model: "ModelMixin", | ||
device_map, | ||
keep_in_fp32_modules: List[str] = [], | ||
**kwargs, | ||
): | ||
# ModelOpt imports diffusers internally. This is here to prevent circular imports | ||
import modelopt.torch.quantization as mtq | ||
|
||
self.modules_to_not_convert = self.quantization_config.modules_to_not_convert | ||
|
||
if not isinstance(self.modules_to_not_convert, list): | ||
self.modules_to_not_convert = [self.modules_to_not_convert] | ||
|
||
self.modules_to_not_convert.extend(keep_in_fp32_modules) | ||
|
||
config = self.quantization_config.get_config_from_quant_type() | ||
mtq.quantize(model, config) | ||
model.config.quantization_config = self.quantization_config | ||
|
||
def _process_model_after_weight_loading(self, model, **kwargs): | ||
return model | ||
|
||
@property | ||
def is_trainable(self): | ||
return True | ||
|
||
@property | ||
def is_serializable(self): | ||
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -46,6 +46,7 @@ class QuantizationMethod(str, Enum): | |
GGUF = "gguf" | ||
TORCHAO = "torchao" | ||
QUANTO = "quanto" | ||
MODELOPT = "modelopt" | ||
|
||
|
||
if is_torchao_available(): | ||
|
@@ -722,3 +723,97 @@ def post_init(self): | |
accepted_weights = ["float8", "int8", "int4", "int2"] | ||
if self.weights_dtype not in accepted_weights: | ||
raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights_dtype}") | ||
|
||
|
||
@dataclass | ||
class NVIDIAModelOptConfig(QuantizationConfigMixin): | ||
"""This is a config class to use nvidia modelopt for quantization. | ||
|
||
Args: | ||
QuantizationConfigMixin (_type_): _description_ | ||
""" | ||
|
||
def __init__(self, quant_type: str, modules_to_not_convert: Optional[List[str]] = None, **kwargs) -> None: | ||
self.quant_method = QuantizationMethod.MODELOPT | ||
self.quant_type = quant_type | ||
QUANT_TYPES = [ | ||
"FP8_WO", | ||
# "FP8_AINT8", | ||
"INT8_WO", | ||
# "INT8_AFP8", | ||
# "INT8_AFP8_QKVFP8", | ||
"INT4_WO", | ||
# "INT4_AFP8", | ||
# "INT4_AFP8_QKVFP8", | ||
Comment on lines
+741
to
+747
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are these not supported? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am currently testing these and will clean it up |
||
] | ||
if quant_type not in QUANT_TYPES: | ||
logger.warning( | ||
f"Quantization type {quant_type} not supported. Supported types are {QUANT_TYPES}, picking FP8_WO as default" | ||
) | ||
self.quant_type = "FP8_WO" | ||
self.modules_to_not_convert = modules_to_not_convert | ||
self.advanced_quant = kwargs | ||
|
||
def get_config_from_quant_type(self) -> Dict[str, Any]: | ||
""" | ||
Get the config from the quantization type. | ||
""" | ||
# ModelOpt imports diffusers internally. This is here to prevent circular imports | ||
external_conf = self.advanced_quant.pop("modelopt_config", None) | ||
if external_conf: | ||
return external_conf | ||
|
||
BASE_CONFIG = { | ||
"quant_cfg": { | ||
"*weight_quantizer": {"fake_quant": False}, | ||
"*input_quantizer": {}, | ||
"*output_quantizer": {"enable": False}, | ||
"*q_bmm_quantizer": {}, | ||
"*k_bmm_quantizer": {}, | ||
"*v_bmm_quantizer": {}, | ||
"*softmax_quantizer": {}, | ||
"default": {"enable": False}, | ||
}, | ||
"algorithm": "max", | ||
} | ||
|
||
quant_cfg = BASE_CONFIG["quant_cfg"] | ||
if "FP8" in self.quant_type: | ||
for k in quant_cfg: | ||
if "enable" not in quant_cfg[k]: | ||
quant_cfg[k]["num_bits"] = (4, 3) | ||
elif "INT8" in self.quant_type: | ||
for k in quant_cfg: | ||
if "enable" not in quant_cfg[k]: | ||
quant_cfg[k]["num_bits"] = 8 | ||
elif "INT4" in self.quant_type: | ||
for k in quant_cfg: | ||
if "enable" not in quant_cfg[k]: | ||
quant_cfg[k]["num_bits"] = 4 | ||
else: | ||
raise ValueError(f"Unknown quantization type: {self.quant_type}") | ||
|
||
if "WO" in self.quant_type: | ||
for k in quant_cfg: | ||
if "*weight_quantizer" not in k: | ||
quant_cfg[k]["enable"] = False | ||
|
||
per_channel = self.advanced_quant.pop("per_channel", False) | ||
if per_channel: | ||
quant_cfg["*weight_quantizer"]["axis"] = self.advanced_quant.pop("axis", -1) | ||
quant_cfg["*input_quantizer"]["axis"] = self.advanced_quant.pop("axis", -1) | ||
|
||
block_quantize = self.advanced_quant.pop("block_quantize", False) | ||
if block_quantize: | ||
quant_cfg["*weight_quantizer"]["block_sizes"] = { | ||
self.advanced_quant.pop("axis", -1): self.advanced_quant.pop("block_size", 128) | ||
} | ||
quant_cfg["*input_quantizer"]["block_sizes"] = { | ||
self.advanced_quant.pop("axis", -1): self.advanced_quant.pop("block_size", 128) | ||
} | ||
|
||
if self.modules_to_not_convert is not None: | ||
for module in self.modules_to_not_convert: | ||
quant_cfg["*" + module + "*"] = {"enable": False} | ||
|
||
return BASE_CONFIG |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes it is in progress