diff --git a/src/huggingface_hub/serialization/_torch.py b/src/huggingface_hub/serialization/_torch.py index e87e728d23..58777d9947 100644 --- a/src/huggingface_hub/serialization/_torch.py +++ b/src/huggingface_hub/serialization/_torch.py @@ -41,6 +41,7 @@ def save_torch_model( max_shard_size: Union[int, str] = MAX_SHARD_SIZE, metadata: Optional[Dict[str, str]] = None, safe_serialization: bool = True, + is_main_process: bool = True, ): """ Saves a given torch model to disk, handling sharding and shared tensors issues. @@ -88,6 +89,10 @@ def save_torch_model( Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle. Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed in a future version. + is_main_process (`bool`, *optional*): + Whether the process calling this is the main process or not. Useful when in distributed training like + TPUs and need to call this function from all processes. In this case, set `is_main_process=True` only on + the main process to avoid race conditions. Defaults to True. Example: @@ -112,6 +117,7 @@ def save_torch_model( metadata=metadata, safe_serialization=safe_serialization, save_directory=save_directory, + is_main_process=is_main_process, ) @@ -124,6 +130,7 @@ def save_torch_state_dict( max_shard_size: Union[int, str] = MAX_SHARD_SIZE, metadata: Optional[Dict[str, str]] = None, safe_serialization: bool = True, + is_main_process: bool = True, ) -> None: """ Save a model state dictionary to the disk, handling sharding and shared tensors issues. @@ -171,7 +178,10 @@ def save_torch_state_dict( Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle. Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed in a future version. - + is_main_process (`bool`, *optional*): + Whether the process calling this is the main process or not. Useful when in distributed training like + TPUs and need to call this function from all processes. In this case, set `is_main_process=True` only on + the main process to avoid race conditions. Defaults to True. Example: ```py @@ -222,15 +232,18 @@ def save_torch_state_dict( state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size ) - # Clean the folder from previous save - existing_files_regex = re.compile(filename_pattern.format(suffix=r"(-\d{5}-of-\d{5})?") + r"(\.index\.json)?") - for filename in os.listdir(save_directory): - if existing_files_regex.match(filename): - try: - logger.debug(f"Removing existing file '{filename}' from folder.") - os.remove(os.path.join(save_directory, filename)) - except Exception as e: - logger.warning(f"Error when trying to remove existing '{filename}' from folder: {e}. Continuing...") + # Only main process should clean up existing files to avoid race conditions in distributed environment + if is_main_process: + existing_files_regex = re.compile(filename_pattern.format(suffix=r"(-\d{5}-of-\d{5})?") + r"(\.index\.json)?") + for filename in os.listdir(save_directory): + if existing_files_regex.match(filename): + try: + logger.debug(f"Removing existing file '{filename}' from folder.") + os.remove(os.path.join(save_directory, filename)) + except Exception as e: + logger.warning( + f"Error when trying to remove existing '{filename}' from folder: {e}. Continuing..." + ) # Save each shard per_file_metadata = {"format": "pt"} @@ -442,7 +455,7 @@ def storage_ptr(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]: from torch.utils._python_dispatch import is_traceable_wrapper_subclass if is_traceable_wrapper_subclass(tensor): - return _get_unique_id(tensor) + return _get_unique_id(tensor) # type: ignore except ImportError: # for torch version less than 2.1, we can fallback to original implementation pass diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 019ec26f2d..d966bd478a 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -264,6 +264,7 @@ def test_save_torch_model(mocker: MockerFixture, tmp_path: Path) -> None: max_shard_size="3GB", metadata={"foo": "bar"}, safe_serialization=True, + is_main_process=True, ) safe_state_dict_mock.assert_called_once_with( state_dict=model_mock.state_dict.return_value, @@ -273,6 +274,7 @@ def test_save_torch_model(mocker: MockerFixture, tmp_path: Path) -> None: max_shard_size="3GB", metadata={"foo": "bar"}, safe_serialization=True, + is_main_process=True, ) @@ -472,3 +474,27 @@ def test_save_torch_state_dict_delete_existing_files( assert (tmp_path / "pytorch_model-00001-of-00003.bin").is_file() assert (tmp_path / "pytorch_model-00002-of-00003.bin").is_file() assert (tmp_path / "pytorch_model-00003-of-00003.bin").is_file() + + +def test_save_torch_state_dict_not_main_process( + tmp_path: Path, + torch_state_dict: Dict[str, "torch.Tensor"], +) -> None: + """ + Test that previous files in the directory are not deleted when is_main_process=False. + When is_main_process=True, previous files should be deleted, + this is already tested in `test_save_torch_state_dict_delete_existing_files`. + """ + # Create some .safetensors files before saving a new state dict. + (tmp_path / "model.safetensors").touch() + (tmp_path / "model-00001-of-00002.safetensors").touch() + (tmp_path / "model-00002-of-00002.safetensors").touch() + (tmp_path / "model.safetensors.index.json").touch() + # Save with is_main_process=False + save_torch_state_dict(torch_state_dict, tmp_path, is_main_process=False) + + # Previous files should still exist (not deleted) + assert (tmp_path / "model.safetensors").is_file() + assert (tmp_path / "model-00001-of-00002.safetensors").is_file() + assert (tmp_path / "model-00002-of-00002.safetensors").is_file() + assert (tmp_path / "model.safetensors.index.json").is_file()