Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update config_moe_args.py #1104

Merged
merged 1 commit into from
Apr 10, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 1 addition & 28 deletions llmfoundry/models/utils/config_moe_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,33 +12,6 @@
from llmfoundry.models.layers.ffn import resolve_ffn_hidden_size


def create_process_group_ranks(ranks: tuple[int]):
"""Creates a new distributed group.

Used in create_set_process_group and create_mod_process_group methods below.

This function is an alternative to `distributed.new_group(ranks)`.
When working with FSDP in torch1.13.1, using `distributed.new_group(ranks)`
resulted in an error but this method worked.

TODO(GRT-2416): When composer no longer has support for torch1.13.1, we should
consider using `distributed.new_group(ranks)` here and in composer's FSDP
custom process group init.

Args:
ranks (tuple[int]): Tuple of ranks of group members.

Returns:
A handle of distributed group that can be given to collective calls.
"""
ranks_gather_list = [None for _ in range(distributed.get_world_size())]
distributed.all_gather_object(ranks_gather_list, ranks)
ranks_per_subgroup = list(set(ranks_gather_list))
group, _ = distributed.distributed_c10d.new_subgroups_by_enumeration(
ranks_per_subgroup)
return group


def create_set_process_group(k: int):
"""Creates a new distributed group using sets of k GPUs.

Expand All @@ -60,7 +33,7 @@ def create_set_process_group(k: int):
raise RuntimeError(f'{world_size=} must be divisible by {k=}.')
start = distributed.get_rank() // k * k
ranks = tuple(range(start, start + k))
return create_process_group_ranks(ranks)
return distributed.new_group(ranks)


def config_megablocks_moe_args(
Expand Down
Loading