Skip to content

Commit

Permalink
Revert "Always enable safetensors for save_quantized() (huggingface#20)"
Browse files Browse the repository at this point in the history
This reverts commit 0d1be56.
  • Loading branch information
Qubitium committed Jun 16, 2024
1 parent c69ed3e commit 5dd4930
Showing 1 changed file with 54 additions and 32 deletions.
86 changes: 54 additions & 32 deletions auto_gptq/modeling/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,7 @@ def save_quantized(
save_dir: str,
safetensors_metadata: Optional[Dict[str, str]] = None,
format: Optional[str] = None,
use_safetensors: bool = True,
):
"""save quantized model and configs to local disk"""
os.makedirs(save_dir, exist_ok=True)
Expand Down Expand Up @@ -522,46 +523,53 @@ def save_quantized(
model.to(CPU)

if quantize_config.model_file_base_name is None:
model_base_name = "model"
if use_safetensors:
model_base_name = "model"
else:
model_base_name = "pytorch_model"
else:
model_base_name = quantize_config.model_file_base_name

model_save_name = model_base_name + ".safetensors"
state_dict = model.state_dict()
state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()}
if safetensors_metadata is None:
safetensors_metadata = {}
elif not isinstance(safetensors_metadata, dict):
raise TypeError("safetensors_metadata must be a dictionary.")
else:
logger.debug(f"Received safetensors_metadata: {safetensors_metadata}")
new_safetensors_metadata = {}
converted_keys = False
for key, value in safetensors_metadata.items():
if not isinstance(key, str) or not isinstance(value, str):
converted_keys = True
try:
new_key = str(key)
new_value = str(value)
except Exception as e:
raise TypeError(
f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}"
)
if new_key in new_safetensors_metadata:
logger.warning(
f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting."
)
new_safetensors_metadata[new_key] = new_value
safetensors_metadata = new_safetensors_metadata
if converted_keys:
logger.debug(
f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}"
)
if use_safetensors:
model_save_name = model_base_name + ".safetensors"
state_dict = model.state_dict()
state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()}
if safetensors_metadata is None:
safetensors_metadata = {}
elif not isinstance(safetensors_metadata, dict):
raise TypeError("safetensors_metadata must be a dictionary.")
else:
logger.debug(f"Received safetensors_metadata: {safetensors_metadata}")
new_safetensors_metadata = {}
converted_keys = False
for key, value in safetensors_metadata.items():
if not isinstance(key, str) or not isinstance(value, str):
converted_keys = True
try:
new_key = str(key)
new_value = str(value)
except Exception as e:
raise TypeError(
f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}"
)
if new_key in new_safetensors_metadata:
logger.warning(
f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting."
)
new_safetensors_metadata[new_key] = new_value
safetensors_metadata = new_safetensors_metadata
if converted_keys:
logger.debug(
f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}"
)

# Format is required to enable Accelerate to load the metadata
# otherwise it raises an OSError
safetensors_metadata["format"] = "pt"
safe_save(state_dict, join(save_dir, model_save_name), safetensors_metadata)
else:
model_save_name = model_base_name + ".bin"
torch.save(model.state_dict(), join(save_dir, model_save_name))

config.quantization_config = quantize_config.to_dict()
config.save_pretrained(save_dir)
Expand Down Expand Up @@ -1048,13 +1056,27 @@ def skip(*args, **kwargs):

# Any post-initialization that require device information, for example buffers initialization on device.
model = autogptq_post_init(model, use_act_order=quantize_config.desc_act)

model.eval()

# == step6: (optional) warmup triton == #
if use_triton and warmup_triton:
from ..nn_modules.qlinear.qlinear_tritonv2 import QuantLinear
QuantLinear.warmup(model, seqlen=model.seqlen)


# == step7: make model compatible with peft
# cls.make_sure_compatible_with_peft(
# model,
# use_triton,
# quantize_config.desc_act,
# quantize_config.group_size,
# bits=quantize_config.bits,
# disable_exllama=disable_exllama,
# disable_exllamav2=disable_exllamav2,
# use_marlin=use_marlin,
# )

return cls(
model,
True,
Expand Down

0 comments on commit 5dd4930

Please sign in to comment.