-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Fully Sharded Data Parallel #3740
Changes from 31 commits
a5b4da4
ea9390c
378eacc
0ea7d3a
fc3e668
65ad526
3153dd8
82f8b01
44fcdfc
281efd1
4146d86
5c6755a
dc5edc3
7e12292
66d53d3
0bb9995
1cb30d1
5cea3b2
490f5d8
31dfeb5
ded3708
d095f51
f17abb2
231e88d
4a3ce86
98a90b7
a2f84c1
e45b149
61b64dc
16374c9
074be0a
0814c99
c5a82aa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,6 +35,7 @@ | |
import parlai.utils.logging as logging | ||
from parlai.core.metrics import SumMetric, AverageMetric, FairseqBleuMetric | ||
from parlai.utils.fp16 import FP16SafeCrossEntropy | ||
import parlai.utils.fsdp as fsdp_utils | ||
from parlai.utils.torch import ( | ||
neginf, | ||
total_parameters, | ||
|
@@ -479,8 +480,10 @@ def __init__(self, opt: Opt, shared=None): | |
else: | ||
# this is not a shared instance of this class, so do full init | ||
self.criterion = self.build_criterion() | ||
# ensure all distributed copies will always be in sync | ||
self.model = self.build_model() | ||
with fsdp_utils.maybe_fsdp_wrap(opt): | ||
self.model = fsdp_utils.fsdp_wrap(self.build_model()) | ||
if self.fp16 and not fsdp_utils.should_use_fsdp(opt): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remember that bug with the instability stuff? is this not re-introducing it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (because we moved the model.half() call?) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay I think this needs to use my utility We haven't really moved it the moment of halving. The operations between these two points don't do much, and the original code path should be about the same.
The catch is that FSDP expects the model pre-halved if we're doing safe optimization, and post-halved if we're doing memory-efficient. (Similar to the optimizer wrappers, it looks for parameters of types to decide what type are the gradients). This is the desired pattern
|
||
self.model = self.model.half() | ||
|
||
# load the block_list for beam search | ||
self.beam_block_list = self._load_beam_block_list() | ||
|
@@ -498,16 +501,15 @@ def __init__(self, opt: Opt, shared=None): | |
self.model.cuda() | ||
self.criterion.cuda() | ||
|
||
sync_parameters(self.model) | ||
if not fsdp_utils.is_fsdp(self.model): | ||
sync_parameters(self.model) | ||
|
||
train_params = trainable_parameters(self.model) | ||
total_params = total_parameters(self.model) | ||
logging.info( | ||
f"Total parameters: {total_params:,d} ({train_params:,d} trainable)" | ||
) | ||
|
||
if self.fp16: | ||
self.model = self.model.half() | ||
|
||
if init_model is not None: | ||
# load model parameters if available | ||
logging.info(f'Loading existing model params from {init_model}') | ||
|
@@ -530,7 +532,11 @@ def __init__(self, opt: Opt, shared=None): | |
logging.warning("Optimizer was reset. Also resetting LR scheduler.") | ||
self.build_lr_scheduler(states, hard_reset=is_finetune or was_reset) | ||
|
||
if shared is None and is_distributed(): | ||
if ( | ||
shared is None | ||
and is_distributed() | ||
and opt.get('ddp_backend', fsdp_utils.DEFAULT_DDP_BACKEND) == 'ddp' | ||
): | ||
device_ids = None if self.model_parallel else [self.opt['gpu']] | ||
self.model = torch.nn.parallel.DistributedDataParallel( | ||
self.model, device_ids=device_ids, broadcast_buffers=False | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,7 +24,6 @@ | |
""" | ||
|
||
import torch | ||
import random | ||
import os | ||
import signal | ||
import traceback | ||
|
@@ -55,10 +54,12 @@ def multiprocess_train( | |
raise | ||
|
||
|
||
def launch_and_train(opt, port): | ||
def launch_and_train(opt, port=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will we ever specify a port here? |
||
""" | ||
Perform a fork() to many processes. | ||
""" | ||
if port is None: | ||
port = distributed_utils.find_free_port() | ||
# Launch multiple subprocesses | ||
spawncontext = torch.multiprocessing.start_processes( | ||
multiprocess_train, | ||
|
@@ -99,7 +100,7 @@ def setup_args(cls): | |
|
||
def run(self): | ||
if self.opt['port'] is None: | ||
port = random.randint(32000, 48000) | ||
port = None | ||
else: | ||
port = self.opt['port'] | ||
return launch_and_train(self.opt, port) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -442,17 +442,20 @@ def save_model(self, suffix=None): | |
""" | ||
Save the model to disk, possibly with a suffix. | ||
""" | ||
if not is_primary_worker(): | ||
# never do IO as a non-primary worker | ||
return | ||
|
||
if not self.opt.get('model_file'): | ||
# nothing to save to, just exit | ||
return | ||
|
||
fn = self.opt['model_file'] | ||
if suffix: | ||
fn += suffix | ||
|
||
if not is_primary_worker(): | ||
# never do IO as a non-primary worker | ||
if hasattr(self.agent, 'save_nonprimary'): | ||
self.agent.save_nonprimary(fn) | ||
return | ||
|
||
while True: | ||
# don't ever let a ctrl-c interrupt saving | ||
try: | ||
|
@@ -543,7 +546,7 @@ def validate(self): | |
) | ||
self.best_valid = new_valid | ||
self.impatience = 0 | ||
if opt.get('model_file') and is_primary_worker(): | ||
if opt.get('model_file'): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just making sure I understand - we can get rid of this check because it's handled in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to be able do save_on_nonprimary_worker actually |
||
logging.info(f"saving best valid model: {opt['model_file']}") | ||
self.save_model() | ||
self.saved = True | ||
|
@@ -566,11 +569,7 @@ def validate(self): | |
self.validate_time.reset() | ||
|
||
# saving | ||
if ( | ||
opt.get('model_file') | ||
and opt.get('save_after_valid') | ||
and is_primary_worker() | ||
): | ||
if opt.get('model_file') and opt.get('save_after_valid'): | ||
logging.info(f"saving model checkpoint: {opt['model_file']}.checkpoint") | ||
self.save_model('.checkpoint') | ||
|
||
|
@@ -720,24 +719,26 @@ def _get_time(self, world: World) -> Tuple[float, float, float]: | |
self._total_epochs = self._preempted_epochs + sum( | ||
all_gather_list(world.get_total_epochs()) | ||
) | ||
train_time, log_time, validate_time = sync_object( | ||
train_time, log_time, validate_time, save_time = sync_object( | ||
( | ||
self.train_time.time(), | ||
self.log_time.time(), | ||
self.validate_time.time(), | ||
self.save_time.time(), | ||
) | ||
) | ||
else: | ||
train_time, log_time, validate_time = ( | ||
train_time, log_time, validate_time, save_time = ( | ||
self.train_time.time(), | ||
self.log_time.time(), | ||
self.validate_time.time(), | ||
self.save_time.time(), | ||
) | ||
self._total_epochs = self._preempted_epochs + ( | ||
num_workers() * world.get_total_epochs() | ||
) | ||
|
||
return train_time, log_time, validate_time | ||
return train_time, log_time, validate_time, save_time | ||
|
||
def log(self): | ||
""" | ||
|
@@ -810,7 +811,7 @@ def train_steps(self): | |
self._last_log_steps += 1 / self.update_freq | ||
|
||
# the following additionally updates self._total_epochs | ||
train_time, log_time, validate_time = self._get_time(world) | ||
train_time, log_time, validate_time, save_time = self._get_time(world) | ||
# get the total training examples done, compute epochs | ||
exs_per_epoch = world.num_examples() | ||
self._total_exs = int(np.round(self._total_epochs * exs_per_epoch)) | ||
|
@@ -859,11 +860,7 @@ def train_steps(self): | |
break | ||
# make sure metrics are clean before we log | ||
world.reset_metrics() | ||
if ( | ||
self.save_time.time() > self.save_every_n_secs | ||
and opt.get('model_file') | ||
and is_primary_worker() | ||
): | ||
if save_time > self.save_every_n_secs and opt.get('model_file'): | ||
logging.info( | ||
f"saving model checkpoint: {opt['model_file']}.checkpoint" | ||
) | ||
|
@@ -872,7 +869,7 @@ def train_steps(self): | |
self.save_model('.checkpoint') | ||
self.save_time.reset() | ||
|
||
if not self.saved and is_primary_worker(): | ||
if not sync_object(self.saved): | ||
# save agent | ||
self.save_model() | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: could make this a helper function too? like
should_sync_gradnorm
(not necessary of course)