diff --git a/configs/7B_internlm2.py b/configs/7B_internlm2.py index 2d7d67ca..52bd70fc 100644 --- a/configs/7B_internlm2.py +++ b/configs/7B_internlm2.py @@ -43,11 +43,9 @@ TRAIN_FOLDER = None VALID_FOLDER = None # "/path/to/dataset" data = dict( - # tokenizer_path = "/mnt/petrelfs/share_data/caizheng/pretrained/internlm2-chat-7b", - # type="streaming", seq_len=SEQ_LEN, # micro_num means the number of micro_batch contained in one gradient update - micro_num=8, + micro_num=4, # packed_length = micro_bsz * SEQ_LEN micro_bsz=1, # defaults to the value of micro_num @@ -92,7 +90,7 @@ hybrid_zero_optimizer = dict( # Enable low_level_optimzer overlap_communication - overlap_sync_grad=False, + overlap_sync_grad=True, overlap_sync_param=False, # bucket size for nccl communication params reduce_bucket_size=512 * 1024 * 1024, @@ -182,9 +180,9 @@ """ parallel = dict( zero1=dict(size=-1), - tensor=dict(size=1, mode="mtp"), - pipeline=dict(size=4, interleaved_overlap=True, mode='zbv'), - weight=dict(size=1, overlap=True), + tensor=dict(size=2, mode="isp"), + pipeline=dict(size=1, interleaved_overlap=True), + weight=dict(size=2, overlap=True), ) cudnn_deterministic = False diff --git a/doc/code-docs/locales/en/LC_MESSAGES/parallel.po b/doc/code-docs/locales/en/LC_MESSAGES/parallel.po index b948e4f9..df6ff3f7 100644 --- a/doc/code-docs/locales/en/LC_MESSAGES/parallel.po +++ b/doc/code-docs/locales/en/LC_MESSAGES/parallel.po @@ -563,8 +563,8 @@ msgstr "" msgid "返回类型" msgstr "Return type" -#: internlm.core.scheduler.pipeline_scheduler.InterleavedPipelineScheduler.forward_backward_step:19 -#: internlm.core.scheduler.pipeline_scheduler.PipelineScheduler.forward_backward_step:19 +#: internlm.core.scheduler.pipeline_scheduler_1f1b.InterleavedPipelineScheduler.forward_backward_step:19 +#: internlm.core.scheduler.pipeline_scheduler_1f1b.PipelineScheduler.forward_backward_step:19 #: of msgid "Tuple[:class:`torch.Tensor`]" msgstr "" @@ -579,11 +579,11 @@ msgstr "" "To use interleaved pipeline scheduler, users need to set " "``model.num_chunks > 1`` in the config file." -#: internlm.core.scheduler.pipeline_scheduler.InterleavedPipelineScheduler:1 of +#: internlm.core.scheduler.pipeline_scheduler_1f1b.InterleavedPipelineScheduler:1 of msgid "Interleaved Pipeline Scheduler." msgstr "" -#: internlm.core.scheduler.pipeline_scheduler.InterleavedPipelineScheduler.forward_backward_step:1 +#: internlm.core.scheduler.pipeline_scheduler_1f1b.InterleavedPipelineScheduler.forward_backward_step:1 #: of msgid "" "Run interleaved 1F1B schedule (model split into model chunks), with " diff --git a/doc/code-docs/source/parallel.rst b/doc/code-docs/source/parallel.rst index 1ad9ff63..3f893af1 100644 --- a/doc/code-docs/source/parallel.rst +++ b/doc/code-docs/source/parallel.rst @@ -137,14 +137,14 @@ InternEvo 在流水线并行中使用 `1F1B 1`` 。 -.. autoclass:: internlm.core.scheduler.pipeline_scheduler.InterleavedPipelineScheduler +.. autoclass:: internlm.core.scheduler.pipeline_scheduler_1f1b.InterleavedPipelineScheduler :members: 值得注意的是,在使用交错式流水线调度器时可启用通信优化功能,即在 1F1B 阶段启用异步通信,以充分利用上行/下行带宽并实现通信与计算重叠。 diff --git a/doc/en/structure.md b/doc/en/structure.md index 7a37ef30..8d9d726c 100644 --- a/doc/en/structure.md +++ b/doc/en/structure.md @@ -12,7 +12,8 @@ The system code file structure is shown below: │ │ │ └── process_group_initializer.py │ │ ├── scheduler # Scheduling module, which manages schedulers for parallel training, including non-pipeline and pipeline parallel schedulers │ │ │ ├── no_pipeline_scheduler.py -│ │ │ └── pipeline_scheduler.py +│ │ │ ├── pipeline_scheduler_1f1b.py +│ │ │ └── pipeline_scheduler_zb.py │ │ ├── engine.py # Responsible for managing the training and evaluation process of the model │ │ └── trainer.py # Responsible for managing the training engine and scheduler │ ├── data # Data module, responsible for managing dataset generation and processing diff --git a/doc/structure.md b/doc/structure.md index 7c0c6f13..9de1016b 100644 --- a/doc/structure.md +++ b/doc/structure.md @@ -12,7 +12,8 @@ │ │ │ └── process_group_initializer.py │ │ ├── scheduler # 调度模块,管理并行训练的调度器,包括非流水线并行调度器和流水线并行调度器 │ │ │ ├── no_pipeline_scheduler.py -│ │ │ └── pipeline_scheduler.py +│ │ │ ├── pipeline_scheduler_1f1b.py +│ │ │ └── pipeline_scheduler_zb.py │ │ ├── engine.py # 负责管理模型的训练和评估过程 │ │ └── trainer.py # 负责管理训练引擎和调度器 │ ├── data # 数据模块,负责管理数据集生成和处理 diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py index 07a046ee..989b1c00 100644 --- a/internlm/core/context/parallel_context.py +++ b/internlm/core/context/parallel_context.py @@ -299,7 +299,7 @@ def is_rank_for_log(self): is_log_rank = is_log_rank and self.is_last_rank(ParallelMode.PIPELINE) else: is_log_rank = is_log_rank and self.is_first_rank(ParallelMode.PIPELINE) - + return is_log_rank def is_last_rank(self, parallel_mode: ParallelMode): diff --git a/internlm/core/parallel/comm/isp.py b/internlm/core/parallel/comm/isp.py index 330688f0..0b8bc490 100644 --- a/internlm/core/parallel/comm/isp.py +++ b/internlm/core/parallel/comm/isp.py @@ -691,7 +691,7 @@ def after_backward(self, scheduler, inputs_grad) -> None: # pylint: disable=W06 # accumulate left gradients in last bucket after backward. if self._isp_communicator and self._isp_communicator.overlap: self._zero_optim.accumulate_left_grads_after_backward() - + if not self._zero_optim.skip_grad_reduce: self._zero_optim.reduce_left_grads_after_backward() diff --git a/internlm/core/parallel/shard.py b/internlm/core/parallel/shard.py index ea84e428..34feda73 100644 --- a/internlm/core/parallel/shard.py +++ b/internlm/core/parallel/shard.py @@ -185,12 +185,12 @@ def partition_uniform(num_items: int, pipeline_parallel_size: int, num_chunks: i left = pipeline_parallel_size - partition_items % pipeline_parallel_size if chunk_size == 0: raise ValueError("Some nodes in Pipeline have no requests") - + if gpc.config.parallel["pipeline"]["mode"] == "ZBV" and idx == 1: - for p in range(pipeline_parallel_size-1, -1, -1): + for p in range(pipeline_parallel_size - 1, -1, -1): st = base_idx base_idx += chunk_size + ((pipeline_parallel_size - p - 1) >= left) - parts[p].append((st, base_idx)) + parts[p].append((st, base_idx)) else: for p in range(pipeline_parallel_size): st = base_idx @@ -236,6 +236,7 @@ def pipeline_parallel_sharding_wrapper( kwargs["last"] = end == num_layers and len(all_parts[-1]) != 0 kwargs["device"] = device kwargs["start_layer_idx"] = start + chunk = model_builder(**kwargs).to(device) setattr(chunk, "first_layer", start) setattr(chunk, "last_layer", end) diff --git a/internlm/core/scheduler/__init__.py b/internlm/core/scheduler/__init__.py index 5c17dab6..042fa596 100644 --- a/internlm/core/scheduler/__init__.py +++ b/internlm/core/scheduler/__init__.py @@ -1,9 +1,6 @@ from .base_scheduler import BaseScheduler from .no_pipeline_scheduler import NonPipelineScheduler -from .pipeline_scheduler_1f1b import ( - InterleavedPipelineScheduler, - PipelineScheduler, -) +from .pipeline_scheduler_1f1b import InterleavedPipelineScheduler, PipelineScheduler from .pipeline_scheduler_zb import ( ZeroBubblePipelineScheduler, ZeroBubblePipelineVShapeScheduler, diff --git a/internlm/core/scheduler/comm/__init__.py b/internlm/core/scheduler/comm/__init__.py index 8cda2b51..0037c097 100644 --- a/internlm/core/scheduler/comm/__init__.py +++ b/internlm/core/scheduler/comm/__init__.py @@ -1,5 +1,6 @@ from .p2p import ( AsynCommunicator, + fused_send_recv_tensor, recv_backward, recv_forward, send_backward, @@ -9,7 +10,6 @@ send_forward_backward_recv_forward_backward, send_forward_recv_backward, send_forward_recv_forward, - fused_send_recv_tensor, ) from .utils import recv_obj_meta, send_obj_meta diff --git a/internlm/core/scheduler/comm/p2p.py b/internlm/core/scheduler/comm/p2p.py index 7872c0e0..c71d746e 100644 --- a/internlm/core/scheduler/comm/p2p.py +++ b/internlm/core/scheduler/comm/p2p.py @@ -172,7 +172,7 @@ def _communicate( for req in reqs: req.wait() # To protect against race condition when using batch_isend_irecv(). - internlm_accelerator.synchronize() + # internlm_accelerator.synchronize() if recv_prev and recv_prev_split: if isinstance(tensor_recv_prev, torch.Tensor): @@ -279,15 +279,15 @@ def _communicate_async( if len(ops) > 0: reqs = dist.batch_isend_irecv(ops) - + # return and do other things yield - + if len(ops) > 0: - for req in reqs: + for req in reqs: # pylint: disable=E0601 req.wait() # To protect against race condition when using batch_isend_irecv(). - internlm_accelerator.synchronize() + # internlm_accelerator.synchronize() if recv_prev and recv_prev_split: if isinstance(tensor_recv_prev, torch.Tensor): @@ -306,14 +306,7 @@ def _communicate_async( tensor_recv_next[index] = ( gather_split_1d_tensor(tensor_recv_next[index]).view(recv_next_shape[index]).requires_grad_() ) - - # if tensor_recv_prev is not None and tensor_recv_next is None: - # yield tensor_recv_prev - # elif tensor_recv_next is not None and tensor_recv_prev is None: - # yield tensor_recv_next - # elif tensor_recv_next is None and tensor_recv_prev is None: - # yield None - # else: + yield tensor_recv_prev, tensor_recv_next @@ -546,6 +539,7 @@ def send_forward_backward_recv_forward_backward( ) return input_tensor, output_tensor_grad + def fused_send_recv_tensor( object_send_next: Union[torch.Tensor, List[torch.Tensor]] = None, object_send_prev: Union[torch.Tensor, List[torch.Tensor]] = None, @@ -608,7 +602,7 @@ def __init__( dtype: torch.dtype = None, scatter_gather_tensors: bool = False, ) -> None: - self._need_receive = recv_prev_shape is not None or recv_next_shape is not None + self._need_receive = recv_prev_shape is not None or recv_next_shape is not None self._coroutine = _communicate_async( object_send_prev=object_send_prev, object_send_next=object_send_next, diff --git a/internlm/core/scheduler/pipeline_scheduler_1f1b.py b/internlm/core/scheduler/pipeline_scheduler_1f1b.py index a74e3ca6..4864c77f 100644 --- a/internlm/core/scheduler/pipeline_scheduler_1f1b.py +++ b/internlm/core/scheduler/pipeline_scheduler_1f1b.py @@ -180,7 +180,6 @@ def _call_engine(engine, data): # pylint: disable=W0237 elif isinstance(data, (list, tuple)): return engine(*data) elif isinstance(data, dict): - # print(f"data: {data}, {gpc.get_global_rank()}", flush=True) stage_output = data.pop("stage_output", None) if stage_output is None: return engine(**data) @@ -745,7 +744,6 @@ def __init__( assert ( isinstance(num_chunks, int) and num_chunks > 0 ), f"expected num_chunks to be an integer and larger than 0, but got {num_chunks}" - print(f"InterleavedPipelineScheduler", flush=True) super().__init__( num_microbatches, @@ -841,13 +839,10 @@ def _forward_step(self, engine, chunk_id, input_obj=None): if gpc.is_pipeline_first_stage() and len(self._input_objs[chunk_id]) == len(self._output_objs[chunk_id]): self._input_objs[chunk_id].append(None) - + if input_obj is None: input_obj = self._input_objs[chunk_id][-1] - - if input_obj is not None: - assert input_obj.requires_grad == True - + if not gpc.is_pipeline_first_stage(): assert input_obj is not None, f"{gpc.get_global_rank()} input is None" micro_batch_data = self.load_micro_batch(chunk_id) @@ -892,7 +887,7 @@ def _forward_step(self, engine, chunk_id, input_obj=None): self._output_objs[chunk_id].append(output_obj) self._moe_losses[chunk_id].append(moe_loss) - + assert output_obj is not None, f"{gpc.get_global_rank()} chunk{chunk_id} output is None" return output_obj @@ -929,7 +924,6 @@ def _get_chunk_by_microbatch(self, step_id: int, backward: bool = False) -> int: """Helper method to get the model chunk ID given the iteration number.""" microbatch_id_in_group = step_id % (self._pp_size * self._num_chunks) chunk_id = microbatch_id_in_group // self._pp_size - if backward: chunk_id = self._num_chunks - chunk_id - 1 @@ -942,7 +936,7 @@ def _get_current_microbatch_id(self, step_id: int) -> int: # microbatch_id: 1 2 3 4 1 2 3 4 5 6 7 8 5 6 7 8 num_microbatch_group = step_id // (self._pp_size * self._num_chunks) step_id_in_group = step_id % (self._pp_size * self._num_chunks) - + microbatch_id = num_microbatch_group * self._pp_size + step_id_in_group % self._pp_size return microbatch_id @@ -1414,7 +1408,6 @@ def forward_backward_step(self, engine, data_iter, forward_only=False, return_lo output, label = pack_return_tensors(self._return_tensors) else: output, label = (None, None) - if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1: dist.all_reduce(self._accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE)) @@ -1430,4 +1423,4 @@ def forward_backward_step(self, engine, data_iter, forward_only=False, return_lo if hasattr(gpc.config.model, "num_experts"): return output, label, accum_loss, accum_moe_loss else: - return output, label, accum_loss \ No newline at end of file + return output, label, accum_loss diff --git a/internlm/core/scheduler/pipeline_scheduler_zb.py b/internlm/core/scheduler/pipeline_scheduler_zb.py index d9ef4db2..75cf1844 100644 --- a/internlm/core/scheduler/pipeline_scheduler_zb.py +++ b/internlm/core/scheduler/pipeline_scheduler_zb.py @@ -1,8 +1,6 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine - import queue from typing import Callable, List, Optional, Tuple, Union @@ -14,14 +12,15 @@ from internlm.core.context import global_context as gpc from internlm.core.engine import Engine from internlm.core.scheduler import comm -from internlm.utils.common import ( - SchedulerHook, - get_current_device, -) +from internlm.utils.common import SchedulerHook, get_current_device from internlm.utils.logger import get_logger from internlm.utils.parallel import is_using_isp -from .pipeline_scheduler_1f1b import PipelineScheduler, InterleavedPipelineScheduler, pack_return_tensors +from .pipeline_scheduler_1f1b import ( + InterleavedPipelineScheduler, + PipelineScheduler, + pack_return_tensors, +) logger = get_logger(__file__) @@ -37,12 +36,11 @@ class WeightGradStore: pp_mode = None optim = None temp = [] - - + @classmethod def set_pp_mode(cls, mode): cls.pp_mode = mode - + @classmethod def set_optim(cls, optim): cls.optim = optim @@ -76,8 +74,6 @@ def pop(cls): for weight, bias, input_tensor, grad_output, has_d_bias, grad_compute_func, *args in stored_w_grad_computation: assert weight.requires_grad grad_weight, grad_bias = grad_compute_func(input_tensor, grad_output, has_d_bias) - - assert not torch.isnan(grad_weight).any(), f"before {gpc.get_global_rank()}, {getattr(grad_weight, 'debug_name', 'no_name')}, {getattr(grad_weight, 'debug_name2', 'no_name')}, {module}, {weight.shape}, {weight.name}, {weight}, {torch.isnan(input_tensor).any()}, {torch.isnan(grad_output).any()}" if is_using_isp(): isp_grad_hook = args[0] @@ -87,22 +83,20 @@ def pop(cls): if grad_bias is not None: grad_bias, handle_bias = isp_grad_hook(grad_bias, async_op=True, is_bias=True, module=module) handle_bias.wait() - - assert not torch.isnan(grad_weight).any(), f"after {gpc.get_global_rank()}, {getattr(grad_weight, 'debug_name', 'no_name')}, {getattr(grad_weight, 'debug_name2', 'no_name')}, {module}, {weight.shape}, {weight.name}, {weight}, {torch.isnan(input_tensor).any()}, {torch.isnan(grad_output).any()}" - + # Gradient Accumulation weight.grad = weight.grad.data + grad_weight if weight.grad is not None else grad_weight if has_d_bias: bias.grad = bias.grad.data + grad_bias if bias.grad is not None else grad_bias - if weight in cls._hooks: + # overlap hook + if weight in cls._hooks: for hook in cls._hooks[weight]: hook() if has_d_bias: - assert bias in cls._hooks for hook in cls._hooks[bias]: - hook() - + hook() + @classmethod def register_hook(cls, param, hooks): cls._hooks[param] = hooks @@ -368,6 +362,18 @@ def _forward_backward_step(self, engine, return_loss=True, return_output_label=T class ZeroBubblePipelineVShapeScheduler(InterleavedPipelineScheduler): """ ZB-V Scheduler. + + Args: + num_microbatches (int): The number of microbatches. + num_chunks (int): The number of model chunks. + dtype (torch.dtype, optional): The data type of the tensors. Default is torch.float. + data_process_func (Callable, optional): + The preprocessing function which receives a batch of data, and it will be executed in `load_batch`. + tensor_shape (torch.Size, optional): Specified shape in pipeline communication. + scatter_gather_tensors (bool, optional): + If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization. + scheduler_hooks (List[SchedulerHook], optional): List of scheduler hooks. Default is None. + optimizer (Optimizer): The optimizer to do param update. """ def __init__( @@ -396,19 +402,17 @@ def __init__( If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization. scheduler_hooks (List[SchedulerHook], optional): List of scheduler hooks. Default is None. """ - - print("debugg ZeroBubblePipelineVShapeScheduler", flush=True) - + assert ( isinstance(num_chunks, int) and num_chunks == 2 ), f"expect num_chunks to be an integer and equal to 2 for ZBV, but got {num_chunks}." - - assert ( - num_microbatches >= 2 * gpc.get_world_size(ParallelMode.PIPELINE) - ), f"For ZBV, num_microbatches must be greater than or equal to twice pp size." + + assert num_microbatches >= 2 * gpc.get_world_size( + ParallelMode.PIPELINE + ), "For ZBV, num_microbatches must be greater than or equal to twice pp size." assert gpc.v_shape - + super().__init__( num_microbatches, num_chunks=num_chunks, @@ -418,30 +422,23 @@ def __init__( scatter_gather_tensors=scatter_gather_tensors, scheduler_hooks=scheduler_hooks, ) - + del self._run_1f1b_loop - - self._special_chunk0_forward = True + WeightGradStore.set_pp_mode("ZBV") WeightGradStore.set_optim(optimizer) - - self.chunk1_need_recv_prev_chunk1_grad = True - - self._micro_step = [0, 0] - self.map_input_output = [{}, {}] + + self._special_chunk0_forward = True + self._chunk1_need_recv_prev_chunk1_grad = True self._backward_step_num = [0, 0] self._num_microbatches = num_microbatches - + def _clear_state(self) -> None: super()._clear_state() self._special_chunk0_forward = True - self.chunk1_need_recv_prev_chunk1_grad = True - - self._micro_step = [0, 0] - self.map_input_output = [{}, {}] + self._chunk1_need_recv_prev_chunk1_grad = True self._backward_step_num = [0, 0] - - + def _backward_step(self, engine, input_obj, output_obj, output_obj_grad, skip_grad_sync=True, moe_loss=None): """ Backward step through the passed-in output tensor. If it is the last stage, the @@ -462,7 +459,7 @@ def _backward_step(self, engine, input_obj, output_obj, output_obj_grad, skip_gr # Retain the grad on the input_obj. if input_obj is not None: - assert input_obj.requires_grad == True, f"{gpc.get_global_rank()}" + assert input_obj.requires_grad if isinstance(input_obj, torch.Tensor): input_obj.retain_grad() else: @@ -470,8 +467,6 @@ def _backward_step(self, engine, input_obj, output_obj, output_obj_grad, skip_gr if in_tensor is not None: in_tensor.retain_grad() - # Backward pass. - # engine.optimizer.skip_grad_reduce # Only the last microbatch does syncing grad. engine.optimizer.skip_grad_reduce = skip_grad_sync self._call_hooks("before_backward", output_obj, output_obj_grad) @@ -480,12 +475,7 @@ def _backward_step(self, engine, input_obj, output_obj, output_obj_grad, skip_gr if output_obj_grad is None: engine.backward(output_obj) else: - try: - engine.backward_by_grad(output_obj, output_obj_grad) - except Exception as e: - print("rank:", gpc.get_global_rank(), flush=True) - - raise e + engine.backward_by_grad(output_obj, output_obj_grad) else: if output_obj_grad is None: engine.backward(output_obj + moe_loss) @@ -496,22 +486,20 @@ def _backward_step(self, engine, input_obj, output_obj, output_obj_grad, skip_gr # [output_obj_grad, 1], Because moe_loss have no relation with subsequent # layer, we set it to None (will be ragarded as 1). engine.backward_by_grad([output_obj, moe_loss], [output_obj_grad, None]) - + # Collect the grad of the input_obj. input_obj_grad = None if input_obj is not None: - assert input_obj.grad is not None, f"{gpc.get_global_rank()}" + assert input_obj.grad is not None if isinstance(input_obj, torch.Tensor): input_obj_grad = input_obj.grad else: input_obj_grad = [] for in_tensor in input_obj: input_obj_grad.append(in_tensor.grad) - else: - assert gpc.is_pipeline_first_stage(), f"{gpc.get_global_rank()}" return input_obj_grad - + def _schedule_backward(self, engine, chunk_id): """ Backward step for passed-in model. If it is the last stage, the input tensor @@ -527,7 +515,7 @@ def _schedule_backward(self, engine, chunk_id): Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: input tensor gradient. """ gpc.set_virtual_pipeline_parallel_rank(chunk_id) - + self._backward_step_num[chunk_id] += 1 if self._backward_step_num[chunk_id] == self._num_microbatches: skip_grad_sync = False @@ -536,66 +524,48 @@ def _schedule_backward(self, engine, chunk_id): if gpc.is_pipeline_last_stage() and len(self._output_obj_grads[chunk_id]) == 0: self._output_obj_grads[chunk_id].append(None) - - assert len(self._input_objs[chunk_id]) == len(self._output_objs[chunk_id]), f"{gpc.get_global_rank()} {chunk_id} {len(self._input_objs[chunk_id])} {len(self._output_objs[chunk_id])}" input_obj = self._input_objs[chunk_id].pop(0) output_obj = self._output_objs[chunk_id].pop(0) output_obj_grad = self._output_obj_grads[chunk_id].pop(0) moe_loss = self._moe_losses[chunk_id].pop(0) - - if input_obj is not None: - assert self.map_input_output[chunk_id][id(input_obj)] == id(output_obj), f"{gpc.get_global_rank()}" - assert input_obj.requires_grad == True - + if not gpc.is_pipeline_last_stage(): assert output_obj_grad is not None if not gpc.is_pipeline_first_stage(): assert input_obj is not None - # import pdb; pdb.set_trace() + input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad, skip_grad_sync, moe_loss) - if not gpc.is_pipeline_first_stage(): - assert input_obj_grad is not None, f"{gpc.get_global_rank()}" - + WeightGradStore.flush() return input_obj_grad - + def _schedule_1f1b_F(self, engine, chunk_id): - self._micro_step[chunk_id] += 1 output_obj = self._forward_step(engine, chunk_id) - if self._input_objs[chunk_id][-1] is not None: - self.map_input_output[chunk_id][id(self._input_objs[chunk_id][-1])] = id(output_obj) - + object_send_next = None - object_send_prev = None + object_send_prev = None recv_next_shape = None recv_prev_shape = None - + if chunk_id == 1: if not gpc.is_first_rank(ParallelMode.PIPELINE): object_send_prev = output_obj - if self.chunk1_need_recv_prev_chunk1_grad: - recv_prev_shape=self._output_obj_shapes[chunk_id] + if self._chunk1_need_recv_prev_chunk1_grad: + recv_prev_shape = self._output_obj_shapes[chunk_id] else: - self.chunk1_need_recv_prev_chunk1_grad = False + self._chunk1_need_recv_prev_chunk1_grad = False if gpc.is_last_rank(ParallelMode.PIPELINE): # For last rank, chunk0 output does not need to be sent but is directly used for chunk1; - # input_obj = output_obj.clone().detach() - # input_obj.requires_grad = True input_obj = output_obj.clone().detach() input_obj.requires_grad_() - assert input_obj.is_leaf - # assert self.dtype == output_obj.dtype - # input_obj = torch.empty_like(output_obj, requires_grad=True, device=get_current_device(), dtype=self.dtype) - # input_obj.copy_(output_obj.clone().detach()) - assert isinstance(input_obj, torch.Tensor) self._input_objs[1].append(input_obj) else: object_send_next = output_obj recv_next_shape = self._output_obj_shapes[chunk_id] - - # chunk1 send output prev, recv output_grad prev + + # chunk1 send output prev, recv output_grad prev # chunk0 send output next, recv output_grad next tensor_recv_prev, tensor_recv_next = comm.fused_send_recv_tensor( object_send_next=object_send_next, @@ -605,47 +575,41 @@ def _schedule_1f1b_F(self, engine, chunk_id): dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors, ) - - if chunk_id == 1 and not self.chunk1_need_recv_prev_chunk1_grad: + + if chunk_id == 1 and not self._chunk1_need_recv_prev_chunk1_grad: assert tensor_recv_prev is None - + if tensor_recv_prev is not None: self._output_obj_grads[1].append(tensor_recv_prev) - + if tensor_recv_next is not None: self._output_obj_grads[0].append(tensor_recv_next) - - + def _schedule_1f1b_B_W(self, engine, chunk_id, next_unit_chunk_id, need_recv_chunk0_output=True): - + # 1B input_obj_grad = self._schedule_backward(engine, chunk_id) - - + object_send_next = None object_send_prev = None recv_next_shape = None recv_prev_shape = [] chunk0_B_need_recv_prev_chunk0_output = need_recv_chunk0_output - + if chunk_id == 1: if gpc.is_last_rank(ParallelMode.PIPELINE): # For last rank, chunk1 input_grad does not need to be sent but is directly used for chunk0. - # output_grad = input_obj_grad.clone().detach() self._output_obj_grads[0].append(input_obj_grad) else: object_send_next = input_obj_grad - + if next_unit_chunk_id == 1: if gpc.is_last_rank(ParallelMode.PIPELINE): assert False, "The last pp rank can never have two consecutive unit1 of the same chunk." recv_next_shape = self._input_obj_shapes[next_unit_chunk_id] - # else: - # if not (gpc.is_first_rank(ParallelMode.PIPELINE) or self._special_chunk0_forward): - # recv_prev_shape = self._input_obj_shapes[next_unit_chunk_id] else: assert next_unit_chunk_id != 0, "There will never be two consecutive chunk0 unit1." - + if not gpc.is_first_rank(ParallelMode.PIPELINE): object_send_prev = input_obj_grad # pre receive chunk1 grad @@ -653,23 +617,17 @@ def _schedule_1f1b_B_W(self, engine, chunk_id, next_unit_chunk_id, need_recv_chu # pre receive chunk0 input if chunk0_B_need_recv_prev_chunk0_output: recv_prev_shape.append(self._input_obj_shapes[0]) - + if not gpc.is_last_rank(ParallelMode.PIPELINE): recv_next_shape = self._input_obj_shapes[next_unit_chunk_id] - - if recv_prev_shape == []: + + if len(recv_prev_shape) == 0: recv_prev_shape = None - - + # chunk1 send input_grad next, chunk0 send input_grad prev - # if next_unit_chunk_id == 1, recv input_obj next - # if next_unit_chunk_id == 0, recv input_obj prev - - # input_shape = recv_next_shape if recv_next_shape is not None else recv_prev_shape - # tensor_to_send = object_send_next if object_send_next is not None else object_send_prev - # send_next = True if object_send_next is not None else False - # recv_next = True if recv_next_shape is not None else False - + # if chunk_id == 1 and next_unit_chunk_id == 1, recv chunk1 input next + # if chunk_id == 0 and next_unit_chunk_id == 1, pre-recv chunk1 grad recv; + # pre-recv chunk0 input prev and recv chunk1 input next async_communicator = comm.AsynCommunicator( object_send_prev=object_send_prev, object_send_next=object_send_next, @@ -679,17 +637,15 @@ def _schedule_1f1b_B_W(self, engine, chunk_id, next_unit_chunk_id, need_recv_chu scatter_gather_tensors=self.scatter_gather_tensors, ) async_communicator.start() - - + # 1W WeightGradStore.pop() self._call_hooks("after_backward", input_obj_grad) - + tensor_recv_prev, tensor_recv_next = async_communicator.wait_and_receive() - + # for the special case, input_obj has already been received and appended at the end of warmup. if next_unit_chunk_id == 0 and self._special_chunk0_forward: - assert tensor_recv_prev is None and tensor_recv_next is None self._special_chunk0_forward = False else: if chunk_id == 0: @@ -700,92 +656,43 @@ def _schedule_1f1b_B_W(self, engine, chunk_id, next_unit_chunk_id, need_recv_chu if gpc.is_first_rank(ParallelMode.PIPELINE): # first_rank only receive chunk1 input from next rank - assert tensor_recv_prev is None and tensor_recv_next is not None - assert isinstance(tensor_recv_next, torch.Tensor) self._input_objs[1].append(tensor_recv_next) elif gpc.is_last_rank(ParallelMode.PIPELINE): # For last rank, chunk1 input does not need to be received - if chunk0_B_need_recv_prev_chunk0_output: - assert isinstance(tensor_recv_prev, List) and len(tensor_recv_prev) == 2 - else: - assert len(tensor_recv_prev) == 1 - - assert isinstance(tensor_recv_prev[0], torch.Tensor) self._output_obj_grads[1].append(tensor_recv_prev[0]) if chunk0_B_need_recv_prev_chunk0_output: - assert isinstance(tensor_recv_prev[1], torch.Tensor) self._input_objs[0].append(tensor_recv_prev[1]) else: - if chunk0_B_need_recv_prev_chunk0_output: - assert isinstance(tensor_recv_prev, List) and len(tensor_recv_prev) == 2 - else: - assert len(tensor_recv_prev) == 1 - - assert isinstance(tensor_recv_next, torch.Tensor) self._output_obj_grads[1].append(tensor_recv_prev[0]) if chunk0_B_need_recv_prev_chunk0_output: - assert isinstance(tensor_recv_prev[1], torch.Tensor) self._input_objs[0].append(tensor_recv_prev[1]) - assert isinstance(tensor_recv_next, torch.Tensor) self._input_objs[1].append(tensor_recv_next) - # elif chunk_id == 1: else: if next_unit_chunk_id == 1: - if gpc.is_last_rank(ParallelMode.PIPELINE): - assert False - assert isinstance(tensor_recv_next, torch.Tensor) and tensor_recv_prev is None self._input_objs[1].append(tensor_recv_next) - else: - assert tensor_recv_prev is None and tensor_recv_next is None - # else: - # # chunk0 and chunk0_B_need_recv_prev_chunk0_output==False - # # stage1 last chunk0 or stage2 - # assert next_unit_chunk_id == 1 - # if gpc.is_first_rank(ParallelMode.PIPELINE): - # assert isinstance(recv_tensor, torch.Tensor) - # self._input_objs[1].append(recv_tensor) - # elif gpc.is_last_rank(ParallelMode.PIPELINE): - # assert isinstance(recv_tensor, List) and len(recv_tensor) == 1 - # self._output_obj_grads[1].append(recv_tensor[0]) - # else: - # assert isinstance(recv_tensor, tuple), f"{gpc.get_global_rank()}, {self._micro_step[chunk_id]}, {recv_tensor}, {recv_next_shape}" - # tensor_recv_prev, tensor_recv_next = recv_tensor - # assert isinstance(tensor_recv_prev, List) and len(tensor_recv_prev) == 1 - # self._output_obj_grads[1].append(tensor_recv_prev[0]) - # self._input_objs[1].append(tensor_recv_next) - - - # if not (next_unit_chunk_id == 1 and gpc.is_last_rank(ParallelMode.PIPELINE)): - # if not (next_unit_chunk_id == 0 and gpc.is_first_rank(ParallelMode.PIPELINE)): - # assert input_obj is not None, f"{gpc.get_global_rank()} chunk{chunk_id} next_unit_chunk_id{next_unit_chunk_id} receive none input BW" - # else: - # assert input_obj is None - # self._input_objs[next_unit_chunk_id].append(input_obj) - - + def _1f1b_unit_1(self, engine, chunk_id, next_unit_chunk_id, need_recv_chunk0_output): """ unit1 consists of: 1F + 1B + 1W, all are chunk0 or chunk1 """ # 1F self._schedule_1f1b_F(engine, chunk_id) - + # 1B + 1W self._schedule_1f1b_B_W(engine, chunk_id, next_unit_chunk_id, need_recv_chunk0_output) - + def _1f1b_unit_2(self, engine, chunk_id): """ unit2 consists of: chunk1 (1F + 1B + 1W) + chunk0 (1B + 1W) """ - assert chunk_id == 1 assert not gpc.is_last_rank(ParallelMode.PIPELINE) - + # 1F (chunk1) self._schedule_1f1b_F(engine, chunk_id) - + # 1B + 1W (chunk1) input_obj_grad = self._schedule_backward(engine, chunk_id) - + # chunk1 send input_grad next, chunk0 recv output_grad next async_communicator = comm.AsynCommunicator( object_send_next=input_obj_grad, @@ -794,54 +701,42 @@ def _1f1b_unit_2(self, engine, chunk_id): scatter_gather_tensors=self.scatter_gather_tensors, ) async_communicator.start() - + WeightGradStore.pop() self._call_hooks("after_backward", input_obj_grad) - + _, output_obj_grad = async_communicator.wait_and_receive() - assert isinstance(output_obj_grad, torch.Tensor) self._output_obj_grads[1 - chunk_id].append(output_obj_grad) - - + # 1B + 1W (chunk0) self._schedule_1f1b_B_W(engine, 1 - chunk_id, chunk_id, need_recv_chunk0_output=False) - - + def _schedule_warmup_F(self, engine, chunk_id, input_obj=None, forward_only=False): - self._micro_step[chunk_id] += 1 output_obj = self._forward_step(engine, chunk_id, input_obj) - if input_obj is not None: - self.map_input_output[chunk_id][id(input_obj)] = id(output_obj) - else: - if self._input_objs[chunk_id][-1] is not None: - self.map_input_output[chunk_id][id(self._input_objs[chunk_id][-1])] = id(output_obj) - - + if forward_only: # when forward-only, no need to save tensors for a backward pass self._input_objs[chunk_id].pop() self._output_objs[chunk_id].pop() self._moe_losses[chunk_id].pop() - + if not gpc.is_pipeline_last_stage(): if isinstance(output_obj, torch.Tensor): self._output_obj_shapes[chunk_id] = output_obj.shape else: self._output_obj_shapes[chunk_id] = [out_tensor.shape for out_tensor in output_obj] - + assert self._output_obj_shapes[chunk_id] == self._input_obj_shapes[chunk_id] if self._send_tensor_shape_flags[chunk_id]: comm.send_obj_meta(output_obj) self._send_tensor_shape_flags[chunk_id] = False # send only once for each chunk. - + if not gpc.is_pipeline_first_stage() and self._input_obj_shapes[chunk_id] is None: self._input_obj_shapes[chunk_id] = comm.recv_obj_meta() - - assert output_obj is None or output_obj.dtype == self.dtype - + return output_obj - + def _run_warmup_loop( self, engine: Engine, @@ -849,37 +744,23 @@ def _run_warmup_loop( forward_only: bool = False, ) -> None: """ - Run the warm-up loop and prepare data for the 1F1B stage. - - During the warm-up process, for each execution, it first performs a forward computation, - and then sends the computation result to the next stage. - It also receives data for the next forward computation. - Since the input for the first forward computation is not considered initially, - it needs to receive data once at the beginning. - - After the warm-up is completed, we need to prepare data for the 1F1B stage. - The data preparation process should be consistent with the communication method of the 1F1B stage. + Run the warm-up loop and prepare data for the steady stage. Args: engine (Engine): The engine to run the warm-up loop. - num_microsteps (int): The total number of microsteps. num_warmup_microsteps (int): The number of warm-up microsteps. - receive_extra_backward (bool, optional): Whether to receive extra backward input for the 1F1B stage. - Default is False. forward_only (bool, optional): Whether to only perform forward pass. Default is False. """ - + # For each rank, the warmup stage will be divided into two sub-phases for scheduling. num_warmup_microsteps_phase_1 = min(self.num_microbatches, (self._pp_size - self._pp_rank) * 2 - 1) num_warmup_microsteps_phase_2 = num_warmup_microsteps - num_warmup_microsteps_phase_1 - - - + if gpc.is_first_rank(ParallelMode.PIPELINE): assert num_warmup_microsteps_phase_2 == 0 if gpc.is_last_rank(ParallelMode.PIPELINE): assert num_warmup_microsteps_phase_1 == 1 - + # get first forward input chunk_id = 0 if not gpc.is_pipeline_first_stage(): @@ -894,8 +775,7 @@ def _run_warmup_loop( ) else: self._input_objs[chunk_id].append(None) - if not gpc.is_first_rank(ParallelMode.PIPELINE): - assert self._input_objs[chunk_id][-1] is not None, f"{gpc.get_global_rank()} chunk{chunk_id} receive none input warmup before" + # Phase1 will only do chunk0 forward for micro_step in range(num_warmup_microsteps_phase_1): # forward @@ -905,48 +785,22 @@ def _run_warmup_loop( recv_prev_shape = None recv_next_shape = None - # For stage1, the last chunk0 unit needs to do recv op to prevent the sender from being blocked. + # For stage1, the last chunk0 unit needs to do recv op to prevent the sender from being blocked. if not gpc.is_first_rank(ParallelMode.PIPELINE): recv_prev_shape = self._input_obj_shapes[0] # For last rank, chunk0 output does not need to be sent but is directly used for chunk1. - if not gpc.is_last_rank(ParallelMode.PIPELINE): + if not gpc.is_last_rank(ParallelMode.PIPELINE): object_send_next = output_obj else: input_obj = output_obj.clone().detach() input_obj.requires_grad_() - assert input_obj.is_leaf self._input_objs[1].append(input_obj) if micro_step == num_warmup_microsteps_phase_1 - 1: - if not gpc.is_last_rank(ParallelMode.PIPELINE): + if not gpc.is_last_rank(ParallelMode.PIPELINE): recv_next_shape = self._input_obj_shapes[1] - - - # if micro_step != num_warmup_microsteps_phase_1 - 1: - # object_send_next = output_obj - # # self._input_objs[chunk_id].append( - # # comm.send_forward_recv_forward( - # # output_obj, - # # input_shape, - # # dtype=self.dtype, - # # scatter_gather_tensors=self.scatter_gather_tensors, - # # ) - # # ) - # # if not gpc.is_pipeline_first_stage(): - # # assert self._input_objs[chunk_id][-1] is not None, f"{gpc.get_global_rank()} chunk{chunk_id} receive none input warmup1" - # else: - # if not gpc.is_last_rank(ParallelMode.PIPELINE): - # object_send_next = output_obj - # recv_next_shape = self._input_obj_shapes[1] - # else: - # # For last rank, chunk0 output does not need to be sent but is directly used for chunk1 - # input_obj = output_obj.clone().detach() - # input_obj.requires_grad_() - # assert input_obj.is_leaf - # self._input_objs[1].append(input_obj) - - + tensor_recv_prev, tensor_recv_next = comm.fused_send_recv_tensor( object_send_next=object_send_next, recv_prev_shape=recv_prev_shape, @@ -955,46 +809,40 @@ def _run_warmup_loop( scatter_gather_tensors=self.scatter_gather_tensors, ) - if gpc.is_first_rank(ParallelMode.PIPELINE): - assert tensor_recv_prev is None self._input_objs[0].append(tensor_recv_prev) - + if micro_step == num_warmup_microsteps_phase_1 - 1: if not gpc.is_last_rank(ParallelMode.PIPELINE): self._input_objs[1].append(tensor_recv_next) - else: - assert tensor_recv_next is None # Phase2 will execute chunk1 and chunk0 forward alternately for micro_step in range(num_warmup_microsteps_phase_2): - assert not gpc.is_first_rank(ParallelMode.PIPELINE) chunk_id = 1 - chunk_id next_chunk_id = 1 - chunk_id - + if chunk_id == 0: input_obj = self._input_objs[chunk_id][-2] else: input_obj = self._input_objs[chunk_id][-1] - + output_obj = self._schedule_warmup_F(engine, chunk_id, input_obj=input_obj, forward_only=forward_only) - - + object_send_next = None - object_send_prev = None + object_send_prev = None recv_next_shape = None recv_prev_shape = None - + if chunk_id == 1: assert micro_step < num_warmup_microsteps_phase_2 - 1 object_send_prev = output_obj recv_prev_shape = self._input_obj_shapes[next_chunk_id] else: - if not gpc.is_last_rank(ParallelMode.PIPELINE): + if not gpc.is_last_rank(ParallelMode.PIPELINE): object_send_next = output_obj recv_next_shape = self._input_obj_shapes[next_chunk_id] - - # chunk1 send output prev, chunk0 recv input prev - # chunk0 send output next, chunk1 recv input next + + # chunk1 send output prev, chunk0 recv input prev + # chunk0 send output next, chunk1 recv input next tensor_recv_prev, tensor_recv_next = comm.fused_send_recv_tensor( object_send_next=object_send_next, object_send_prev=object_send_prev, @@ -1003,25 +851,23 @@ def _run_warmup_loop( dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors, ) - + # For last rank, chunk0 output does not need to be sent but is directly used for chunk1 if chunk_id == 0 and gpc.is_last_rank(ParallelMode.PIPELINE): input_obj = output_obj.clone().detach() input_obj.requires_grad_() - assert input_obj.is_leaf else: input_obj = tensor_recv_prev if tensor_recv_prev is not None else tensor_recv_next - + self._input_objs[next_chunk_id].append(input_obj) - if not gpc.is_pipeline_first_stage(): - assert self._input_objs[next_chunk_id][-1] is not None, f"{gpc.get_global_rank()} chunk{chunk_id} next_chunk_id{next_chunk_id} receive none input warmup3" - - def _run_1f1b_loop(self, + + def _run_steady_loop( + self, engine: Engine, num_1f1b_units: int, ) -> None: """ - 1F1B unit schedule: + 1F1B unit schedule: stage1: (pp_size + 1 + pp_rank + 2 * (micro_num - 2 * pp_size)) * unit1 stage2: (pp_size - 1 - pp_rank) * unit2 stage3: 1 * special chunk1 unit1 @@ -1034,49 +880,49 @@ def _run_1f1b_loop(self, num_units_stage1 = 2 * self.num_microbatches - 3 * self._pp_size + 1 + self._pp_rank num_units_stage2 = self._pp_size - 1 - self._pp_rank assert num_units_stage1 + num_units_stage2 + 1 == num_1f1b_units - + # chunk schedule: stage1 + stage2 + stage1 # stage1: chunk1 # stage2: chunk0 and chunk1 alternately stage1_length = self._pp_size - self._pp_rank - stage2_length = 2 * self._pp_rank + 1 + 2 * (self.num_microbatches - 2 * self._pp_size ) - assert stage1_length * 2 + stage2_length == num_1f1b_units - stage2_list = [i for i in range(stage1_length, stage1_length + stage2_length)] + stage2_length = 2 * self._pp_rank + 1 + 2 * (self.num_microbatches - 2 * self._pp_size) + stage2_list = list(range(stage1_length, stage1_length + stage2_length)) chunk0_units = [stage2_list[i] for i in range(len(stage2_list)) if i % 2 == 0] - + # unit stage1 for unit_step in range(num_units_stage1): if unit_step in chunk0_units: chunk_id = 0 else: chunk_id = 1 - + if unit_step + 1 in chunk0_units: next_unit_chunk_id = 0 else: next_unit_chunk_id = 1 - + # import pdb; pdb.set_trace() if unit_step == num_units_stage1 - 1: chunk0_B_need_recv_prev_chunk0_output = False else: chunk0_B_need_recv_prev_chunk0_output = True - - self._1f1b_unit_1(engine, chunk_id, next_unit_chunk_id, need_recv_chunk0_output=chunk0_B_need_recv_prev_chunk0_output) - + + self._1f1b_unit_1( + engine, chunk_id, next_unit_chunk_id, need_recv_chunk0_output=chunk0_B_need_recv_prev_chunk0_output + ) + # unit stage2 for unit_step in range(num_units_stage2): assert unit_step + num_units_stage1 not in chunk0_units self._1f1b_unit_2(engine, 1) - + # unit stage3 assert num_1f1b_units - 1 not in chunk0_units self._schedule_1f1b_F(engine, 1) - origin = engine.optimizer.skip_grad_reduce + origin_skip = engine.optimizer.skip_grad_reduce input_obj_grad = self._schedule_backward(engine, 1) if gpc.is_last_rank(ParallelMode.PIPELINE): # For last rank, chunk1 input_grad does not need to be sent but is directly used for chunk0. - # output_grad = input_obj_grad.clone().detach() self._output_obj_grads[0].append(input_obj_grad) tensor_to_send = None recv_shape = None @@ -1092,21 +938,18 @@ def _run_1f1b_loop(self, scatter_gather_tensors=self.scatter_gather_tensors, ) async_communicator.start() - - + WeightGradStore.pop() self._call_hooks("after_backward", input_obj_grad) - engine.optimizer.skip_grad_reduce = origin - - + engine.optimizer.skip_grad_reduce = origin_skip + _, output_obj_grad = async_communicator.wait_and_receive() - assert output_obj_grad is None or isinstance(output_obj_grad, torch.Tensor) if not gpc.is_last_rank(ParallelMode.PIPELINE): self._output_obj_grads[0].append(output_obj_grad) - + def _run_cooldown_loop(self, engine): """ - Cooldown unit schedule: + Cooldown unit schedule: Unit: 1B + 1W Schedule unit chunk0 and unit chunk1 alternatively Each pp rank has pp_size chunk0, but only pp_rank chunk1 @@ -1114,44 +957,42 @@ def _run_cooldown_loop(self, engine): chunk0_length = self._pp_size chunk1_length = self._pp_rank num_cooldown_units = chunk0_length + chunk1_length - total_list = [i for i in range(chunk1_length * 2)] + total_list = list(range(chunk1_length * 2)) chunk1_units = [total_list[i] for i in range(chunk1_length * 2) if i % 2 != 0] - - + cool_down = [0, 0] - + for unit_step in range(num_cooldown_units): if unit_step in chunk1_units: chunk_id = 1 else: chunk_id = 0 - + cool_down[chunk_id] += 1 if unit_step + 1 in chunk1_units: next_unit_chunk_id = 1 else: next_unit_chunk_id = 0 - - origin = engine.optimizer.skip_grad_reduce + + origin_skip = engine.optimizer.skip_grad_reduce input_obj_grad = self._schedule_backward(engine, chunk_id) object_send_next = None object_send_prev = None recv_next_shape = None recv_prev_shape = None - + if chunk_id == 1: assert not gpc.is_first_rank(ParallelMode.PIPELINE) if gpc.is_last_rank(ParallelMode.PIPELINE): # For last rank, chunk1 input_grad does not need to be sent but is directly used for chunk0. - # output_grad = input_obj_grad.clone().detach() self._output_obj_grads[0].append(input_obj_grad) else: object_send_next = input_obj_grad # next unit should be chunk0 - recv_next_shape = self._output_obj_shapes[0] - else: + recv_next_shape = self._output_obj_shapes[0] + else: if not gpc.is_first_rank(ParallelMode.PIPELINE): object_send_prev = input_obj_grad @@ -1162,8 +1003,7 @@ def _run_cooldown_loop(self, engine): else: assert not gpc.is_last_rank(ParallelMode.PIPELINE) recv_next_shape = self._output_obj_shapes[next_unit_chunk_id] - - + # chunk1 send input_grad next, chunk0 send input_grad prev # if next_unit_chunk_id == 1, recv output_grad prev # if next_unit_chunk_id == 0, recv output_grad next @@ -1176,21 +1016,18 @@ def _run_cooldown_loop(self, engine): scatter_gather_tensors=self.scatter_gather_tensors, ) async_communicator.start() - + # 1W - WeightGradStore.pop() self._call_hooks("after_backward", input_obj_grad) - engine.optimizer.skip_grad_reduce = origin - + engine.optimizer.skip_grad_reduce = origin_skip + tensor_recv_prev, tensor_recv_next = async_communicator.wait_and_receive() output_obj_grad = tensor_recv_prev if tensor_recv_prev is not None else tensor_recv_next - assert output_obj_grad is None or isinstance(output_obj_grad, torch.Tensor) - - # if not(next_unit_chunk_id == 0 and gpc.is_last_rank(ParallelMode.PIPELINE)): + if output_obj_grad is not None: self._output_obj_grads[next_unit_chunk_id].append(output_obj_grad) - + def _forward_only_step(self, engine: Engine): num_warmup_steps = self.num_microbatches * self._num_chunks @@ -1199,37 +1036,27 @@ def _forward_only_step(self, engine: Engine): num_warmup_steps, forward_only=True, ) - - - + def _forward_backward_step(self, engine: Engine): assert self.num_microbatches > self._pp_size - + # Compute number of warmup microbatches. num_warmup_steps = self._pp_size * 2 - 1 - + # Compute number of 1F1B unit. num_1f1b_units = 2 * self.num_microbatches - num_warmup_steps - + # 1. Warmup self._run_warmup_loop( engine, num_warmup_steps, ) - + # 2. 1F1B - self._run_1f1b_loop( + self._run_steady_loop( engine, num_1f1b_units, ) - + # 3. cooldown self._run_cooldown_loop(engine) - # import pdb; pdb.set_trace() - assert len(self._input_objs[0]) == 0 and len(self._input_objs[1]) == 0, f"{gpc.get_global_rank()}" - assert len(self._output_objs[0]) == 0 and len(self._output_objs[1]) == 0, f"{gpc.get_global_rank()}" - assert len(self._output_obj_grads[0]) == 0 and len(self._output_obj_grads[1]) == 0, f"{gpc.get_global_rank()}" - - assert WeightGradStore.size() == 0 - # assert self._backward_step_num[0] == 8 and self._backward_step_num[1] == 8 - \ No newline at end of file diff --git a/internlm/core/trainer.py b/internlm/core/trainer.py index e228d7c1..3b01d3af 100644 --- a/internlm/core/trainer.py +++ b/internlm/core/trainer.py @@ -9,10 +9,7 @@ from typing import Iterable, Optional from internlm.core.engine import Engine -from internlm.core.scheduler import ( - BaseScheduler, - NonPipelineScheduler, -) +from internlm.core.scheduler import BaseScheduler, NonPipelineScheduler class TrainState: diff --git a/internlm/core/trainer_builder.py b/internlm/core/trainer_builder.py index c8153930..d0ef284d 100644 --- a/internlm/core/trainer_builder.py +++ b/internlm/core/trainer_builder.py @@ -99,7 +99,7 @@ def __init__( self.current_time = self._setup_time_and_logging() # load config_lines config_lines = self._read_config(kwargs["config"]) - + # set tracking name for parameters set_param_unique_tracking_name(model) diff --git a/internlm/initialize/initialize_trainer.py b/internlm/initialize/initialize_trainer.py index ef593b5c..3ebe1749 100644 --- a/internlm/initialize/initialize_trainer.py +++ b/internlm/initialize/initialize_trainer.py @@ -99,7 +99,10 @@ def _data_preparation_func(_data, _label): gpc.config.NUM_MICRO_BATCHES = gpc.config.data.micro_num tensor_shape = get_tensor_shape() use_interleaved = ( - hasattr(gpc.config, "model") and hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1 and gpc.config.parallel["pipeline"]["mode"] == "1F1B" + hasattr(gpc.config, "model") + and hasattr(gpc.config.model, "num_chunks") + and gpc.config.model.num_chunks > 1 + and gpc.config.parallel["pipeline"]["mode"] == "1F1B" ) scatter_gather = gpc.is_initialized(ParallelMode.TENSOR) if use_interleaved: diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 9e2ef8f9..fc63b8a2 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -85,8 +85,8 @@ def args_sanity_check(): gpc.config.parallel._add_item("zero1", dict(size=zero1_size, fsdp=False)) if "pipeline" not in gpc.config.parallel: - gpc.config.parallel._add_item("pipeline", dict(size=1, interleaved_overlap=False, mode='1F1B')) - + gpc.config.parallel._add_item("pipeline", dict(size=1, interleaved_overlap=False, mode="1F1B")) + if isinstance(gpc.config.parallel.pipeline, dict) and "mode" not in gpc.config.parallel.pipeline: gpc.config.parallel.pipeline._add_item("mode", "1F1B") @@ -106,13 +106,17 @@ def args_sanity_check(): pp = gpc.config.parallel.pipeline else: pp = gpc.config.parallel.pipeline.size - + if isinstance(gpc.config.parallel.pipeline, dict): gpc.config.parallel.pipeline["mode"] = gpc.config.parallel.pipeline["mode"].upper() - assert gpc.config.parallel.pipeline["mode"] in ["1F1B", "ZBH1", "ZBV"], f"unsupported pp mode {gpc.config.parallel.pipeline['mode']}" + assert gpc.config.parallel.pipeline["mode"] in [ + "1F1B", + "ZBH1", + "ZBV", + ], f"unsupported pp mode {gpc.config.parallel.pipeline['mode']}" if gpc.config.parallel.pipeline["mode"] == "ZBV": gpc.v_shape = True - + # check fsdp config if "fsdp" not in gpc.config.parallel.zero1: gpc.config.parallel.zero1._add_item("fsdp", False) @@ -450,7 +454,7 @@ def args_sanity_check(): assert ( gpc.config.parallel["pipeline"].get("interleaved_overlap", False) is True ), "only support interleaved pipeline scheduler with overlap" - + if gpc.config.parallel["pipeline"]["mode"] == "ZBV": gpc.config.model.num_chunks = 2 if gpc.is_rank_for_log(): @@ -537,8 +541,6 @@ def args_sanity_check(): assert ( gpc.config.parallel.expert.size <= 1 and gpc.config.parallel.expert_weight.size <= 1 ), "expert parallel is only supported in MoE setting" - - print(f"isp_overlap: {gpc.config.parallel['weight']['overlap']}", flush=True) # sequence_2D if "sequence_2D" not in gpc.config.parallel: diff --git a/internlm/model/modules/embedding.py b/internlm/model/modules/embedding.py index f16eb51d..93fcd6b2 100644 --- a/internlm/model/modules/embedding.py +++ b/internlm/model/modules/embedding.py @@ -66,7 +66,7 @@ def __init__( self.weight = nn.Parameter( torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition), dtype=dtype) ) - + setattr(self.weight, "is_embedding_param", True) def forward(self, input_: Tensor) -> Tensor: diff --git a/internlm/model/modules/linear.py b/internlm/model/modules/linear.py index 9b8fc967..64d79796 100644 --- a/internlm/model/modules/linear.py +++ b/internlm/model/modules/linear.py @@ -134,13 +134,16 @@ def backward(ctx, grad_output, *args): handle_x.wait() x = x.reshape(batch_dim, x.shape[-1]) - if ( - gpc.is_using_parallel_mode(ParallelMode.PIPELINE) - and ((gpc.config.parallel["pipeline"].get("mode", "1F1B") == "ZBH1" - and not gpc.is_first_rank(ParallelMode.PIPELINE)) - or gpc.config.parallel["pipeline"].get("mode", "1F1B") == "ZBV") + if gpc.is_using_parallel_mode(ParallelMode.PIPELINE) and ( + ( + gpc.config.parallel["pipeline"].get("mode", "1F1B") == "ZBH1" + and not gpc.is_first_rank(ParallelMode.PIPELINE) + ) + or gpc.config.parallel["pipeline"].get("mode", "1F1B") == "ZBV" ): - from internlm.core.scheduler.pipeline_scheduler_zb import WeightGradStore + from internlm.core.scheduler.pipeline_scheduler_zb import ( + WeightGradStore, + ) WeightGradStore.put(weight, bias, x, grad_output, ctx.needs_input_grad[2], linear_backward_op) grad_weight, grad_bias = None, None @@ -238,11 +241,12 @@ def backward(ctx, grad_output, *args): total_weight = communicator.weight_hook(weight, module=module) - is_using_ZB = ( - gpc.is_using_parallel_mode(ParallelMode.PIPELINE) - and ((gpc.config.parallel["pipeline"].get("mode", "1F1B") == "ZBH1" - and not gpc.is_first_rank(ParallelMode.PIPELINE)) - or gpc.config.parallel["pipeline"].get("mode", "1F1B") == "ZBV") + is_using_ZB = gpc.is_using_parallel_mode(ParallelMode.PIPELINE) and ( + ( + gpc.config.parallel["pipeline"].get("mode", "1F1B") == "ZBH1" + and not gpc.is_first_rank(ParallelMode.PIPELINE) + ) + or gpc.config.parallel["pipeline"].get("mode", "1F1B") == "ZBV" ) # compute weight grad @@ -250,10 +254,19 @@ def backward(ctx, grad_output, *args): assert ctx.compute_weight_gradient x = x.reshape(batch_dim, x.shape[-1]) if is_using_ZB: - from internlm.core.scheduler.pipeline_scheduler_zb import WeightGradStore + from internlm.core.scheduler.pipeline_scheduler_zb import ( + WeightGradStore, + ) WeightGradStore.put( - weight, bias, x, grad_output, ctx.needs_input_grad[2], linear_backward_op, communicator.grad_hook, module + weight, + bias, + x, + grad_output, + ctx.needs_input_grad[2], + linear_backward_op, + communicator.grad_hook, + module, ) grad_weight, grad_bias = None, None else: diff --git a/internlm/model/modules/mlp.py b/internlm/model/modules/mlp.py index adfc7ea7..e3248e6c 100644 --- a/internlm/model/modules/mlp.py +++ b/internlm/model/modules/mlp.py @@ -101,7 +101,7 @@ def __init__( ) self.w3 = new_linear( "w3", in_features, hidden_features, bias, device=device, dtype=dtype, is_expert=is_expert - ) + ) self.w2 = new_linear( "w2", hidden_features, out_features, bias, device=device, dtype=dtype, is_expert=is_expert ) diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index b88541b9..49f3fbcf 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -298,7 +298,8 @@ def _is_moe_group(self, param_group): # TODO check expert dp is correct when enable moe and overlap both def _attach_reduction_hook(self): - from internlm.core.scheduler.pipeline_scheduler import WeightGradStore + from internlm.core.scheduler.pipeline_scheduler_zb import WeightGradStore + is_using_ZB = gpc.config.parallel["pipeline"].get("mode", "1F1B") != "1F1B" # we iterate over the fp16 params # on each param, we register a hook to its AccumulateGrad object @@ -308,9 +309,8 @@ def _attach_reduction_hook(self): # we should not reduce the param in moe if not param.requires_grad: continue - - if is_using_ZB: - hooks = [] + + hooks = [] reduce_rank = None @@ -396,7 +396,7 @@ def extra_layernorm_reduce_grad_hook(*args): # pylint: disable=W0613 ) ): if is_using_ZB and not hasattr(param, "is_embedding_param"): - hooks.append(accum_grad_hook) + hooks.append(accum_grad_hook) # pylint: disable=W0640 else: if hasattr(param, "evo_tensor"): param.register_post_accumulate_grad_hook(accum_grad_hook) @@ -405,7 +405,7 @@ def extra_layernorm_reduce_grad_hook(*args): # pylint: disable=W0613 if self._overlap_sync_grad: if is_using_ZB and not hasattr(param, "is_embedding_param"): - hooks.append(reduce_grad_hook) + hooks.append(reduce_grad_hook) # pylint: disable=W0640 else: if hasattr(param, "evo_tensor"): param.register_post_accumulate_grad_hook(reduce_grad_hook) @@ -414,6 +414,7 @@ def extra_layernorm_reduce_grad_hook(*args): # pylint: disable=W0613 _define_and_attach(param, reduce_rank) if len(hooks) > 0: + assert is_using_ZB WeightGradStore.register_hook(param, hooks) def accumulate_left_grads_after_backward(self): @@ -422,8 +423,8 @@ def accumulate_left_grads_after_backward(self): for group_id in range(self.num_param_groups): self._accum_grads_store_in_bucket(self._accum_grad_buckets[group_id]) - - def reduce_left_grads_after_backward(self): + + def reduce_left_grads_after_backward(self): for group_id in range(self.num_param_groups): self._reduce_grads_stored_in_bucket(self._bucket_store[group_id], reduce_rank=None) @@ -1025,4 +1026,3 @@ def reload_zero_fp32_buff(self): ) # param_group["params"] is fp32 flatten optimizer states of this zero rank. param_group["params"][0].data.copy_(fp16_flat_current_rank.float()) - \ No newline at end of file diff --git a/internlm/solver/optimizer/hybrid_zero_optim_v2.py b/internlm/solver/optimizer/hybrid_zero_optim_v2.py index dfac3b6a..36e5f073 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim_v2.py +++ b/internlm/solver/optimizer/hybrid_zero_optim_v2.py @@ -852,6 +852,7 @@ def reload_zero_fp32_buff(self): def _attach_reduction_hook(self): from internlm.core.scheduler.pipeline_scheduler_zb import WeightGradStore + is_using_ZB = gpc.config.parallel["pipeline"].get("mode", "1F1B") != "1F1B" # we iterate over the fp16 params # on each param, we register a hook to its AccumulateGrad object @@ -862,8 +863,7 @@ def _attach_reduction_hook(self): if not param.requires_grad: continue - if is_using_ZB: - hooks = [] + hooks = [] reduce_rank = None @@ -916,13 +916,13 @@ def extra_layernorm_reduce_grad_hook(*args): # pylint: disable=W0613 and gpc.config.parallel.weight.size > 1 ): if is_using_ZB and not hasattr(param, "is_embedding_param"): - hooks.append(accum_grad_hook) + hooks.append(accum_grad_hook) # pylint: disable=W0640 else: param.register_post_accumulate_grad_hook(accum_grad_hook) if self._overlap_sync_grad: if is_using_ZB and not hasattr(param, "is_embedding_param"): - hooks.append(partial(grad_handler, group_id)) + hooks.append(partial(grad_handler, group_id)) # pylint: disable=W0640 else: param.register_post_accumulate_grad_hook( partial(grad_handler, group_id)