You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We recently switched to leveraging Safetensors by default for the PyTorchModelHubMixin class in huggingface_hub (huggingface/huggingface_hub#2033), which is a minimal class that adds from_pretrained and push_to_hub methods to any custom nn.Module.
However, when trying out this class on the Gemma series of models by Google, I get the following error when calling push_to_hub (which first saves the tensors in the safetensors format before uploading the files to the hub):
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
[<ipython-input-8-eac8a21155c9>](https://localhost:8080/#) in <cell line: 1>()
----> 1 model.push_to_hub(f"nielsr/gemma-2b-it")
8 frames
[/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_deprecation.py](https://localhost:8080/#) in inner_f(*args, **kwargs)
99 message += "\n\n" + custom_message
100 warnings.warn(message, FutureWarning)
--> 101 return f(*args, **kwargs)
102
103 return inner_f
[/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_validators.py](https://localhost:8080/#) in _inner_fn(*args, **kwargs)
117 kwargs = smoothly_deprecate_use_auth_token(fn_name=fn.__name__, has_token=has_token, kwargs=kwargs)
118
--> 119 return fn(*args, **kwargs)
120
121 return _inner_fn # type: ignore
[/usr/local/lib/python3.10/dist-packages/huggingface_hub/hub_mixin.py](https://localhost:8080/#) in push_to_hub(self, repo_id, config, commit_message, private, token, branch, create_pr, allow_patterns, ignore_patterns, delete_patterns, api_endpoint)
517 with SoftTemporaryDirectory() as tmp:
518 saved_path = Path(tmp) / repo_id
--> 519 self.save_pretrained(saved_path, config=config)
520 return api.upload_folder(
521 repo_id=repo_id,
[/usr/local/lib/python3.10/dist-packages/huggingface_hub/hub_mixin.py](https://localhost:8080/#) in save_pretrained(self, save_directory, config, repo_id, push_to_hub, **push_to_hub_kwargs)
247
248 # save model weights/files (framework-specific)
--> 249 self._save_pretrained(save_directory)
250
251 # save config (if provided)
[/usr/local/lib/python3.10/dist-packages/huggingface_hub/hub_mixin.py](https://localhost:8080/#) in _save_pretrained(self, save_directory)
590 """Save weights from a Pytorch model to a local directory."""
591 model_to_save = self.module if hasattr(self, "module") else self # type: ignore
--> 592 save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
593
594 @classmethod
[/usr/local/lib/python3.10/dist-packages/safetensors/torch.py](https://localhost:8080/#) in save_model(model, filename, metadata, force_contiguous)
153 """
154 state_dict = model.state_dict()
--> 155 to_removes = _remove_duplicate_names(state_dict)
156
157 for kept_name, to_remove_group in to_removes.items():
[/usr/local/lib/python3.10/dist-packages/safetensors/torch.py](https://localhost:8080/#) in _remove_duplicate_names(state_dict, preferred_names, discard_names)
98 to_remove = defaultdict(list)
99 for shared in shareds:
--> 100 complete_names = set([name for name in shared if _is_complete(state_dict[name])])
101 if not complete_names:
102 raise RuntimeError(
[/usr/local/lib/python3.10/dist-packages/safetensors/torch.py](https://localhost:8080/#) in <listcomp>(.0)
98 to_remove = defaultdict(list)
99 for shared in shareds:
--> 100 complete_names = set([name for name in shared if _is_complete(state_dict[name])])
101 if not complete_names:
102 raise RuntimeError(
[/usr/local/lib/python3.10/dist-packages/safetensors/torch.py](https://localhost:8080/#) in _is_complete(tensor)
79
80 def _is_complete(tensor: torch.Tensor) -> bool:
---> 81 return tensor.data_ptr() == storage_ptr(tensor) and tensor.nelement() * _SIZE[tensor.dtype] == storage_size(tensor)
82
83
KeyError: torch.complex64
System Info
safetensors v0.4.2
huggingface_hub v0.22.0.dev0
Information
Reproduction
We recently switched to leveraging Safetensors by default for the
PyTorchModelHubMixin
class in huggingface_hub (huggingface/huggingface_hub#2033), which is a minimal class that addsfrom_pretrained
andpush_to_hub
methods to any customnn.Module
.However, when trying out this class on the Gemma series of models by Google, I get the following error when calling
push_to_hub
(which first saves the tensors in the safetensors format before uploading the files to the hub):Here's a notebook for reproduction.
Expected behavior
This model has some tensors of type
torch.complex64
, would be great to save those.The text was updated successfully, but these errors were encountered: