Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Serialization] Add is_main_process argument to save_torch_state_dict() #2648

Merged

Conversation

hanouticelina
Copy link
Contributor

@hanouticelina hanouticelina commented Oct 31, 2024

This small PR adds is_main_process parameter to save_torch_state_dict() to prevent race conditions during distributed environment. This aligns with accelerate's, transformers' and diffusers' implementations and will enable standardization of model saving across these libraries. See #2314 and this internal slack message for more context.

Once this is released in huggingface_hub==0.27.0, PRs will be opened to update accelerate's save_model(), transformers' save_pretrained() and diffusers' save_pretrained() to use save_torch_state_dict() directly.

Main changes:

(Following existing implementations in accelerate, transformers and diffusers)

  • Condition removing the files from the previous save on is_main_process=True to avoid race conditions during distributed environment.
  • Add is_main_process=True as default parameter.
  • Add a simple unit test for that.

cc @muellerzr, @SunMarc and @sayakpaul for visibility.

@hanouticelina hanouticelina requested a review from Wauplin October 31, 2024 16:06
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@@ -472,3 +472,27 @@ def test_save_torch_state_dict_delete_existing_files(
assert (tmp_path / "pytorch_model-00001-of-00003.bin").is_file()
assert (tmp_path / "pytorch_model-00002-of-00003.bin").is_file()
assert (tmp_path / "pytorch_model-00003-of-00003.bin").is_file()


def test_save_torch_state_dict_not_main_process(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice test!

Copy link
Contributor

@Wauplin Wauplin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

Once this is released in huggingface_hub==0.27.0, PRs will be opened to update accelerate's save_model(), transformers' save_pretrained() and diffusers' save_pretrained() to use save_torch_state_dict() directly.

Super exciting!

src/huggingface_hub/serialization/_torch.py Outdated Show resolved Hide resolved
src/huggingface_hub/serialization/_torch.py Outdated Show resolved Hide resolved
except Exception as e:
logger.warning(f"Error when trying to remove existing '{filename}' from folder: {e}. Continuing...")
# Only main process should clean up existing files to avoid race conditions in distributed environment
if is_main_process:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice ! Thanks for adding this !

@hanouticelina
Copy link
Contributor Author

Thanks for the reviews!

@hanouticelina hanouticelina merged commit 0c98fbd into main Nov 5, 2024
17 checks passed
@hanouticelina hanouticelina deleted the add-condition-to-cleaning-previous-saved-state-dict branch November 5, 2024 14:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants