Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to disable distributed parameters in distributed Adam optimizer #5685

Merged
merged 6 commits into from
Jan 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions nemo/collections/nlp/modules/common/megatron/clip_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ def clip_grad_norm_distributed_optimizer(optimizer, max_norm, norm_type=2):
# Compute grad norm
# Note: Compute norm of local grads and sum over all procs
grad_norm_sq = optimizer._local_grad_norm(parameters=params_for_norm, norm_type=norm_type)
if optimizer.redundant_size > 1:
grad_norm_sq /= optimizer.redundant_size
Comment on lines +193 to +194
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do all optimizers have this attribute? Can we do a safer check here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This grad clipping function guarantees that it's just dealing with the distributed optimizer:

assert isinstance(optimizer, DistributedFusedAdam)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've also been thinking about moving this logic to the distributed optimizer wrapper at:

class MegatronDistributedFusedAdam(DistributedFusedAdam):

I've been working on supporting BF16 grad reductions, so perhaps this work can be folded into that PR.

torch.distributed.all_reduce(
grad_norm_sq, op=torch.distributed.ReduceOp.SUM,
)
Expand Down
9 changes: 8 additions & 1 deletion nemo/core/optim/distributed_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam
from apex.transformer import parallel_state


# Wrapper class that supports main_grad buffer
# Note: main_grad buffer is used for O2-style optimizations
class MegatronDistributedFusedAdam(DistributedFusedAdam):
def __init__(self, *args, **kwargs):
def __init__(self, *args, disable_distributed_parameters=False, **kwargs):
if 'process_group' not in kwargs and not parallel_state.is_unitialized():
kwargs['process_group'] = parallel_state.get_data_parallel_group()
if disable_distributed_parameters:
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
self_groups = [torch.distributed.new_group(ranks=[i]) for i in range(world_size)]
kwargs['distributed_process_group'] = self_groups[rank]
kwargs['redundant_process_group'] = kwargs['process_group']
super().__init__(*args, **kwargs)

def _make_post_backward_hook(self, param, param_group_id, param_id):
Expand Down