Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for checkpoint rename race condition #28364

Merged

Conversation

tblattner
Copy link
Contributor

What does this PR do?

When running distributed training with deepspeed, I was encountered a race condition due to os.rename not being atomic on network filesystems. This rework, changes the logic for renaming to only run on the main processes, or a single main process depending on the save_on_each_node flag. Also added is the use of fsync to try to flush buffers, hopefully ensuring the rename is completed. fsync may have no effect in some filesystems, so a better mechanism may be required to ensure that the rename completed.

Fixes #27925

Before submitting

  • [No] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [Yes] Did you read the contributor guideline,
    Pull Request section?
  • [Discussed on Github issue] Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • [Yes] Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • [No] Did you write any new necessary tests?

Who can review?

@muellerzr
@pacman100

@tblattner
Copy link
Contributor Author

One thing to note, I attempted to reuse the existing with block:
with self.args.main_process_first(desc="Renaming model checkpoint folder to true location", local=self.args.save_on_each_node):

and include the fsync. Unfortunately fsync did not flush buffers related to the staging directory, so it still failed on other processes. This raises some concerns as to the behavior of fsync on network attached storage using NFS.

@siddartha-RE
Copy link
Contributor

Oops, missed this. I looked yesterday :) but I guess you poster after I looked. This is my version:
#28373

I don't see why existence check is required if it is happening only once per node.

@tblattner
Copy link
Contributor Author

Oops, missed this. I looked yesterday :) but I guess you poster after I looked. This is my version: #28373

I don't see why existence check is required if it is happening only once per node.

Could be a race if the output directory for the checkpoint is used sometime later in the code. If that is not the case, then shouldn't be an issue.

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi all, I think the best solution here is a mix of both. @siddartha-RE your approach of using existing functions in the state is better than Accelerate here for readability and focusing it on the Trainer state, and @tblattner using fsync here I think is more robust and better.

I propose that @siddartha-RE can you make your PR just make the modifications to trainer_callback.py and we can handle the OS issue in trainer in this PR?

Thank you both so much for your wonderful solutions and PR's.

src/transformers/trainer.py Outdated Show resolved Hide resolved
Comment on lines 2402 to 2403
if self.args.should_save:
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also move this to be under this if/else so that way we can have it be done just on a single process as needed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Above the wait_for_everyone)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually have a related question. I noticed that the current code has push_to_hub unguarded by any check for is_local / is_world zero. This seems incorrect. However, I don't use that option so I didn't want to touch it without understanding implications.

My guess is that it would have been best to do push_to_hub after wait_for_everyone from the final output_dir. Otherwise it seems like the push could end up shipping partially written state.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reworked the logic for this so that we include the rotate_checkpoints function. Also curious about push_to_hub and the save_to_json that is above line 2384.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed thinking more on it, it should be after for the reasons you mentioned

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, push_to_hub checks is_world_process_zero(), so it's fine :)

tblattner and others added 5 commits January 9, 2024 00:10
…o only operate with the main process.

Added fsync functionality to attempt to flush the write changes in case os.rename is not atomic.
Co-authored-by: Zach Mueller <muellerzr@gmail.com>
Added with open usage to ensure better file closing as suggested from PR
Added rotate_checkpoints into main process logic
@tblattner tblattner force-pushed the fix-checkpoint-rename-race-condition branch from a9fe43c to eb58698 Compare January 9, 2024 05:10
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great fix!

@muellerzr muellerzr requested a review from ArthurZucker January 9, 2024 09:33
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM we no longer get a warning saying that the model checkpoint is renamed but not an issue IMO

@ArthurZucker ArthurZucker merged commit cef2e40 into huggingface:main Jan 10, 2024
21 checks passed
@ArthurZucker
Copy link
Collaborator

Thanks @tblattner 🤗

@xiaojunjie
Copy link

  • if self.is_local_process_zero() if self.args.save_on_each_node else self.is_world_process_zero():
  • if self.args.should_save
    May I ask What is the difference between the two lines above?Can replace Line1 with Line2

staghado pushed a commit to staghado/transformers that referenced this pull request Jan 15, 2024
* Changed logic for renaming staging directory when saving checkpoint to only operate with the main process.
Added fsync functionality to attempt to flush the write changes in case os.rename is not atomic.

* Updated styling using make fixup

* Updated check for main process to use built-in versions from trainer

Co-authored-by: Zach Mueller <muellerzr@gmail.com>

* Fixed incorrect usage of trainer main process checks
Added with open usage to ensure better file closing as suggested from PR
Added rotate_checkpoints into main process logic

* Removed "with open" due to not working with directory. os.open seems to work for directories.

---------

Co-authored-by: Zach Mueller <muellerzr@gmail.com>
@siddartha-RE
Copy link
Contributor

  • if self.is_local_process_zero() if self.args.save_on_each_node else self.is_world_process_zero():
  • if self.args.should_save
    May I ask What is the difference between the two lines above?Can replace Line1 with Line2

Definition of should_save looks like they would be equivalent and the second one a little clearer. It will also allow the extra check below for should_save to be removed.

MadElf1337 pushed a commit to MadElf1337/transformers that referenced this pull request Jan 15, 2024
* Changed logic for renaming staging directory when saving checkpoint to only operate with the main process.
Added fsync functionality to attempt to flush the write changes in case os.rename is not atomic.

* Updated styling using make fixup

* Updated check for main process to use built-in versions from trainer

Co-authored-by: Zach Mueller <muellerzr@gmail.com>

* Fixed incorrect usage of trainer main process checks
Added with open usage to ensure better file closing as suggested from PR
Added rotate_checkpoints into main process logic

* Removed "with open" due to not working with directory. os.open seems to work for directories.

---------

Co-authored-by: Zach Mueller <muellerzr@gmail.com>
@Vechtomov
Copy link

Looks like fd = os.open(output_dir, os.O_RDONLY) doesn't work on Windows:
image

wgifford pushed a commit to wgifford/transformers that referenced this pull request Jan 21, 2024
* Changed logic for renaming staging directory when saving checkpoint to only operate with the main process.
Added fsync functionality to attempt to flush the write changes in case os.rename is not atomic.

* Updated styling using make fixup

* Updated check for main process to use built-in versions from trainer

Co-authored-by: Zach Mueller <muellerzr@gmail.com>

* Fixed incorrect usage of trainer main process checks
Added with open usage to ensure better file closing as suggested from PR
Added rotate_checkpoints into main process logic

* Removed "with open" due to not working with directory. os.open seems to work for directories.

---------

Co-authored-by: Zach Mueller <muellerzr@gmail.com>
@tblattner
Copy link
Contributor Author

Looks like fd = os.open(output_dir, os.O_RDONLY) doesn't work on Windows: image

This is indeed an issue. Windows handles directories differently than Linux. I'm not an expert at Windows dev in python, so I don't have a good solution sorry!

@Vechtomov
Copy link

I think this is critical, when your training fails at the first checkpoint. Maybe as a workaround we can add a condition for non-Windows systems?

if platform.system() != 'Windows':
    fd = os.open(output_dir, os.O_RDONLY)
    os.fsync(fd)
    os.close(fd)

@muellerzr @siddartha-RE

@muellerzr
Copy link
Contributor

Yes, quick PR going in a moment.

AjayP13 pushed a commit to AjayP13/transformers that referenced this pull request Jan 22, 2024
* Changed logic for renaming staging directory when saving checkpoint to only operate with the main process.
Added fsync functionality to attempt to flush the write changes in case os.rename is not atomic.

* Updated styling using make fixup

* Updated check for main process to use built-in versions from trainer

Co-authored-by: Zach Mueller <muellerzr@gmail.com>

* Fixed incorrect usage of trainer main process checks
Added with open usage to ensure better file closing as suggested from PR
Added rotate_checkpoints into main process logic

* Removed "with open" due to not working with directory. os.open seems to work for directories.

---------

Co-authored-by: Zach Mueller <muellerzr@gmail.com>
@muellerzr
Copy link
Contributor

@Vechtomov @tblattner #28637 fixed it.

Unsure how it affects multinode on windows but if a user has this situation and hits it then we can deal with it then as there's not really a clean solution for doing so in python :(

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Save model checkpoint error when multi-gpu training
7 participants