This repository was archived by the owner on Nov 3, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Fully Sharded Data Parallel #3740
Merged
Merged
Changes from all commits
Commits
Show all changes
33 commits
Select commit
Hold shift + click to select a range
a5b4da4
Implement zero2 and zero3
stephenroller ea9390c
Implement overflow syncing.
stephenroller 378eacc
Tweak log statements.
stephenroller 0ea7d3a
Use free ports rather than random ports
stephenroller fc3e668
Refactor test_distributed
stephenroller 65ad526
More refactor.
stephenroller 3153dd8
Fixup checkpoints.
stephenroller 82f8b01
Merge branch 'freeport' into fsdp
stephenroller 44fcdfc
Get tests working.
stephenroller 281efd1
GPU only
stephenroller 4146d86
Sigh
stephenroller 5c6755a
Moar.
stephenroller dc5edc3
Trying to sync grad norms
stephenroller 7e12292
Correctly implement gnorm syncing.
stephenroller 66d53d3
Update comment.
stephenroller 0bb9995
Merge branch 'master' into fsdp
stephenroller 1cb30d1
Try zero3.
stephenroller 5cea3b2
Okay got zero3 working.
stephenroller 490f5d8
Refactor.
stephenroller 31dfeb5
Get FSDP Zero3 working, except during validation.
stephenroller ded3708
Merge branch 'master' into fsdp
stephenroller d095f51
Check in missing code. Carve out notimplemented.
stephenroller f17abb2
Lint.
stephenroller 231e88d
Er.
stephenroller 4a3ce86
Add a test to ensure we keep track of zero3 not working.
stephenroller 98a90b7
Remove debugs, add docstrings, rename variable.
stephenroller a2f84c1
Silly
stephenroller e45b149
Merge branch 'master' into fsdp
stephenroller 61b64dc
Reviewer comments.
stephenroller 16374c9
Lint.
stephenroller 074be0a
We disabled zero3 as an option, so don't need the test.
stephenroller 0814c99
Bug caught by Kurt.
stephenroller c5a82aa
Rofl
stephenroller File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,7 +24,6 @@ | |
""" | ||
|
||
import torch | ||
import random | ||
import os | ||
import signal | ||
import traceback | ||
|
@@ -55,10 +54,12 @@ def multiprocess_train( | |
raise | ||
|
||
|
||
def launch_and_train(opt, port): | ||
def launch_and_train(opt, port=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will we ever specify a port here? |
||
""" | ||
Perform a fork() to many processes. | ||
""" | ||
if port is None: | ||
port = distributed_utils.find_free_port() | ||
# Launch multiple subprocesses | ||
spawncontext = torch.multiprocessing.start_processes( | ||
multiprocess_train, | ||
|
@@ -99,7 +100,7 @@ def setup_args(cls): | |
|
||
def run(self): | ||
if self.opt['port'] is None: | ||
port = random.randint(32000, 48000) | ||
port = None | ||
else: | ||
port = self.opt['port'] | ||
return launch_and_train(self.opt, port) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -442,17 +442,20 @@ def save_model(self, suffix=None): | |
""" | ||
Save the model to disk, possibly with a suffix. | ||
""" | ||
if not is_primary_worker(): | ||
# never do IO as a non-primary worker | ||
return | ||
|
||
if not self.opt.get('model_file'): | ||
# nothing to save to, just exit | ||
return | ||
|
||
fn = self.opt['model_file'] | ||
if suffix: | ||
fn += suffix | ||
|
||
if not is_primary_worker(): | ||
# never do IO as a non-primary worker | ||
if hasattr(self.agent, 'save_nonprimary'): | ||
self.agent.save_nonprimary(fn) | ||
return | ||
|
||
while True: | ||
# don't ever let a ctrl-c interrupt saving | ||
try: | ||
|
@@ -543,7 +546,7 @@ def validate(self): | |
) | ||
self.best_valid = new_valid | ||
self.impatience = 0 | ||
if opt.get('model_file') and is_primary_worker(): | ||
if opt.get('model_file'): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just making sure I understand - we can get rid of this check because it's handled in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to be able do save_on_nonprimary_worker actually |
||
logging.info(f"saving best valid model: {opt['model_file']}") | ||
self.save_model() | ||
self.saved = True | ||
|
@@ -566,11 +569,7 @@ def validate(self): | |
self.validate_time.reset() | ||
|
||
# saving | ||
if ( | ||
opt.get('model_file') | ||
and opt.get('save_after_valid') | ||
and is_primary_worker() | ||
): | ||
if opt.get('model_file') and opt.get('save_after_valid'): | ||
logging.info(f"saving model checkpoint: {opt['model_file']}.checkpoint") | ||
self.save_model('.checkpoint') | ||
|
||
|
@@ -720,24 +719,26 @@ def _get_time(self, world: World) -> Tuple[float, float, float]: | |
self._total_epochs = self._preempted_epochs + sum( | ||
all_gather_list(world.get_total_epochs()) | ||
) | ||
train_time, log_time, validate_time = sync_object( | ||
train_time, log_time, validate_time, save_time = sync_object( | ||
( | ||
self.train_time.time(), | ||
self.log_time.time(), | ||
self.validate_time.time(), | ||
self.save_time.time(), | ||
) | ||
) | ||
else: | ||
train_time, log_time, validate_time = ( | ||
train_time, log_time, validate_time, save_time = ( | ||
self.train_time.time(), | ||
self.log_time.time(), | ||
self.validate_time.time(), | ||
self.save_time.time(), | ||
) | ||
self._total_epochs = self._preempted_epochs + ( | ||
num_workers() * world.get_total_epochs() | ||
) | ||
|
||
return train_time, log_time, validate_time | ||
return train_time, log_time, validate_time, save_time | ||
|
||
def log(self): | ||
""" | ||
|
@@ -810,7 +811,7 @@ def train_steps(self): | |
self._last_log_steps += 1 / self.update_freq | ||
|
||
# the following additionally updates self._total_epochs | ||
train_time, log_time, validate_time = self._get_time(world) | ||
train_time, log_time, validate_time, save_time = self._get_time(world) | ||
# get the total training examples done, compute epochs | ||
exs_per_epoch = world.num_examples() | ||
self._total_exs = int(np.round(self._total_epochs * exs_per_epoch)) | ||
|
@@ -859,11 +860,7 @@ def train_steps(self): | |
break | ||
# make sure metrics are clean before we log | ||
world.reset_metrics() | ||
if ( | ||
self.save_time.time() > self.save_every_n_secs | ||
and opt.get('model_file') | ||
and is_primary_worker() | ||
): | ||
if save_time > self.save_every_n_secs and opt.get('model_file'): | ||
logging.info( | ||
f"saving model checkpoint: {opt['model_file']}.checkpoint" | ||
) | ||
|
@@ -872,7 +869,7 @@ def train_steps(self): | |
self.save_model('.checkpoint') | ||
self.save_time.reset() | ||
|
||
if not self.saved and is_primary_worker(): | ||
if not sync_object(self.saved): | ||
# save agent | ||
self.save_model() | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: could make this a helper function too? like
should_sync_gradnorm
(not necessary of course)