Skip to content

Commit

Permalink
Optimizing distributed Adam when running with one work queue (#5560)
Browse files Browse the repository at this point in the history
* Dist Adam constructs a single param bucket for each GPT layer

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Synchronize dist Adam reduce-scatters before launching model-parallel all-reduces

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Configure per-layer dist Adam buckets for BERT and T5

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Remove unused variables

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Configure GPT with one dist Adam bucket per virtual pipeline stage

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Configure BERT with one dist Adam bucket per virtual pipeline stage

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Update Apex commit in Dockerfile

Need recent updates to Apex distributed Adam optimizer.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Remove logic for per-virtual-pipeline distopt buckets from T5

Signed-off-by: Tim Moon <tmoon@nvidia.com>

---------

Signed-off-by: Tim Moon <tmoon@nvidia.com>
  • Loading branch information
timmoon10 authored Jan 30, 2023
1 parent fd6bb1d commit c3eeae1
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 12 deletions.
7 changes: 7 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ RUN apt-get update && \

WORKDIR /tmp/

# TODO: Remove once this Apex commit (1/19/23) is included in PyTorch
# container
RUN git clone https://github.com/NVIDIA/apex.git && \
cd apex && \
git checkout 75f401e088ef88e7c85a57ecf70fb232235f0334 && \
pip3 install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_layer_norm" --global-option="--distributed_adam" --global-option="--deprecated_fused_adam" ./

# uninstall stuff from base container
RUN pip3 uninstall -y sacrebleu torchtext

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -398,12 +398,24 @@ def configure_optimizers(self):
# Configure distributed optimizer
if self.with_distributed_adam:

# Initialize params so that main grads are available
# Initialize param buckets if explicitly provided
if hasattr(self, 'distributed_adam_buckets'):
for bucket in self.distributed_adam_buckets:
self._optimizer.init_params_bucket(bucket)
del self.distributed_adam_buckets

# Make sure all params are initialized so main grads are
# available
# Note: Consolidate grads without overlap
self._optimizer.init_params(
p for p in self.parameters() if getattr(p, '_disable_overlap_grad_sync', False)
)
self._optimizer.init_params(self.parameters())
overlap_params = []
no_overlap_params = []
for p in self.parameters():
if getattr(p, '_disable_overlap_grad_sync', False):
no_overlap_params.append(p)
else:
overlap_params.append(p)
self._optimizer.init_params(reversed(overlap_params))
self._optimizer.init_params(reversed(no_overlap_params))

if self._scheduler is None:
return self._optimizer
Expand All @@ -428,7 +440,7 @@ def _extract_consumed_samples_from_ckpt(self, ckpt_path):
return init_consumed_samples

def _validate_and_override_config(self):
""" Certain configurations might be incompatible or discouraged.
""" Certain configurations might be incompatible or discouraged.
We can check for them here and override if necessary.
"""
app_state = AppState()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,10 @@ def training_step(self, batch, batch_idx):
self.allreduce_sequence_parallel_gradients()

if self.with_distributed_adam:
# gradients are reduced internally in distributed optimizer
pass
# synchronize asynchronous grad reductions
# note: not necessary, but reduces performance degradation
# from multiple simultaneous NCCL calls
self._optimizer._finish_bucket_grad_sync()
elif self.megatron_amp_o2:
if self.cfg.get('pipeline_model_parallel_size', 1) > 1 or self.cfg.get('sequence_parallel', False):
# when using pipeline parallelism grads must be all-reduced after the pipeline (not asynchronously)
Expand Down Expand Up @@ -765,6 +767,36 @@ def configure_optimizers(self):
param._disable_greedy_grad_copy = not self.megatron_amp_o2
param._disable_overlap_grad_sync = True

# Initialize parameter buckets for overlapped grad and param syncs
buckets = []
if self.cfg.get('virtual_pipeline_model_parallel_size', None) is not None:
# Initialize a bucket for each virtual pipeline stage
for module in self.model:
if isinstance(module, Float16Module):
module = module.module
stage_bucket = []
for layer in module.language_model.encoder.layers:
stage_bucket.extend(
p for p in layer.parameters() if not getattr(p, '_disable_overlap_grad_sync', False)
)
buckets.append(stage_bucket)
else:
# Initialize a bucket for each Transformer layer
modules = self.model if isinstance(self.model, list) else [self.model]
for module in modules:
if isinstance(module, Float16Module):
module = module.module
for layer in module.language_model.encoder.layers:
buckets.append(
[p for p in layer.parameters() if not getattr(p, '_disable_overlap_grad_sync', False)]
)
buckets.reverse()
used_params = set()
for bucket in buckets:
used_params.update(bucket)
buckets.append([p for p in self.parameters() if p not in used_params])
self.distributed_adam_buckets = buckets

return super().configure_optimizers()

# Required for ONNX export
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,36 @@ def configure_optimizers(self):
param._disable_greedy_grad_copy = not self.megatron_amp_o2
param._disable_overlap_grad_sync = True

# Initialize parameter buckets for overlapped grad and param syncs
buckets = []
if self.cfg.get('virtual_pipeline_model_parallel_size', None) is not None:
# Initialize a bucket for each virtual pipeline stage
for module in self.model:
if isinstance(module, Float16Module):
module = module.module
stage_bucket = []
for layer in module.language_model.encoder.layers:
stage_bucket.extend(
p for p in layer.parameters() if not getattr(p, '_disable_overlap_grad_sync', False)
)
buckets.append(stage_bucket)
else:
# Initialize a bucket for each Transformer layer
modules = self.model if isinstance(self.model, list) else [self.model]
for module in modules:
if isinstance(module, Float16Module):
module = module.module
for layer in module.language_model.encoder.layers:
buckets.append(
[p for p in layer.parameters() if not getattr(p, '_disable_overlap_grad_sync', False)]
)
buckets.reverse()
used_params = set()
for bucket in buckets:
used_params.update(bucket)
buckets.append([p for p in self.parameters() if p not in used_params])
self.distributed_adam_buckets = buckets

return super().configure_optimizers()

def forward(self, tokens, text_position_ids, attention_mask, labels):
Expand Down Expand Up @@ -336,8 +366,10 @@ def training_step(self, batch, batch_idx):
self.allreduce_sequence_parallel_gradients()

if self.with_distributed_adam:
# gradients are reduced internally in distributed optimizer
pass
# synchronize asynchronous grad reductions
# note: not necessary, but reduces performance degradation
# from multiple simultaneous NCCL calls
self._optimizer._finish_bucket_grad_sync()
elif self.megatron_amp_o2:
# when using pipeline parallelism grads must be all-reduced after the pipeline (not asynchronously)
if self.cfg.get('pipeline_model_parallel_size', 1) > 1 or self.cfg.get('sequence_parallel', False):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -369,8 +369,10 @@ def training_step(self, batch, batch_idx):
loss_mean = torch.tensor(0.0).cuda()

if self.with_distributed_adam:
# gradients are reduced internally in distributed optimizer
pass
# synchronize asynchronous grad reductions
# note: not necessary, but reduces performance degradation
# from multiple simultaneous NCCL calls
self._optimizer._finish_bucket_grad_sync()
elif self.megatron_amp_o2:
# when using pipeline parallelism grads must be reduced after the pipeline (not asynchronously)
if self.cfg.get('pipeline_model_parallel_size', 1) > 1:
Expand Down

0 comments on commit c3eeae1

Please sign in to comment.