diff --git a/src/compressed_tensors/compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressor.py index dfbff781..ff01e352 100644 --- a/src/compressed_tensors/compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressor.py @@ -205,6 +205,10 @@ def parse_quantization_config( if hasattr(compression_config, QUANTIZATION_CONFIG_NAME): # for loaded HFQuantizer config return getattr(compression_config, QUANTIZATION_CONFIG_NAME) + elif isinstance(compression_config, dict) and ( + QUANTIZATION_CONFIG_NAME in compression_config + ): + return compression_config[QUANTIZATION_CONFIG_NAME] if QUANTIZATION_CONFIG_NAME in compression_config: # for loaded HFQuantizer config from dict diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index 9dd7b22d..d2c5f2cd 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -13,7 +13,7 @@ # limitations under the License. import torch -from torch.nn import Module +from torch.nn import Module, Parameter __all__ = [ @@ -106,7 +106,19 @@ def update_parameter_data( raise ValueError("Attempted to update uninitialized parameter") dtype = parameter.dtype - parameter.data = new_param_data.to(device).to(dtype) + try: + parameter.data = new_param_data.to(device).to(dtype) + except RuntimeError: + # exception may occur when trying to overwrite meta device, overriding + # parameter directly + setattr( + module, + param_name, + Parameter( + data=new_param_data.to(device).to(dtype), + requires_grad=parameter.requires_grad, + ), + ) if offloaded: prefix_dict = module._hf_hook.weights_map.dataset diff --git a/src/compressed_tensors/utils/safetensors_load.py b/src/compressed_tensors/utils/safetensors_load.py index 4fdb3007..0592592b 100644 --- a/src/compressed_tensors/utils/safetensors_load.py +++ b/src/compressed_tensors/utils/safetensors_load.py @@ -47,6 +47,10 @@ def get_safetensors_folder( model will be searched for in the default TRANSFORMERS_CACHE :return: local folder containing model data """ + if isinstance(pretrained_model_name_or_path, list): + # assume sharded files, referencing first file is sufficient + pretrained_model_name_or_path = pretrained_model_name_or_path[0] + if os.path.exists(pretrained_model_name_or_path): # argument is a path to a local folder return os.path.abspath(pretrained_model_name_or_path)