-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fault-tolerant checklist #9130
Comments
Hey @ananthsub, As fault tolerant training is quite complex, would you like someone on your side to be involved too ? Best, |
Do we have a design document for the proposed work? |
@awaelchli @tchaton - we should cover metrics here as well. I am not clear around the assumptions for where/when metric states are synced, whether they're synced to rank 0 before saving, and whether we assume that we save the checkpoint from rank zero only. This does not hold for deepspeed or other use cases where we save a part of the lightning module state dict across ranks? |
Hey @aazzolini, I entirely agree, this is definitely very error prone ! We are working on this with @awaelchli and we will share it with you asap. Hey @ananthsub, We didn't investigate the impact of sharding on TorchMetrics yet and added fault tolerant for non-sharded model for our fault tolerant V0. My intuition is that this assumption should still hold. Before saving, we accumulate the states on all ranks. So each ranks will contain the accumulated states. On reload, we just reset the metric on non-rank 0. But it might become more intricate if metrics states get sharded, which might results in state collisions. Best, |
🚀 Fault-tolerant training progress tracker
Fault-tolerant training in PL can be activated by setting the environment variable.
Description of the architecture around fault-tolerant training
progress tracking in loops
progress restart of loops
capturing state of iterable datasets
capturing state of map-style dataset
mid epoch restart
failure on mid training epoch loop
failure on mid validation epoch loop: [Feat] Add Fault Tolerant Training for ValidationLoop. #9563, Support skipping to validation #9681
epoch end restart
reloading multiple train dataloaders with individual random states (see add fault-tolerance for global random state in map-style datasets #8950 for todos in added tests)
Fault-tolerant training with multiple dataloaders (Fault tolerant training with multiple dataloaders #11349)
parametrize the important tests in
test_auto_restart.py
with num_workers > 0 and run as part of slow tests Slow CI #9086end-user guide - how to use and description of limitations [doc] Add Fault Tolerant Documentation Page #9256
benchmark iteration/s for small and large models with fault-tolerant enabled vs. not enabled
add test case for different batch structures (dict, list, etc.) fix state extraction from batch when fault-tolerant training #9281
fix progress bar tracking on restart fix progress bar restart with fault-tolerant training enabled #9310
Add logic to resume OptimizerLoop? Is this needed/wanted?
Fix num_workers > 0 causing repeated random state when resuming
Fix TODOs
https://github.com/PyTorchLightning/pytorch-lightning/blob/ce00053002a1bb5385f7e44fceea97af50313d4c/tests/loops/test_loops.py#L899-L900
https://github.com/PyTorchLightning/pytorch-lightning/blob/5841ca97825bd9786ab84d70f0abfa6e673528b4/tests/trainer/connectors/test_signal_connector.py#L27
https://github.com/PyTorchLightning/pytorch-lightning/blob/5841ca97825bd9786ab84d70f0abfa6e673528b4/tests/utilities/test_auto_restart.py#L1181-L1183
Open questions:
torch.cuda.get_rng_state
too?Supplementary documents:
State of fault-tolerant training in PyTorch Lightning
cc @Borda @carmocca @justusschock @awaelchli @ninginthecloud
The text was updated successfully, but these errors were encountered: