Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Serialization] Add is_main_process argument to save_torch_state_dict() #2648

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 24 additions & 11 deletions src/huggingface_hub/serialization/_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def save_torch_model(
max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
metadata: Optional[Dict[str, str]] = None,
safe_serialization: bool = True,
is_main_process: bool = True,
):
"""
Saves a given torch model to disk, handling sharding and shared tensors issues.
Expand Down Expand Up @@ -88,6 +89,10 @@ def save_torch_model(
Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle.
Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed
in a future version.
is_main_process (`bool`, *optional*):
Whether the process calling this is the main process or not. Useful when in distributed training like
TPUs and need to call this function from all processes. In this case, set `is_main_process=True` only on
the main process to avoid race conditions. Defaults to True.

Example:

Expand All @@ -112,6 +117,7 @@ def save_torch_model(
metadata=metadata,
safe_serialization=safe_serialization,
save_directory=save_directory,
is_main_process=is_main_process,
)


Expand All @@ -124,6 +130,7 @@ def save_torch_state_dict(
max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
metadata: Optional[Dict[str, str]] = None,
safe_serialization: bool = True,
is_main_process: bool = True,
) -> None:
"""
Save a model state dictionary to the disk, handling sharding and shared tensors issues.
Expand Down Expand Up @@ -171,7 +178,10 @@ def save_torch_state_dict(
Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle.
Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed
in a future version.

is_main_process (`bool`, *optional*):
Whether the process calling this is the main process or not. Useful when in distributed training like
TPUs and need to call this function from all processes. In this case, set `is_main_process=True` only on
the main process to avoid race conditions. Defaults to True.
Example:

```py
Expand Down Expand Up @@ -222,15 +232,18 @@ def save_torch_state_dict(
state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
)

# Clean the folder from previous save
existing_files_regex = re.compile(filename_pattern.format(suffix=r"(-\d{5}-of-\d{5})?") + r"(\.index\.json)?")
for filename in os.listdir(save_directory):
if existing_files_regex.match(filename):
try:
logger.debug(f"Removing existing file '{filename}' from folder.")
os.remove(os.path.join(save_directory, filename))
except Exception as e:
logger.warning(f"Error when trying to remove existing '{filename}' from folder: {e}. Continuing...")
# Only main process should clean up existing files to avoid race conditions in distributed environment
if is_main_process:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

existing_files_regex = re.compile(filename_pattern.format(suffix=r"(-\d{5}-of-\d{5})?") + r"(\.index\.json)?")
for filename in os.listdir(save_directory):
if existing_files_regex.match(filename):
try:
logger.debug(f"Removing existing file '{filename}' from folder.")
os.remove(os.path.join(save_directory, filename))
except Exception as e:
logger.warning(
f"Error when trying to remove existing '{filename}' from folder: {e}. Continuing..."
)

# Save each shard
per_file_metadata = {"format": "pt"}
Expand Down Expand Up @@ -442,7 +455,7 @@ def storage_ptr(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]:
from torch.utils._python_dispatch import is_traceable_wrapper_subclass

if is_traceable_wrapper_subclass(tensor):
return _get_unique_id(tensor)
return _get_unique_id(tensor) # type: ignore
except ImportError:
# for torch version less than 2.1, we can fallback to original implementation
pass
Expand Down
26 changes: 26 additions & 0 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def test_save_torch_model(mocker: MockerFixture, tmp_path: Path) -> None:
max_shard_size="3GB",
metadata={"foo": "bar"},
safe_serialization=True,
is_main_process=True,
)
safe_state_dict_mock.assert_called_once_with(
state_dict=model_mock.state_dict.return_value,
Expand All @@ -273,6 +274,7 @@ def test_save_torch_model(mocker: MockerFixture, tmp_path: Path) -> None:
max_shard_size="3GB",
metadata={"foo": "bar"},
safe_serialization=True,
is_main_process=True,
)


Expand Down Expand Up @@ -472,3 +474,27 @@ def test_save_torch_state_dict_delete_existing_files(
assert (tmp_path / "pytorch_model-00001-of-00003.bin").is_file()
assert (tmp_path / "pytorch_model-00002-of-00003.bin").is_file()
assert (tmp_path / "pytorch_model-00003-of-00003.bin").is_file()


def test_save_torch_state_dict_not_main_process(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice test!

tmp_path: Path,
torch_state_dict: Dict[str, "torch.Tensor"],
) -> None:
"""
Test that previous files in the directory are not deleted when is_main_process=False.
When is_main_process=True, previous files should be deleted,
this is already tested in `test_save_torch_state_dict_delete_existing_files`.
"""
# Create some .safetensors files before saving a new state dict.
(tmp_path / "model.safetensors").touch()
(tmp_path / "model-00001-of-00002.safetensors").touch()
(tmp_path / "model-00002-of-00002.safetensors").touch()
(tmp_path / "model.safetensors.index.json").touch()
# Save with is_main_process=False
save_torch_state_dict(torch_state_dict, tmp_path, is_main_process=False)

# Previous files should still exist (not deleted)
assert (tmp_path / "model.safetensors").is_file()
assert (tmp_path / "model-00001-of-00002.safetensors").is_file()
assert (tmp_path / "model-00002-of-00002.safetensors").is_file()
assert (tmp_path / "model.safetensors.index.json").is_file()
Loading