Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Expert Parallelism #72

Merged
merged 7 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/nanotron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ class GeneralArgs:
step: Optional[int] = None
consumed_train_samples: Optional[int] = None
benchmark_csv_path: Optional[Path] = None
ignore_sanity_checks: bool = False
ignore_sanity_checks: bool = True

def __post_init__(self):
if self.seed is None:
Expand Down
3 changes: 3 additions & 0 deletions src/nanotron/config/parallelism_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class ParallelismArgs:
dp: Number of DP replicas
pp: Number of PP stages
tp: Number of TP replicas
expert_parallel_size: Number of expert parallel replicas (used only for MoEs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't expert_parallel_size should be the number of experts per tp rank?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not quite, expert parallelism is orthogonal to TP. for example you can have 1 expert sharded along 2 tp ranks

pp_engine: Pipeline engine to use between "1f1b" and "afab"
tp_mode: TP mode to use between "all_reduce" and "reduce_scatter": all_reduce is normal, reduce_scatter activate sequence parallelism
recompute_granularity: Recompute granularity to use between "full" and "selective"
Expand All @@ -34,6 +35,8 @@ class ParallelismArgs:
recompute_granularity: Optional[RecomputeGranularity] = None
tp_linear_async_communication: Optional[bool] = None

expert_parallel_size: int = 1

def __post_init__(self):
# Conservative defaults
if self.pp_engine is None:
Expand Down
35 changes: 23 additions & 12 deletions src/nanotron/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def sanity_check_dataloader(
}

if not config.general.ignore_sanity_checks:
# SANITY CHECK: Check input are not the same across DP
# SANITY CHECK: Check input are not the same across DP and expert_pg
for key, value in sorted(micro_batch.items(), key=lambda x: x[0]):
if isinstance(value, TensorPointer):
continue
Expand All @@ -58,6 +58,11 @@ def sanity_check_dataloader(
tensor=value, pg=parallel_context.dp_pg, msg=lambda err: f"{key} {err}"
)

with assert_fail_except_rank_with(AssertionError, rank_exception=0, pg=parallel_context.expert_pg):
assert_tensor_synced_across_pg(
tensor=value, pg=parallel_context.expert_pg, msg=lambda err: f"{key} {err}"
)

# SANITY CHECK: Check input are synchronized throughout TP
for key, value in sorted(micro_batch.items(), key=lambda x: x[0]):
if isinstance(value, TensorPointer):
Expand Down Expand Up @@ -393,8 +398,8 @@ def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Uni

# Adapted from https://github.com/huggingface/transformers/blob/47e1676255e5dd86b9541f734cd4f4bdcbb50f4a/src/transformers/trainer.py#L763-L835
def _get_train_sampler(
dp_size: int,
dp_rank: int,
dl_ranks_size: int,
dl_rank: int,
train_dataset: "Dataset",
seed: int,
use_loop_to_round_batch_size: bool,
Expand All @@ -413,16 +418,18 @@ def _get_train_sampler(
sampler = DistributedSamplerWithLoop(
train_dataset,
batch_size=micro_batch_size,
num_replicas=dp_size,
rank=dp_rank,
num_replicas=dl_ranks_size,
rank=dl_rank,
seed=seed,
drop_last=drop_last,
)
else:
sampler = DistributedSampler(train_dataset, num_replicas=dp_size, rank=dp_rank, seed=seed, drop_last=drop_last)
sampler = DistributedSampler(
train_dataset, num_replicas=dl_ranks_size, rank=dl_rank, seed=seed, drop_last=drop_last
)

if consumed_train_samples > 0:
sampler = SkipBatchSampler(sampler, skip_batches=consumed_train_samples, dp_size=dp_size)
sampler = SkipBatchSampler(sampler, skip_batches=consumed_train_samples, dp_size=dl_ranks_size)

return sampler

Expand Down Expand Up @@ -476,12 +483,16 @@ def get_train_dataloader(
parallel_context=parallel_context,
)

# Compute size and rank of dataloader workers
dl_ranks_size = parallel_context.dp_pg.size() * parallel_context.expert_pg.size()
dl_rank = parallel_context.dp_pg.rank() * parallel_context.expert_pg.size() + parallel_context.expert_pg.rank()

# TODO @nouamanetazi: Remove unused columns: https://github.com/huggingface/transformers/blob/47e1676255e5dd86b9541f734cd4f4bdcbb50f4a/src/transformers/trainer.py#L852
# TODO @nouamanetazi: Support torch.utils.data.IterableDataset: https://github.com/huggingface/transformers/blob/47e1676255e5dd86b9541f734cd4f4bdcbb50f4a/src/transformers/trainer.py#L855-L872

train_sampler = _get_train_sampler(
dp_size=parallel_context.dp_pg.size(),
dp_rank=dist.get_rank(parallel_context.dp_pg),
dl_rank=dl_rank,
dl_ranks_size=dl_ranks_size,
train_dataset=train_dataset,
seed=seed_worker,
use_loop_to_round_batch_size=use_loop_to_round_batch_size,
Expand All @@ -498,18 +509,18 @@ def get_train_dataloader(
drop_last=dataloader_drop_last, # we also drop_last in `clm_process()`
num_workers=dataloader_num_workers,
pin_memory=dataloader_pin_memory,
worker_init_fn=get_dataloader_worker_init(dp_rank=dist.get_rank(parallel_context.dp_pg)),
worker_init_fn=get_dataloader_worker_init(dl_rank=dl_rank),
# TODO @thomasw21: I'm not sure but this doesn't seem to work at all.
# pin_memory_device="cuda",
)


def get_dataloader_worker_init(dp_rank: int):
def get_dataloader_worker_init(dl_rank: int):
"""Creates random states for each worker in order to get different state in each workers"""

def dataloader_worker_init(worker_id):
# Dataloader is TP/PP synced in random states
seed = 2 ** (1 + worker_id) * 3 ** (1 + dp_rank) % (2**32)
seed = 2 ** (1 + worker_id) * 3 ** (1 + dl_rank) % (2**32)
set_random_seed(seed)

return dataloader_worker_init
Expand Down
5 changes: 4 additions & 1 deletion src/nanotron/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,12 @@ def add_scalars_from_list(self, log_entries: List[LogItem], iteration_step: int)

def set_logger_verbosity_format(logging_level: str, parallel_context: ParallelContext):
node_name = os.environ.get("SLURMD_NODENAME")
expert_parallel_log = (
f"|EXP={dist.get_rank(parallel_context.expert_pg)}" if parallel_context.expert_parallel_size > 1 else ""
)
formatter = Formatter(
fmt=f"%(asctime)s [%(levelname)s|DP={dist.get_rank(parallel_context.dp_pg)}|PP={dist.get_rank(parallel_context.pp_pg)}|"
f"TP={dist.get_rank(parallel_context.tp_pg)}{'|' + node_name if node_name else ''}]: %(message)s",
f"TP={dist.get_rank(parallel_context.tp_pg)}{expert_parallel_log}{'|' + node_name if node_name else ''}]: %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
)
# TODO @thomasw21: `logging.log_levels` returns valid lg log levels
Expand Down
36 changes: 16 additions & 20 deletions src/nanotron/parallel/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ def __init__(
tensor_parallel_size: int,
pipeline_parallel_size: int,
data_parallel_size: int,
expert_parallel_size: int = 1,
NouamaneTazi marked this conversation as resolved.
Show resolved Hide resolved
backend: DistributedBackend = "nccl",
):
"""Initialize parallel context."""
num_gpus_per_model = tensor_parallel_size * pipeline_parallel_size
num_gpus_per_model = tensor_parallel_size * pipeline_parallel_size * expert_parallel_size
world_size = int(os.environ["WORLD_SIZE"])

assert (
Expand All @@ -40,6 +41,7 @@ def __init__(
self.tensor_parallel_size = tensor_parallel_size
self.pipeline_parallel_size = pipeline_parallel_size
self.data_parallel_size = data_parallel_size
self.expert_parallel_size = expert_parallel_size

self._groups = {}

Expand All @@ -65,28 +67,24 @@ def _init_parallel_groups(self):
dist.barrier()
world_size = int(os.environ["WORLD_SIZE"])
ranks = np.arange(0, world_size).reshape(
(self.pipeline_parallel_size, self.data_parallel_size, self.tensor_parallel_size)
(
self.expert_parallel_size,
self.pipeline_parallel_size,
self.data_parallel_size,
self.tensor_parallel_size,
)
)
self.world_ranks_to_pg = {}

# Relevent process groups containing the current rank
self.tp_pg = self.create_new_group(
ranks.reshape((self.pipeline_parallel_size * self.data_parallel_size, self.tensor_parallel_size))
)
self.dp_pg = self.create_new_group(
ranks.transpose((0, 2, 1)).reshape(
(self.pipeline_parallel_size * self.tensor_parallel_size, self.data_parallel_size)
)
)
self.pp_pg = self.create_new_group(
ranks.transpose((2, 1, 0)).reshape(
(self.tensor_parallel_size * self.data_parallel_size, self.pipeline_parallel_size)
)
)
self.tp_pg = self.create_new_group(ranks.transpose((0, 1, 2, 3)).reshape((-1, self.tensor_parallel_size)))
self.dp_pg = self.create_new_group(ranks.transpose((3, 0, 1, 2)).reshape((-1, self.data_parallel_size)))
self.pp_pg = self.create_new_group(ranks.transpose((2, 3, 0, 1)).reshape((-1, self.pipeline_parallel_size)))
self.expert_pg = self.create_new_group(ranks.transpose((1, 2, 3, 0)).reshape((-1, self.expert_parallel_size)))

# model parallel group = combination of tp and pp for a given dp rank
self.mp_pg = self.create_new_group(
[ranks[:, dp_rank, :].reshape(-1) for dp_rank in range(self.data_parallel_size)]
[ranks[:, :, dp_rank, :].reshape(-1) for dp_rank in range(self.data_parallel_size)]
)

self.world_rank_matrix: np.ndarray = ranks
Expand Down Expand Up @@ -120,10 +118,8 @@ def set_device(self):
torch.cuda.set_device(torch.cuda.device(device_id))

def get_3d_ranks(self, world_rank: int) -> Tuple[int, int, int]:
pp_rank = (world_rank // (self.tp_pg.size() * self.dp_pg.size())) % self.pp_pg.size()
dp_rank = (world_rank // self.tp_pg.size()) % self.dp_pg.size()
tp_rank = world_rank % self.tp_pg.size()
return (pp_rank, dp_rank, tp_rank)
# return coordinates in world_rank_matrix without expert_parallel_rank
return tuple(i.item() for i in np.where(self.world_rank_matrix == world_rank))[-3:]

def destroy(self):
if not dist.is_initialized():
Expand Down
9 changes: 6 additions & 3 deletions src/nanotron/serialize/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
# TODO(xrsrke): take rank instead of parallel_context
def optimizer_filename(parallel_context: ParallelContext, is_zero: bool):
if is_zero is True:
return f"{ObjectType.OPTIMIZER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}_dp-{dist.get_rank(parallel_context.dp_pg)}-of-{parallel_context.dp_pg.size()}_tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}.pt"
return f"{ObjectType.OPTIMIZER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}_dp-{dist.get_rank(parallel_context.dp_pg)}-of-{parallel_context.dp_pg.size()}_tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_exp-{dist.get_rank(parallel_context.expert_pg)}-of-{parallel_context.expert_parallel_size}.pt"
else:
return f"{ObjectType.OPTIMIZER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}_tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}.pt"
return f"{ObjectType.OPTIMIZER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}_tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_exp-{dist.get_rank(parallel_context.expert_pg)}-of-{parallel_context.expert_parallel_size}.pt"


def lr_scheduler_filename():
Expand Down Expand Up @@ -58,13 +58,15 @@ def save_optimizer(
tp_size = parallel_context.tp_pg.size()
pp_size = parallel_context.pp_pg.size()
dp_size = parallel_context.dp_pg.size()
expert_parallel_size = parallel_context.expert_parallel_size

config = {
"type": str(optimizer.__class__.__name__),
"parallelism": {
"tp_size": str(tp_size),
"dp_size": str(dp_size),
"pp_size": str(pp_size),
"expert_parallel_size": str(expert_parallel_size),
},
"configs": {},
}
Expand Down Expand Up @@ -140,6 +142,7 @@ def load_optimizer(
ckp_pp_size = ckp_optimizer_config["parallelism"]["pp_size"]
ckp_tp_size = ckp_optimizer_config["parallelism"]["tp_size"]
ckp_dp_size = ckp_optimizer_config["parallelism"]["dp_size"]
ckpt_expert_parallel_size = ckp_optimizer_config["parallelism"]["expert_parallel_size"]

if int(ckp_tp_size) != int(parallel_context.tp_pg.size()):
assert (
Expand All @@ -159,7 +162,7 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -
# across data parallel dimension, before merging the shards across tensor parallel dimension
shard_paths = list(
root_folder.glob(
f"{ObjectType.OPTIMIZER.value}_pp-*-of-{ckp_pp_size}_dp-*-of-{ckp_dp_size}_tp-*-of-{ckp_tp_size}.pt"
f"{ObjectType.OPTIMIZER.value}_pp-*-of-{ckp_pp_size}_dp-*-of-{ckp_dp_size}_tp-*-of-{ckp_tp_size}-exp-*-of-{ckpt_expert_parallel_size}.pt"
)
)
ckp_sharded_optim_states = merge_dp_shard_in_zero1_optimizer(
Expand Down
4 changes: 3 additions & 1 deletion src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def __init__(
tensor_parallel_size=self.config.parallelism.tp,
pipeline_parallel_size=self.config.parallelism.pp,
data_parallel_size=self.config.parallelism.dp,
expert_parallel_size=self.config.parallelism.expert_parallel_size,
)

self.pre_init()
Expand Down Expand Up @@ -198,7 +199,7 @@ def __init__(

# Setup tensorboard write and log writers on output rank
self.logger_ranks = self.parallel_context.world_rank_matrix[
self.unwrapped_model.output_pp_rank, 0, 0
0, self.unwrapped_model.output_pp_rank, 0, 0
NouamaneTazi marked this conversation as resolved.
Show resolved Hide resolved
].flatten()
self.loggerwriter = self.setup_log_writers()

Expand Down Expand Up @@ -744,6 +745,7 @@ def mark_tied_parameters(
target,
(
parallel_context.world_rank_matrix[
dist.get_rank(parallel_context.expert_pg),
get_pp_rank_of(target, module=model),
dist.get_rank(parallel_context.dp_pg),
dist.get_rank(parallel_context.tp_pg),
Expand Down
1 change: 1 addition & 0 deletions tests/test_parameters_accumulate_gradient_in_fp32.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ def _test_tied_weights_sync_with_grad_accum_in_fp32(
target,
(
parallel_context.world_rank_matrix[
dist.get_rank(parallel_context.expert_pg),
get_pp_rank_of(target, module=mdl),
dist.get_rank(parallel_context.dp_pg),
dist.get_rank(parallel_context.tp_pg),
Expand Down
Loading