From 3546d741ee81280d86282f7e7880dd1820eefc26 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Tue, 20 Aug 2024 12:59:37 -0400 Subject: [PATCH 1/2] additional fixes for HFQuantizer compatibility --- .../compressors/model_compressor.py | 4 ++++ .../quantization/observers/helpers.py | 6 +++--- src/compressed_tensors/utils/offload.py | 16 ++++++++++++++-- 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressor.py index 9807cbec..ed4197e1 100644 --- a/src/compressed_tensors/compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressor.py @@ -188,6 +188,10 @@ def parse_quantization_config(compression_config: Dict) -> Union[Dict, None]: if hasattr(compression_config, QUANTIZATION_CONFIG_NAME): # for loaded HFQuantizer config return getattr(compression_config, QUANTIZATION_CONFIG_NAME) + elif isinstance(compression_config, dict) and ( + QUANTIZATION_CONFIG_NAME in compression_config + ): + return compression_config[QUANTIZATION_CONFIG_NAME] # SparseAutoModel format quantization_config = deepcopy(compression_config) diff --git a/src/compressed_tensors/quantization/observers/helpers.py b/src/compressed_tensors/quantization/observers/helpers.py index e33839e1..13c05991 100644 --- a/src/compressed_tensors/quantization/observers/helpers.py +++ b/src/compressed_tensors/quantization/observers/helpers.py @@ -38,9 +38,9 @@ def get_observer_token_count(module: torch.nn.Module) -> Counter: token_counts = Counter() for name, module in module.named_modules(): if name.endswith(".input_observer"): - token_counts[name.replace(".input_observer", "")] = ( - module._num_observed_tokens - ) + token_counts[ + name.replace(".input_observer", "") + ] = module._num_observed_tokens return token_counts diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index 1d1f7245..7b0b765d 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -13,7 +13,7 @@ # limitations under the License. import torch -from torch.nn import Module +from torch.nn import Module, Parameter __all__ = [ @@ -100,7 +100,19 @@ def update_parameter_data( parameter = getattr(module, param_name, None) dtype = parameter.dtype - parameter.data = new_param_data.to(device).to(dtype) + try: + parameter.data = new_param_data.to(device).to(dtype) + except RuntimeError: + # exception may occur when trying to overwrite meta device, overriding + # parameter directly + setattr( + module, + param_name, + Parameter( + data=new_param_data.to(device).to(dtype), + requires_grad=parameter.requires_grad, + ), + ) if offloaded: prefix_dict = module._hf_hook.weights_map.dataset From bc0caac5f8d96c53ec3558fba253371cc1a05916 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Tue, 20 Aug 2024 14:10:50 -0400 Subject: [PATCH 2/2] sharded files fix --- src/compressed_tensors/utils/safetensors_load.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/compressed_tensors/utils/safetensors_load.py b/src/compressed_tensors/utils/safetensors_load.py index 9cdac782..b31c6f77 100644 --- a/src/compressed_tensors/utils/safetensors_load.py +++ b/src/compressed_tensors/utils/safetensors_load.py @@ -47,6 +47,10 @@ def get_safetensors_folder( model will be searched for in the default TRANSFORMERS_CACHE :return: local folder containing model data """ + if isinstance(pretrained_model_name_or_path, list): + # assume sharded files, referencing first file is sufficient + pretrained_model_name_or_path = pretrained_model_name_or_path[0] + if os.path.exists(pretrained_model_name_or_path): # argument is a path to a local folder return os.path.abspath(pretrained_model_name_or_path)