Skip to content

Commit

Permalink
add tests + comments + load_model
Browse files Browse the repository at this point in the history
  • Loading branch information
Wauplin committed Mar 5, 2024
1 parent 2849871 commit 3cc9f71
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 10 deletions.
2 changes: 1 addition & 1 deletion docs/source/en/guides/download.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ repo. For example if `filename="data/train.csv"` and `local_dir="path/to/folder"
- If `local_dir_use_symlinks=True` is set, all files are symlinked for an optimal disk space optimization. This is
for example useful when downloading a huge dataset with thousands of small files.
- Finally, if you don't want symlinks at all you can disable them (`local_dir_use_symlinks=False`). The cache directory
will still be used to check wether the file is already cached or not. If already cached, the file is **duplicated**
will still be used to check whether the file is already cached or not. If already cached, the file is **duplicated**
from the cache (i.e. saves bandwidth but increases disk usage). If the file is not already cached, it will be
downloaded and moved directly to the local dir. This means that if you need to reuse it somewhere else later, it
will be **re-downloaded**.
Expand Down
19 changes: 10 additions & 9 deletions src/huggingface_hub/hub_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
import torch # type: ignore

if is_safetensors_available():
from safetensors import safe_open
from safetensors.torch import save_model
from safetensors.torch import load_model as load_model_as_safetensor
from safetensors.torch import save_model as save_model_as_safetensor


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -463,7 +463,7 @@ class PyTorchModelHubMixin(ModelHubMixin):
def _save_pretrained(self, save_directory: Path) -> None:
"""Save weights from a Pytorch model to a local directory."""
model_to_save = self.module if hasattr(self, "module") else self # type: ignore
save_model(model_to_save, save_directory / SAFETENSORS_SINGLE_FILE)
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))

@classmethod
def _from_pretrained(
Expand Down Expand Up @@ -524,10 +524,11 @@ def _load_as_pickle(cls, model: T, model_file: str, map_location: str, strict: b

@classmethod
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
state_dict = {}
with safe_open(model_file, framework="pt", device=map_location) as f: # type: ignore [attr-defined]
for k in f.keys():
state_dict[k] = f.get_tensor(k)
model.load_state_dict(state_dict, strict=strict) # type: ignore
model.eval() # type: ignore
if map_location != "cpu":
logger.warning(
f"Loading model weights on '{map_location}' is not supported by `PytorchHubMixin`."
" Loading on CPU instead. Loading on other devices is planned to be supported in future releases."
" See https://github.com/huggingface/huggingface_hub/pull/2086 for more details."
)
load_model_as_safetensor(model, model_file, strict=strict) # type: ignore [arg-type]
return model
30 changes: 30 additions & 0 deletions tests/test_hub_mixin_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,33 @@ def test_push_to_hub(self):

# Delete repo
self._api.delete_repo(repo_id=repo_id)

def test_save_model_with_shared_tensors(self):
"""
Regression test for #2086. Shared tensors should be saved correctly.
See https://github.com/huggingface/huggingface_hub/pull/2086 for more details.
"""

class ModelWithSharedTensors(nn.Module, PyTorchModelHubMixin):
def __init__(self):
super().__init__()
self.a = nn.Linear(100, 100)
self.b = self.a

def forward(self, x):
return self.b(self.a(x))

# Save and reload model
model = ModelWithSharedTensors()
model.save_pretrained(self.cache_dir)
reloaded = ModelWithSharedTensors.from_pretrained(self.cache_dir)

# Linear layers should share weights and biases in memory
state_dict = reloaded.state_dict()
a_weight_ptr = state_dict["a.weight"].storage().data_ptr()
b_weight_ptr = state_dict["b.weight"].storage().data_ptr()
a_bias_ptr = state_dict["a.bias"].storage().data_ptr()
b_bias_ptr = state_dict["b.bias"].storage().data_ptr()
assert a_weight_ptr == b_weight_ptr
assert a_bias_ptr == b_bias_ptr

0 comments on commit 3cc9f71

Please sign in to comment.