From 0d4ea37fabd2f539964eb0238fe4e3086ca97a6e Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Tue, 22 Aug 2023 15:35:52 +0800 Subject: [PATCH 01/25] fix is_first_layer --- bmtrain/block_layer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 8e52c68a..18438b8c 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -528,8 +528,8 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) module._mode = "ZERO" module.set_pre_module(pre_module) pre_module = module - self._is_first_layer = False - self._is_last_layer = False + module._is_first_layer = False + module._is_last_layer = False self._modules[str(i)] = module self.add_module(str(i), module) From 3063afb129f3a97b3f13ed84b2ca983886cba61e Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Wed, 23 Aug 2023 19:36:55 +0800 Subject: [PATCH 02/25] tensor parallel --- bmtrain/block_layer.py | 49 ++++++--- bmtrain/checkpointing.py | 12 +-- bmtrain/init.py | 63 ++++++++--- bmtrain/loss/cross_entropy.py | 38 ++++--- bmtrain/nn/__init__.py | 4 + bmtrain/nn/column_parallel_linear.py | 35 ++++++ bmtrain/nn/cross_entropy.py | 135 ++++++++++++++++++++++++ bmtrain/nn/embedding.py | 87 +++++++++++++++ bmtrain/nn/parallel_linear_hook_func.py | 91 ++++++++++++++++ bmtrain/nn/row_parallel_linear.py | 37 +++++++ bmtrain/param_init.py | 19 +++- bmtrain/parameter.py | 71 ++++++++++--- example/layers/attention.py | 52 ++++++--- example/layers/embedding.py | 9 +- example/layers/feedforward.py | 15 ++- example/models/gpt.py | 18 +++- example/train.py | 15 ++- 17 files changed, 659 insertions(+), 91 deletions(-) create mode 100644 bmtrain/nn/column_parallel_linear.py create mode 100644 bmtrain/nn/cross_entropy.py create mode 100644 bmtrain/nn/embedding.py create mode 100644 bmtrain/nn/parallel_linear_hook_func.py create mode 100644 bmtrain/nn/row_parallel_linear.py diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 18438b8c..4bc1cc2d 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -94,7 +94,8 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True): "total": 0, "storage_type": storage_type, "requires_grad": param.requires_grad, - "group": param.group + "group": param.group, + "zero_comm" : param._zero_comm } param_shape = param._original_shape @@ -108,11 +109,14 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True): offsets = {} # intialize storage buffers for kw, val in self._storage_info.items(): - val["world_size"] = config["world_size"] + comm = val['zero_comm'] + world_size = nccl.commCount(comm) + rank = nccl.commRank(comm) + val["world_size"] = world_size #config["dp_size"] partition_size = round_up(val["total"], val["world_size"]) // val["world_size"] val["partition_size"] = partition_size - val["begin"] = config['rank'] * partition_size - val["end"] = (config['rank'] + 1) * partition_size + val["begin"] = rank * partition_size #config['zero_rank'] * partition_size + val["end"] = (rank+1) * partition_size #(config['zero_rank'] + 1) * partition_size offsets[kw] = 0 @@ -301,13 +305,16 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, if key in state_dict: # load here input_param = state_dict[key] + param = it['parameter'] + tp_mode = param._tp_mode if input_param.__class__.__name__ == "DistributedTensorWrapper": input_param = input_param.broadcast() - if input_param.shape != it["shape"]: + if not tp_mode and input_param.shape != it["shape"]: error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' 'the shape in current model is {}.' .format(key, input_param.shape, it["shape"])) continue + param_st = it["offset"] param_end = it["offset"] + it["size"] kw_name = it["kw_name"] @@ -321,8 +328,18 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, continue # copy to buffer - assert input_param.numel() == it["size"] + if not tp_mode: + assert input_param.numel() == it["size"] + contiguous_param = input_param.to(it["parameter"].dtype).cuda().contiguous() + + if tp_mode: + tp_split_dim = param._tp_split_dim + if tp_split_dim >= 0: + param_list = contiguous_param.chunk(config['tp_size'], dim=tp_split_dim) + sub_tensor = param_list[config['topology'].tp_id] + contiguous_param = torch.empty(sub_tensor.shape, device=sub_tensor.device, dtype=sub_tensor.dtype) + contiguous_param.copy_(sub_tensor) offset_st = max(storage_st - param_st, 0) offset_end = min(storage_end - param_st, contiguous_param.numel()) @@ -330,7 +347,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, to_offset_st = offset_st + param_st - storage_st to_offset_end = offset_end + param_st - storage_st - + # copy to buffer # PyTorch 1.11 changed the API of storage.__getitem__ d_dtype = self._storage_params[kw_name].dtype @@ -397,7 +414,7 @@ def init_parameters(self): param = it["parameter"] if isinstance(param, DistributedParameter) and param._init_method is not None: # initialzie here - tmp_tensor = torch.empty(it["shape"], device=param.device, dtype=param.dtype) + tmp_tensor = torch.empty(param._tp_original_shape, device=param.device, dtype=param.dtype) param._init_method(tmp_tensor) param_st = it["offset"] param_end = it["offset"] + it["size"] @@ -411,16 +428,18 @@ def init_parameters(self): if param_end <= storage_st: continue + if param._tp_mode and param._tp_split_dim >= 0: + tensor_list = tmp_tensor.chunk(config['tp_size'], dim=param._tp_split_dim) + sub_tensor = tensor_list[config['topology'].tp_id].contiguous() + tmp_tensor = torch.empty(sub_tensor.shape, device=param.device, dtype=sub_tensor.dtype) + tmp_tensor.copy_(sub_tensor) # copy to buffer assert tmp_tensor.is_contiguous() and it["size"] == tmp_tensor.numel() - offset_st = max(storage_st - param_st, 0) - offset_end = min(storage_end - param_st, tmp_tensor.numel()) + offset_st = max(storage_st - param_st, 0) + offset_end = min(storage_end - param_st, tmp_tensor.numel()) assert offset_st < offset_end - to_offset_st = offset_st + param_st - storage_st - to_offset_end = offset_end + param_st - storage_st - # copy to buffer # PyTorch 1.11 changed the API of storage.__getitem__ d_dtype = self._storage_params[kw_name].dtype @@ -528,8 +547,8 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) module._mode = "ZERO" module.set_pre_module(pre_module) pre_module = module - module._is_first_layer = False - module._is_last_layer = False + self._is_first_layer = False + self._is_last_layer = False self._modules[str(i)] = module self.add_module(str(i), module) diff --git a/bmtrain/checkpointing.py b/bmtrain/checkpointing.py index b2c9ec07..3adbc105 100644 --- a/bmtrain/checkpointing.py +++ b/bmtrain/checkpointing.py @@ -39,10 +39,8 @@ def __init__(self, block : 'CheckpointBlock', ctx_dict : dict = None, pipe = Fal self._param_tensor = {} self._grad_tensor = {} self._need_release = False - if pipe: - self.comm = config["zero_comm"] - else: - self.comm = config["comm"] + #self.comm = config["zero_comm"] + def enter(self, flag=0, requires_grad=False): """ gather parameters @@ -74,7 +72,8 @@ def enter(self, flag=0, requires_grad=False): nccl.allGather( self.block._storage_params[kw].storage(), self._param_buffer[kw], - self.comm + #self.comm + val['zero_comm'] ) nccl.groupEnd() @@ -144,7 +143,8 @@ def exit(self, flag=0, backward=False): self._grad_buffer[kw], local_param.grad.storage(), "sum", - self.comm + #self.comm + val['zero_comm'] ) nccl.groupEnd() diff --git a/bmtrain/init.py b/bmtrain/init.py index 1fa0712d..a403caf7 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -10,12 +10,12 @@ from . import nccl from .synchronize import synchronize - def init_distributed( init_method : str = "env://", seed : int = 0, zero_level: int = 3, pipe_size: int = -1, + tp_size = 1, num_micro_batches: int = None, ): """Initialize distributed training. @@ -75,8 +75,10 @@ def init_distributed( config['barrier_stream'] = torch.cuda.Stream() config["load_event"] = torch.cuda.Event() config["zero_level"] = zero_level + config["tp_size"] = tp_size if tp_size > 0 else 1 config["topology"] = topology(config) - config["zero_rank"] = config["topology"].get_group_rank("zero") if pipe_size > 1 else config['rank'] + config["zero_rank"] = config['topology'].get_group_rank("zero") + config["tp_zero_rank"] = config['topology'].get_group_rank("tp_zero") cpus_this_worker = None all_available_cpus = sorted(list(os.sched_getaffinity(0))) @@ -105,6 +107,8 @@ def init_distributed( unique_id = bytes.fromhex(store.get("BMTRAIN_UNIQUE_ID").decode()) config['comm'] = nccl.commInitRank(unique_id, world_size, rank) + config['zero_comm'] = config['comm'] + if config['pipe_enabled']: config["micros"] = num_micro_batches if num_micro_batches else config["pipe_size"] topo = config['topology'] @@ -113,13 +117,24 @@ def init_distributed( store.set(f"PIPE_UNIQUE_ID{topo.pipe_idx}", unique_id.hex()) unique_id = bytes.fromhex(store.get(f"PIPE_UNIQUE_ID{topo.pipe_idx}").decode()) config ['pipe_comm'] = nccl.commInitRank(unique_id, pipe_size, topo.stage_id) - if topo.zero_id == 0: + + if config['tp_size'] > 1: + topo = config['topology'] + if topo.tp_id == 0: unique_id = nccl.getUniqueId() - store.set(f"ZERO_UNIQUE_ID{topo.zero_idx}", unique_id.hex() ) - unique_id = bytes.fromhex(store.get(f"ZERO_UNIQUE_ID{topo.zero_idx}").decode()) - config ['zero_comm'] = nccl.commInitRank(unique_id, world_size//pipe_size, topo.zero_id) + store.set(f"TP_UNIQUE_ID{topo.tp_idx}", unique_id.hex()) + unique_id = bytes.fromhex(store.get(f"TP_UNIQUE_ID{topo.tp_idx}").decode()) + config['tp_comm'] = nccl.commInitRank(unique_id, tp_size, topo.tp_id) + + if not config['pipe_enabled'] and config['tp_size'] <= 1: + config['tp_zero_comm'] = config['comm'] else: - config['zero_comm'] = config['comm'] + if topo.tp_zero_id == 0: + unique_id = nccl.getUniqueId() + store.set(f"ZERO_UNIQUE_ID{topo.tp_zero_idx}", unique_id.hex() ) + unique_id = bytes.fromhex(store.get(f"ZERO_UNIQUE_ID{topo.tp_zero_idx}").decode()) + config ['tp_zero_comm'] = nccl.commInitRank(unique_id, world_size//(config['pipe_size'] * config['tp_size']), topo.tp_zero_id) + for i in range(world_size): if i == rank: print_dict("Initialization", { @@ -132,24 +147,40 @@ def init_distributed( "cpus": cpus_this_worker }) synchronize() + class topology: def __init__(self,config): # pipe_idx is the idx of the pipeline in the group self.rank = config['rank'] pp_size = config["pipe_size"] + tp_size = config["tp_size"] world_size = config["world_size"] - assert world_size % pp_size == 0, "The nums of GPUs must be divisible by the pipeline parallel size" + assert world_size % (pp_size * tp_size) == 0, "The nums of GPUs must be divisible by the pipeline parallel size * tensor parallel size" - dp_size = world_size // pp_size - topo=torch.tensor(range(dp_size*pp_size),dtype=torch.int,device='cuda') - topo=topo.view(pp_size,dp_size) + dp_size = world_size // (pp_size * tp_size) + config['tp_zero_size'] = dp_size + config['zero_size'] = world_size // pp_size + topo=torch.tensor(range(dp_size*tp_size*pp_size),dtype=torch.int,device='cuda') + topo=topo.view(pp_size,dp_size*tp_size) self.pp_group=topo.transpose(0,1).reshape(-1,pp_size) - self.dp_group=topo self.stage_id = (self.pp_group == self.rank).nonzero()[0,-1].item() self.stages = config['pipe_size'] self.pipe_idx = (self.pp_group == self.rank).nonzero()[0, 0].item() # x axes self.zero_id = self.pipe_idx self.zero_idx = self.stage_id + + self.tp_group = topo.reshape(pp_size, dp_size, tp_size) + self.tp_id = (self.tp_group == self.rank).nonzero()[0,2].item() + self.tp_idx = (self.tp_group == self.rank).nonzero()[0,1 if dp_size > 1 else 0].item() + + if pp_size == 1 and tp_size == 1: + self.tp_zero_id = self.rank + self.tp_zero_idx = 0 + else: + self.dp_group = self.tp_group.permute(0,2,1) + self.tp_zero_id = (self.dp_group == self.rank).nonzero()[0,2 if tp_size > 1 else 0].item() + self.tp_zero_idx = (self.dp_group == self.rank).nonzero()[0,1 if tp_size > 1 else 2].item() + self.next_rank = self.stage_id+1 if self.stage_id < config['pipe_size'] - 1 else -1 self.prev_rank = self.stage_id-1 if self.stage_id > 0 else -1 self.tails = self.pp_group[self.pipe_idx, self.stage_id:].tolist() @@ -160,12 +191,20 @@ def get_group_id(self,group_name): return self.pipe_idx elif group_name == "zero": return self.zero_idx + elif group_name == "tp_zero": + return self.tp_zero_idx + elif group_name == "tp": + return self.tp_idx def get_group_rank(self,group_name): if group_name == "pipe": return self.stage_id elif group_name == "zero": return self.zero_id + elif group_name == "tp_zero": + return self.tp_zero_id + elif group_name == "tp": + return self.tp_id def is_initialized() -> bool: return config["initialized"] diff --git a/bmtrain/loss/cross_entropy.py b/bmtrain/loss/cross_entropy.py index 31223640..a5d43a55 100644 --- a/bmtrain/loss/cross_entropy.py +++ b/bmtrain/loss/cross_entropy.py @@ -1,6 +1,10 @@ from typing import Optional import torch from . import _function as F +from bmtrain.nn import fused_cross_entropy +from bmtrain.global_var import config +from bmtrain.distributed import all_gather + class OpFusedCrossEntropy(torch.autograd.Function): """ CrossEntropy dim = 1 @@ -185,19 +189,23 @@ def __init__(self, self.inplace = inplace def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - if input.dtype == torch.float32: - return torch.nn.functional.cross_entropy( - input, - target.long(), - weight=self.weight, - ignore_index=self.ignore_index, - reduction=self.reduction, - label_smoothing=self.label_smoothing) - - if self.inplace: - ret = OpFusedCrossEntropyInplace.apply(input, target.int(), self.ignore_index) # return float tensor + if config['tp_size'] > 1: + target = all_gather(target, comm=config['tp_comm']).flatten(0,1) + ret = fused_cross_entropy(input, target.long(), self.ignore_index) else: - ret = OpFusedCrossEntropy.apply(input, target.int(), self.ignore_index) # return float tensor + if input.dtype == torch.float32: + return torch.nn.functional.cross_entropy( + input, + target.long(), + weight=self.weight, + ignore_index=self.ignore_index, + reduction=self.reduction, + label_smoothing=self.label_smoothing) + + if self.inplace: + ret = OpFusedCrossEntropyInplace.apply(input, target.int(), self.ignore_index) # return float tensor + else: + ret = OpFusedCrossEntropy.apply(input, target.int(), self.ignore_index) # return float tensor if self.weight is not None: if self.weight.dim() != 1 or self.weight.size(0) != input.size(1): @@ -208,6 +216,12 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: w = (target != self.ignore_index).int() ret = w * ret + + if config['tp_size'] > 1: + ret_list = ret.chunk(config['tp_size'], dim=0) + ret = ret_list[config['topology'].tp_id] + w_list = w.chunk(config['tp_size'], dim=0) + w = w_list[config['topology'].tp_id] if self.reduction == "none": return ret diff --git a/bmtrain/nn/__init__.py b/bmtrain/nn/__init__.py index 67f9fdee..12540b57 100644 --- a/bmtrain/nn/__init__.py +++ b/bmtrain/nn/__init__.py @@ -1 +1,5 @@ from .linear import Linear +from .column_parallel_linear import ColumnParallelLinear +from .row_parallel_linear import RowParallelLinear +from .embedding import Embedding +from .cross_entropy import fused_cross_entropy diff --git a/bmtrain/nn/column_parallel_linear.py b/bmtrain/nn/column_parallel_linear.py new file mode 100644 index 00000000..9d6444b4 --- /dev/null +++ b/bmtrain/nn/column_parallel_linear.py @@ -0,0 +1,35 @@ +import torch +from torch.nn.parameter import Parameter + +import bmtrain as bmt +from bmtrain.global_var import config +from .parallel_linear_hook_func import ( + LinearHookFunc, + ReduceType) + +class ColumnParallelLinear(bmt.DistributedModule): + def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None, gather_output=False) -> None: + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.gather_output = gather_output + tp_size = config['tp_size'] + assert out_features % tp_size == 0 + self.out_features_per_partition = out_features // tp_size + self.weight = bmt.DistributedParameter(torch.empty(self.out_features_per_partition, in_features, dtype=dtype, device="cuda"), init_method=torch.nn.init.xavier_normal_, tp_split_dim=0, tp_mode=True) + if bias: + self.bias = bmt.DistributedParameter(torch.empty(self.out_features_per_partition, dtype=dtype, device="cuda"), init_method=torch.nn.init.zeros_, tp_split_dim=0, tp_mode=True) + else: + self.register_parameter('bias', None) + + def forward(self, input): + gather_input = True + split_input = False + reduce_output_type = None + return LinearHookFunc.apply(input, self.weight, self.bias, gather_input, self.gather_output, split_input, reduce_output_type) + + def extra_repr(self) -> str: + return 'in_features={}, out_features={}, bias={}'.format( + self.in_features, self.out_features_per_partitions, self.bias is not None + ) + diff --git a/bmtrain/nn/cross_entropy.py b/bmtrain/nn/cross_entropy.py new file mode 100644 index 00000000..8643c03b --- /dev/null +++ b/bmtrain/nn/cross_entropy.py @@ -0,0 +1,135 @@ +import torch +import bmtrain as bmt +from bmtrain.global_var import config +from bmtrain.distributed import all_reduce, all_gather + +class FusedCrossEntropyFunc(torch.autograd.Function): + + @staticmethod + def forward(ctx, vocab_parallel_logits, target, label_smoothing=0.0): + comm = config['tp_comm'] + rank = config['topology'].tp_id + world_size = config['tp_size'] + # Maximum value along vocab dimension across all GPUs. + logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] + logits_max = all_reduce(logits_max, op="max", comm=comm) + # Subtract the maximum value. + vocab_parallel_logits = vocab_parallel_logits - logits_max.unsqueeze(dim=-1) + + # Get the partition's vocab indecies + #get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size + partition_vocab_size = vocab_parallel_logits.size()[-1] + vocab_start_index = rank * partition_vocab_size + vocab_end_index = (rank + 1) * partition_vocab_size + + # Create a mask of valid vocab ids (1 means it needs to be masked). + target_mask = (target < vocab_start_index) | (target >= vocab_end_index) + masked_target = target.clone() - vocab_start_index + masked_target[target_mask] = 0 + + # Get predicted-logits = logits[target]. + # For Simplicity, we convert logits to a 2-D tensor with size + # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. + logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) + masked_target_1d = masked_target.view(-1) + arange_1d = torch.arange(start=0, end=logits_2d.size()[0], + device=logits_2d.device) + predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] + predicted_logits_1d = predicted_logits_1d.clone().contiguous() + predicted_logits = predicted_logits_1d.view_as(target) + predicted_logits[target_mask] = 0.0 + # All reduce is needed to get the chunks from other GPUs. + predicted_logits = all_reduce(predicted_logits, op="sum", comm=comm) + + # Sum of exponential of logits along vocab dimension across all GPUs. + exp_logits = vocab_parallel_logits + torch.exp(vocab_parallel_logits, out=exp_logits) + sum_exp_logits = exp_logits.sum(dim=-1) + #if config['rank'] == 0: + #print("before", sum_exp_logits.shape, predicted_logits.shape) + sum_exp_logits = all_reduce(sum_exp_logits, op="sum", comm=comm) + #if config['rank'] == 0: + # print(sum_exp_logits.shape) + + # Loss = log(sum(exp(logits))) - predicted-logit. + loss = torch.log(sum_exp_logits.view(predicted_logits.shape)) - predicted_logits + + # Normalize and optionally smooth logits + exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + + vocab_size = exp_logits.size(-1) + if label_smoothing > 0: + """ + We'd like to assign 1 / (K - 1) probability mass to every index that is not the ground truth. + = (1 - alpha) * y_gt + alpha * mean(y_{i for i != gt}) + = (1 - alpha) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i + = ((K - 1) * (1 - alpha) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i + = (K * (1 - alpha) - 1) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i} y_i + = (1 - (alpha * K) / (K - 1)) * y_gt + ( (alpha * K) / (K - 1) ) * \sum_{i} y_i / K + From: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/losses/smoothed_cross_entropy.py + """ + assert 1.0 > label_smoothing > 0.0 + smoothing = label_smoothing * vocab_size / (vocab_size - 1) + + # Exp logits at this point are normalized probabilities. So we can just take the log to get log-probs. + log_probs = torch.log(exp_logits) + mean_log_probs = log_probs.mean(dim=-1) + loss = (1.0 - smoothing) * loss - smoothing * mean_log_probs + + ctx.label_smoothing, ctx.vocab_size = label_smoothing, vocab_size + ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) + + # Store softmax, target-mask and masked-target for backward pass. + ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) + + return loss + + @staticmethod + def backward(ctx, grad_output): + # Retreive tensors from the forward path. + softmax, target_mask, masked_target_1d = ctx.saved_tensors + label_smoothing, vocab_size = ctx.label_smoothing, ctx.vocab_size + + # All the inputs have softmax as thier gradient. + grad_input = softmax + # For simplicity, work with the 2D gradient. + partition_vocab_size = softmax.size()[-1] + grad_2d = grad_input.view(-1, partition_vocab_size) + + # Add the gradient from matching classes. + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], + device=grad_2d.device) + + softmax_update = 1.0 - target_mask.view(-1).float() + + if label_smoothing > 0: + smoothing = label_smoothing * vocab_size / (vocab_size - 1) + grad_2d[arange_1d, masked_target_1d] -= (1.0 - smoothing) * softmax_update + average_grad = 1 / vocab_size + grad_2d[arange_1d, :] -= smoothing * average_grad + else: + grad_2d[arange_1d, masked_target_1d] -= softmax_update + + # Finally elementwise multiplication with the output gradients. + #grad_input.mul_(grad_output.unsqueeze(dim=-1)) + grad_input.mul_(grad_output.flatten(0,1).unsqueeze(dim=-1)) + + return grad_input, None, None + + +def fused_cross_entropy(vocab_parallel_logits, target, label_smoothing=0.0): + """ + Performs cross entropy loss when logits are split across tensor parallel ranks + + Arguments: + vocab_parallel_logits: logits split across tensor parallel ranks + dimension is [sequence_length, batch_size, hidden_size] + + target: correct vocab ids of dimseion [sequence_length, micro_batch_size] + + lobal_smoothing: smoothing factor, must be in range [0.0, 1.0) + default is no smoothing (=0.0) + """ + out = FusedCrossEntropyFunc.apply(vocab_parallel_logits.to(torch.float32), target, label_smoothing) + return out + diff --git a/bmtrain/nn/embedding.py b/bmtrain/nn/embedding.py new file mode 100644 index 00000000..52be1d72 --- /dev/null +++ b/bmtrain/nn/embedding.py @@ -0,0 +1,87 @@ +import torch +from torch.nn.parameter import Parameter +import torch.nn.functional as F +import math + +import bmtrain as bmt +from bmtrain.global_var import config +from bmtrain.distributed import all_reduce, all_gather +from .parallel_linear_hook_func import LinearHookFunc + +class Embedding(bmt.DistributedModule): + def __init__( + self, + vocab_size: int, + embedding_size: int, + dtype: torch.dtype = torch.half, + scale: bool = True, + init_mean: float = 0.0, + init_std: float = 1, + ): + super().__init__() + + self.dim_model = embedding_size + assert vocab_size % config['tp_size'] == 0 + self.vocab_size_per_partition = vocab_size // config['tp_size'] + self.start_index = config['topology'].tp_id * self.vocab_size_per_partition + self.end_index = (config['topology'].tp_id+1) * self.vocab_size_per_partition + self.weight = bmt.DistributedParameter( + torch.empty(self.vocab_size_per_partition, embedding_size, dtype=dtype), + init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std), + tp_mode=True, + tp_split_dim=0, + ) + self.scale = scale + + def forward(self, ids: torch.Tensor): + """ + Args: + ids (:obj:`torch.Tensor` of shape ``(batch_size, seq_len)``): Indices of input sequence tokens. + Return: + :obj:`torch.Tensor` of shape ``(batch_size, seq_len, embedding_size)``: The embedding output. + """ # noqa: E501 + + if config['tp_size'] > 1: + ids = all_gather(ids, comm=config['tp_comm']) + input_mask = (ids < self.start_index) | (ids >= self.end_index) + ids = ids.clone() - self.start_index + ids[input_mask] = 0 + + #if self.scale: + #embeds = F.embedding(ids, self.weight) / math.sqrt(self.dim_model) + #else: + embeds = F.embedding(ids, self.weight) + + if config['tp_size'] > 1: + embeds[input_mask, :] = 0.0 + embeds = all_reduce(embeds, op="sum", comm=config['tp_comm']) + embed_list = embeds.chunk(config['tp_size'], dim=0) + embeds = embed_list[config['topology'].tp_id].flatten(0,1) + #print(embeds.sum()) + + return embeds.clone() + + def projection(self, x: torch.Tensor): + """ + Projection based on embedding's weight. For example, embedding map vocab_size to embed_size, than projection map embed_size back to vocab_size. + Args: + x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_model)``): Input of projection + Returns: + :obj:`torch.Tensor` of shape ``(batch, seq_len, vocab_output_size)``: The projection output. + """ # noqa: E501 + gather_input = True + split_input = False + reduce_output_type = None + gather_output = False + if self.scale: + #out = LinearHookFunc.apply(x / math.sqrt(self.dim_model), self.weight, None, gather_input, gather_output, split_input, reduce_output_type) + #print(x.sum()) + #print(x.shape, self.weight.shape) + out = LinearHookFunc.apply(x , self.weight, None, gather_input, gather_output, split_input, reduce_output_type) / math.sqrt(self.dim_model) + #out_list = out.chunk(config['tp_size'], dim=0) + #out = out_list[config['topology'].tp_id] + #print(out.sum()) + return out + else: + return LinearHookFunc.apply(x, self.weight, None, gather_input, gather_output, split_input, reduce_output_type) + return logits diff --git a/bmtrain/nn/parallel_linear_hook_func.py b/bmtrain/nn/parallel_linear_hook_func.py new file mode 100644 index 00000000..af7bea0a --- /dev/null +++ b/bmtrain/nn/parallel_linear_hook_func.py @@ -0,0 +1,91 @@ +import torch +import torch.nn.functional as F +from bmtrain.global_var import config +from ..distributed import all_gather, all_reduce +from .. import nccl +import bmtrain as bmt +from enum import Enum + +class ReduceType(Enum): + ALL_REDUCE = 1 + REDUCE_SCATTER = 2 + +def preprocess_input(input, gather_input, split_input): + if gather_input: + input = all_gather(input, config['tp_comm']) + input = input.flatten(0, 1) + + if split_input: + all_input_list = input.chunk(config['tp_size'], dim=1) + input = all_input_list[config['topology'].tp_id] + return input + +class LinearHookFunc(torch.autograd.Function): + @staticmethod + def forward(ctx, input, weight, bias=None, gather_input=False, gather_output=False, split_input=False, reduce_output_type=None): + ctx.save_for_backward(input, weight, bias) + ctx.gather_output = gather_output + ctx.split_input = split_input + ctx.gather_input = gather_input + ctx.reduce_output_type = reduce_output_type + all_input = preprocess_input(input, ctx.gather_input, ctx.split_input) + out = F.linear(all_input, weight, bias) + if gather_output: + all_output_list = all_gather(out, config['tp_comm']) + all_output_list = all_output_list.chunk(config['tp_size'], dim=0) + out = torch.cat(all_output_list, dim=all_output_list[0].dim()-1).flatten(0,1) + + if reduce_output_type is None: + return out + + if reduce_output_type == ReduceType.ALL_REDUCE: + nccl.allReduce(out.storage(), out.storage(), "sum", config['tp_comm']) + return out + elif reduce_output_type == ReduceType.REDUCE_SCATTER: + shape = list(out.shape) + shape[0] = shape[0] // config['tp_size'] + reduce_out = torch.empty(shape, dtype=out.dtype, device=out.device) + nccl.reduceScatter(out.storage(), reduce_out.storage(), "sum", config['tp_comm']) + return reduce_out + + @staticmethod + def backward(ctx, grad_output): + input, weight, bias = ctx.saved_tensors + gather_output = ctx.gather_output + + if ctx.reduce_output_type == ReduceType.REDUCE_SCATTER: + grad_output = all_gather(grad_output, config['tp_comm']) + grad_output = grad_output.flatten(0, 1) + + if gather_output: + tp_size = config['tp_size'] + tp_id = config['topology'].tp_id + grad_output_list = grad_output.chunk(tp_size, dim=1) + grad_output = grad_output_list[tp_id] + + grad_input = grad_weight = grad_bias = None + + if input.requires_grad or weight.requires_grad: + all_input = preprocess_input(input, ctx.gather_input, ctx.split_input) + + if input.requires_grad: + #gather can async with grad_out.matmul(weight) + #TODO: gather on load_stream + grad_all_input = grad_output.matmul(weight) + grad_input = torch.empty_like(input) + if ctx.gather_input: + nccl.reduceScatter(grad_all_input.storage(), grad_input.storage(), "sum", config['tp_comm']) + else: + grad_input = grad_all_input + + if ctx.split_input: + grad_input = all_gather(grad_input, config['tp_comm']) + + if weight.requires_grad: + dim = grad_output.dim() + grad_weight = grad_output.reshape(-1, + grad_output.shape[-1]).t().matmul(all_input.reshape(-1, all_input.shape[-1])) + + if bias is not None and bias.requires_grad: + grad_bias = grad_output.reshape(-1, grad_output.shape[-1]).sum(0) + return grad_input, grad_weight, grad_bias, None, None, None, None diff --git a/bmtrain/nn/row_parallel_linear.py b/bmtrain/nn/row_parallel_linear.py new file mode 100644 index 00000000..a6877b7c --- /dev/null +++ b/bmtrain/nn/row_parallel_linear.py @@ -0,0 +1,37 @@ +import torch +from torch.nn.parameter import Parameter + +import bmtrain as bmt +from bmtrain.global_var import config +from .parallel_linear_hook_func import ( + LinearHookFunc, + ReduceType) + +class RowParallelLinear(bmt.DistributedModule): + def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None, split_input=False, all_reduce_output=False) -> None: + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.split_input = split_input + self.all_reduce_output = all_reduce_output + tp_size = config['tp_size'] + assert in_features % tp_size == 0 + self.in_features_per_partition = in_features // tp_size + self.weight = bmt.DistributedParameter(torch.empty(self.out_features, self.in_features_per_partition, dtype=dtype, device="cuda"), init_method=torch.nn.init.xavier_normal_, tp_split_dim=1, tp_mode=True) + if bias: + self.bias = bmt.DistributedParameter(torch.empty(self.out_features, dtype=dtype, device="cuda"), init_method=torch.nn.init.zeros_, tp_split_dim=-1, tp_mode=True) + else: + self.register_parameter('bias', None) + + def forward(self, input): + gather_input = self.split_input + gather_output = False + reduce_output_type = ReduceType.ALL_REDUCE if self.all_reduce_output else ReduceType.REDUCE_SCATTER + out = LinearHookFunc.apply(input, self.weight, None, gather_input, gather_output, self.split_input, reduce_output_type) + out = out + self.bias + return out + + def extra_repr(self) -> str: + return 'in_features={}, out_features={}, bias={}'.format( + self.in_features_per_partition, self.out_features, self.bias is not None + ) diff --git a/bmtrain/param_init.py b/bmtrain/param_init.py index 8b74c580..5e165fe1 100644 --- a/bmtrain/param_init.py +++ b/bmtrain/param_init.py @@ -3,6 +3,7 @@ from .block_layer import CheckpointBlock from .parameter import DistributedParameter from .global_var import config +from . import nccl def init_distributed_parameter(params : Iterable[torch.nn.Parameter]): @@ -13,17 +14,27 @@ def init_distributed_parameter(params : Iterable[torch.nn.Parameter]): continue with torch.no_grad(): partition_size = param.storage().size() - global_size = partition_size * config['world_size'] - + global_size = partition_size * config['tp_zero_size'] * config['tp_size'] tmp_storage = param.storage_type()(global_size) tmp_tensor = torch.tensor([], dtype=param.dtype, device="cuda") - tmp_tensor.set_(tmp_storage, 0, param._original_shape) + tmp_tensor.set_(tmp_storage, 0, param._tp_original_shape) param._init_method(tmp_tensor) + if param._tp_mode and param._tp_split_dim >= 0: + tensor_list = tmp_tensor.chunk(config['tp_size'], dim=param._tp_split_dim) + sub_tensor = tensor_list[config['topology'].tp_id].contiguous() + tmp_tensor = torch.empty(sub_tensor.shape, device=param.device, dtype=sub_tensor.dtype) + tmp_tensor.copy_(sub_tensor) + + if param._tp_mode: + begin = config['tp_zero_rank'] + else: + begin = config['zero_rank'] + end = begin + 1 # Pytorch 1.11 changed the API of storage.__getitem__ torch.tensor([], dtype=param.dtype, device=param.device).set_(param.storage())[:] = \ - torch.tensor([], dtype=param.dtype, device=param.device).set_(tmp_storage)[partition_size * config['rank'] : partition_size * (config['rank'] + 1)] + torch.tensor([], dtype=param.dtype, device=param.device).set_(tmp_tensor.storage())[partition_size * begin : partition_size * end] # param.storage().copy_(tmp_storage[partition_size * config['rank'] : partition_size * (config['rank'] + 1)]) def iterate_parameters(model : torch.nn.Module): diff --git a/bmtrain/parameter.py b/bmtrain/parameter.py index f965cdac..f68d53ac 100644 --- a/bmtrain/parameter.py +++ b/bmtrain/parameter.py @@ -3,6 +3,7 @@ from .utils import round_up from .global_var import config from . import nccl +from .distributed import all_gather class DistributedParameter(torch.nn.Parameter): r""" @@ -31,7 +32,9 @@ def __new__(cls, data : torch.Tensor, requires_grad : bool = True, init_method : Optional[Callable[['DistributedParameter'], None]] = None, - group : Optional[str] = None + group : Optional[str] = None, + tp_mode : bool = False, + tp_split_dim : int = -1, ): if not config["initialized"]: raise RuntimeError("BMTrain is not initialized") @@ -39,14 +42,25 @@ def __new__(cls, num_of_elements = data.numel() cuda_tensor = torch.tensor([], dtype=data.dtype, device="cuda") - cuda_storage_size = round_up(num_of_elements, config["world_size"]) // config["world_size"] + if tp_mode: + comm = config['tp_zero_comm'] + else: + comm = config['zero_comm'] + world_size = nccl.commCount(comm) + rank = nccl.commRank(comm) + cuda_storage_size = round_up(num_of_elements, world_size) // world_size original_shape = data.size() + tp_original_shape = original_shape + if tp_mode and tp_split_dim >= 0: + list_shape = list(original_shape) + list_shape[tp_split_dim] *= config['tp_size'] + tp_original_shape = list_shape cuda_storage = cuda_tensor.storage_type()(cuda_storage_size) - start_of_partition = cuda_storage_size * config["rank"] - end_of_partition = min(num_of_elements, cuda_storage_size * (config["rank"] + 1)) + start_of_partition = cuda_storage_size * rank + end_of_partition = min(num_of_elements, cuda_storage_size * (rank + 1)) # FX: cuda_tensor_size < 0 if num_of_elements is too small cuda_tensor_size = max(end_of_partition - start_of_partition, 0) @@ -60,7 +74,12 @@ def __new__(cls, setattr(ret, "_end_partition", end_of_partition) setattr(ret, "_init_method", init_method) setattr(ret, "_in_checkpoint_block", False) - setattr(ret, "_group", group) + setattr(ret, "_group", group if not tp_mode else "tp") + + setattr(ret, "_tp_mode", tp_mode) + setattr(ret, "_zero_comm", comm) + setattr(ret, "_tp_split_dim", tp_split_dim) + setattr(ret, "_tp_original_shape", tp_original_shape) return ret @property @@ -83,24 +102,52 @@ def gather(self) -> torch.Tensor: current_stream.wait_stream(config['load_stream']) return output_tensor + def gather_all(self) -> torch.tensor: + zero_param = self.gather() + if config['tp_size'] > 1 and self._tp_split_dim >= 0: + output_tensor = all_gather(zero_param, config['tp_comm']) + if self._tp_split_dim == 1: + output_list = output_tensor.chunk(config['tp_size'], dim=0) + output = torch.cat(output_list, dim=output_list[0].dim()-1).flatten(0,1) + return output + else: + return output_tensor.flatten(0,1) + else: + return zero_param + + def tp_gather(self) -> torch.tensor: + if config['tp_size'] > 1 and self._tp_split_dim >= 0: + output_tensor = all_gather(self, config['tp_comm']) + if self._tp_split_dim == 1: + output_list = output_tensor.chunk(config['tp_size'], dim=0) + output = torch.cat(output_list, dim=output_list[0].dim()-1).flatten(0,1) + return output + else: + return output_tensor.flatten(0,1) + else: + return self + def _copy_data(self, data : torch.Tensor): self.data.copy_(data.view(-1)[self._start_partition : self._end_partition]) - class OpAllGather(torch.autograd.Function): @staticmethod def forward(ctx, value : DistributedParameter): assert isinstance(value, DistributedParameter) + comm = value._zero_comm #config['zero_comm'] + world_size = nccl.commCount(comm) + ctx.comm = comm + ctx.world_size = world_size partition_size = value.storage().size() - global_size = partition_size * config['world_size'] + global_size = partition_size * world_size storage = value.storage_type()(global_size) nccl.allGather( value.storage(), storage, - config['comm'] + comm ) output_tensor = torch.tensor([], dtype=value.dtype, device="cuda") @@ -117,19 +164,19 @@ def backward(ctx, grad_output : torch.Tensor): grad_storage = grad_output.storage_type()(ctx.partition_size) grad_output_storage = grad_output.storage() - if grad_output_storage.size() == ctx.partition_size * config['world_size']: + if grad_output_storage.size() == ctx.partition_size * ctx.world_size: pass else: - grad_output_storage.resize_(ctx.partition_size * config['world_size']) + grad_output_storage.resize_(ctx.partition_size * ctx.world_size) nccl.reduceScatter( grad_output_storage, grad_storage, 'sum', - config['comm'] + ctx.comm ) grad_tensor = torch.tensor([], dtype=grad_output.dtype, device="cuda") grad_tensor.set_(grad_storage, 0, (ctx.tensor_size,)) - return grad_tensor + return grad_tensor, None class ParameterInitializer: """ diff --git a/example/layers/attention.py b/example/layers/attention.py index 243df3ea..b2ddebb9 100644 --- a/example/layers/attention.py +++ b/example/layers/attention.py @@ -1,8 +1,14 @@ from typing import Optional import torch import bmtrain as bmt -from bmtrain.nn import Linear +from bmtrain.nn import ( + Linear, + ColumnParallelLinear, + RowParallelLinear, +) import math +from bmtrain.global_var import config +from bmtrain.distributed import all_gather class Attention(bmt.DistributedModule): def __init__(self, @@ -12,11 +18,17 @@ def __init__(self, ) -> None: super().__init__() - self.project_q = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) - self.project_k = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) - self.project_v = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + if config['tp_size'] <= 1: + self.project_q = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + self.project_k = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + self.project_v = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + self.project_out = Linear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype) + else: + self.project_q = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + self.project_k = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + self.project_v = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) - self.project_out = Linear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype) + self.project_out = RowParallelLinear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype) self.softmax = torch.nn.Softmax(dim=-1) self.num_heads = num_heads @@ -35,29 +47,35 @@ def forward(self, h_q : torch.Tensor = self.project_q(hidden_q) h_k : torch.Tensor = self.project_k(hidden_kv) h_v : torch.Tensor = self.project_v(hidden_kv) + if config['tp_size'] > 1: + #batch_size will changed in TensorParallel + batch_size = h_v.shape[0] - h_q = h_q.view(batch_size, seq_q, self.num_heads, self.dim_head) - h_k = h_k.view(batch_size, seq_kv, self.num_heads, self.dim_head) - h_v = h_v.view(batch_size, seq_kv, self.num_heads, self.dim_head) + h_q = h_q.view(batch_size, seq_q, -1, self.dim_head) + h_k = h_k.view(batch_size, seq_kv, -1, self.dim_head) + h_v = h_v.view(batch_size, seq_kv, -1, self.dim_head) h_q = h_q.permute(0, 2, 1, 3).contiguous() h_k = h_k.permute(0, 2, 1, 3).contiguous() h_v = h_v.permute(0, 2, 1, 3).contiguous() - h_q = h_q.view(batch_size * self.num_heads, seq_q, self.dim_head) - h_k = h_k.view(batch_size * self.num_heads, seq_kv, self.dim_head) - h_v = h_v.view(batch_size * self.num_heads, seq_kv, self.dim_head) + h_q = h_q.view(-1, seq_q, self.dim_head) + h_k = h_k.view(-1, seq_kv, self.dim_head) + h_v = h_v.view(-1, seq_kv, self.dim_head) score = torch.bmm( h_q, h_k.transpose(1, 2) ) score = score / math.sqrt(self.dim_head) - score = score.view(batch_size, self.num_heads, seq_q, seq_kv) + score = score.view(batch_size, -1, seq_q, seq_kv) if position_bias is not None: - score = score + position_bias.view(batch_size, self.num_heads, seq_q, seq_kv) - + score = score + position_bias.view(batch_size, -1, seq_q, seq_kv) + + if config['tp_size'] > 1: + mask = all_gather(mask, config['tp_comm']) + score = torch.where( mask.view(batch_size, 1, seq_q, seq_kv), score, @@ -70,14 +88,14 @@ def forward(self, torch.scalar_tensor(0, device=score.device, dtype=score.dtype) ) - score = score.view(batch_size * self.num_heads, seq_q, seq_kv) + score = score.view(-1, seq_q, seq_kv) h_out = torch.bmm( score, h_v ) - h_out = h_out.view(batch_size, self.num_heads, seq_q, self.dim_head) + h_out = h_out.view(batch_size, -1, seq_q, self.dim_head) h_out = h_out.permute(0, 2, 1, 3).contiguous() - h_out = h_out.view(batch_size, seq_q, self.num_heads * self.dim_head) + h_out = h_out.view(batch_size, seq_q, -1) attn_out = self.project_out(h_out) return attn_out diff --git a/example/layers/embedding.py b/example/layers/embedding.py index 13c47384..4faaa133 100644 --- a/example/layers/embedding.py +++ b/example/layers/embedding.py @@ -77,11 +77,14 @@ def from_pretrained(cls, embeddings, freeze=True, padding_idx=None, def forward(self, input: torch.Tensor, projection : bool = False) -> torch.Tensor: if not projection: - return F.embedding( + out = F.embedding( input, self.weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse) + return out else: - return F.linear(input, self.weight) / math.sqrt(self.embedding_dim) + out = F.linear(input, self.weight) / math.sqrt(self.embedding_dim) + #print(out.sum()) + return out def extra_repr(self) -> str: s = '{num_embeddings}, {embedding_dim}' @@ -97,4 +100,4 @@ def extra_repr(self) -> str: s += ', sparse=True' return s.format(**self.__dict__) - \ No newline at end of file + diff --git a/example/layers/feedforward.py b/example/layers/feedforward.py index 99d2dc3b..e88d2495 100644 --- a/example/layers/feedforward.py +++ b/example/layers/feedforward.py @@ -1,16 +1,23 @@ import torch import bmtrain as bmt -from bmtrain.nn import Linear +from bmtrain.nn import ( + Linear, + ColumnParallelLinear, + RowParallelLinear) +from bmtrain.global_var import config class Feedforward(bmt.DistributedModule): def __init__(self, dim_model : int, dim_ff : int, bias : bool = True, dtype = None) -> None: super().__init__() - self.w_in = Linear(dim_model, dim_ff, bias = bias, dtype=dtype) - self.w_out = Linear(dim_ff, dim_model, bias = bias, dtype=dtype) + if config['tp_size'] > 1: + self.w_in = ColumnParallelLinear(dim_model, dim_ff, bias = bias, dtype=dtype) + self.w_out = RowParallelLinear(dim_ff, dim_model, bias = bias, dtype=dtype) + else: + self.w_in = Linear(dim_model, dim_ff, bias=bias, dtype=dtype) + self.w_out = Linear(dim_ff, dim_model, bias=bias, dtype=dtype) self.relu = torch.nn.ReLU() def forward(self, input : torch.Tensor) -> torch.Tensor: - return self.w_out(self.relu(self.w_in(input))) diff --git a/example/models/gpt.py b/example/models/gpt.py index 78d77a7d..d218b391 100644 --- a/example/models/gpt.py +++ b/example/models/gpt.py @@ -1,6 +1,7 @@ import torch import bmtrain as bmt from layers import TransformerEncoder, Layernorm, Embedding, TransformerEncoder +from bmtrain.global_var import config class GPT(bmt.DistributedModule): def __init__(self, @@ -13,14 +14,17 @@ def __init__(self, self.max_distance = max_distance - self.word_emb = Embedding(vocab_size, dim_model, dtype=dtype) + if config['tp_size'] > 1: + self.word_emb = bmt.nn.Embedding(vocab_size, dim_model, dtype=dtype) + else: + self.word_emb = Embedding(vocab_size, dim_model, dtype=dtype) self.pos_emb = Embedding(max_distance, dim_model, dtype=dtype) self.transformers = bmt.TransformerBlockList([ bmt.CheckpointBlock( TransformerEncoder( dim_model, dim_head, num_heads, dim_ff, bias, dtype - ) + ), use_checkpoint=False ) for _ in range(num_layers) ]) @@ -37,12 +41,18 @@ def forward(self, mask_2d = mask_2d & (pos[:, None, :] >= pos[:, :, None]) out = self.pos_emb(pos) + self.word_emb(input) + bmt.synchronize() # for layer in self.transformers: out = self.transformers(out, mask_2d, None) out = self.layernorm(out) - logits = self.word_emb(out, projection=True) + bmt.synchronize() + if config['tp_size'] > 1: + logits = self.word_emb.projection(out)#self.word_emb(out, projection=True) + else: + logits = self.word_emb(out, projection=True) + bmt.synchronize() bmt.inspect.record_tensor(logits, "logits") - return logits \ No newline at end of file + return logits diff --git a/example/train.py b/example/train.py index 1a744e20..50036675 100644 --- a/example/train.py +++ b/example/train.py @@ -3,12 +3,14 @@ from models import GPT import time from bmtrain import optim +from bmtrain.global_var import config from bmtrain import inspect def main(): bmt.init_distributed( seed=0, zero_level=2, + tp_size=4, ) model = GPT( @@ -24,11 +26,13 @@ def main(): ) bmt.init_parameters(model) + #bmt.load(model, "example_model.pt") # print_inspect(model, "*") bmt.print_rank("Model memory") bmt.print_rank(torch.cuda.memory_summary()) bmt.synchronize() + #bmt.save(model, "example_model.pt") # data # generate dummy data for each rank @@ -52,7 +56,11 @@ def main(): if i == bmt.rank(): break - loss_func = torch.nn.CrossEntropyLoss(ignore_index=-100) + if config['tp_size'] > 1: + loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100) + else: + loss_func = torch.nn.CrossEntropyLoss(ignore_index=-100) + optimizer = optim.AdamOffloadOptimizer(model.parameters(), weight_decay=1e-2) lr_scheduler = bmt.lr_scheduler.Noam(optimizer, start_lr=1e-3, warmup_iter=40, end_iter=1000, num_iter=0) @@ -77,7 +85,10 @@ def main(): ) batch, seq_len, vocab_out_size = logits.size() - loss = loss_func(logits.view(batch * seq_len, vocab_out_size), targets.view(batch * seq_len)) + if config['tp_size'] > 1: + loss = loss_func(logits.view(batch * seq_len, vocab_out_size), targets) + else: + loss = loss_func(logits.float().view(batch * seq_len, vocab_out_size), targets.view(batch * seq_len)) global_loss = bmt.sum_loss(loss).item() From 8648f5baea9b599e668a92ba1fd788c2942363cf Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Wed, 23 Aug 2023 20:20:56 +0800 Subject: [PATCH 03/25] rm unused code --- bmtrain/block_layer.py | 6 +++--- bmtrain/checkpointing.py | 3 --- bmtrain/nn/cross_entropy.py | 5 ----- bmtrain/nn/embedding.py | 20 ++------------------ example/layers/attention.py | 12 ++++++------ example/layers/embedding.py | 3 +-- example/models/gpt.py | 7 ++----- example/train.py | 5 +---- 8 files changed, 15 insertions(+), 46 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index f30b13f6..d01a01cb 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -112,11 +112,11 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True): comm = val['zero_comm'] world_size = nccl.commCount(comm) rank = nccl.commRank(comm) - val["world_size"] = world_size #config["dp_size"] + val["world_size"] = world_size partition_size = round_up(val["total"], val["world_size"]) // val["world_size"] val["partition_size"] = partition_size - val["begin"] = rank * partition_size #config['zero_rank'] * partition_size - val["end"] = (rank+1) * partition_size #(config['zero_rank'] + 1) * partition_size + val["begin"] = rank * partition_size + val["end"] = (rank+1) * partition_size offsets[kw] = 0 diff --git a/bmtrain/checkpointing.py b/bmtrain/checkpointing.py index 3adbc105..550225be 100644 --- a/bmtrain/checkpointing.py +++ b/bmtrain/checkpointing.py @@ -39,7 +39,6 @@ def __init__(self, block : 'CheckpointBlock', ctx_dict : dict = None, pipe = Fal self._param_tensor = {} self._grad_tensor = {} self._need_release = False - #self.comm = config["zero_comm"] def enter(self, flag=0, requires_grad=False): """ @@ -72,7 +71,6 @@ def enter(self, flag=0, requires_grad=False): nccl.allGather( self.block._storage_params[kw].storage(), self._param_buffer[kw], - #self.comm val['zero_comm'] ) nccl.groupEnd() @@ -143,7 +141,6 @@ def exit(self, flag=0, backward=False): self._grad_buffer[kw], local_param.grad.storage(), "sum", - #self.comm val['zero_comm'] ) nccl.groupEnd() diff --git a/bmtrain/nn/cross_entropy.py b/bmtrain/nn/cross_entropy.py index 8643c03b..49460b5b 100644 --- a/bmtrain/nn/cross_entropy.py +++ b/bmtrain/nn/cross_entropy.py @@ -45,11 +45,7 @@ def forward(ctx, vocab_parallel_logits, target, label_smoothing=0.0): exp_logits = vocab_parallel_logits torch.exp(vocab_parallel_logits, out=exp_logits) sum_exp_logits = exp_logits.sum(dim=-1) - #if config['rank'] == 0: - #print("before", sum_exp_logits.shape, predicted_logits.shape) sum_exp_logits = all_reduce(sum_exp_logits, op="sum", comm=comm) - #if config['rank'] == 0: - # print(sum_exp_logits.shape) # Loss = log(sum(exp(logits))) - predicted-logit. loss = torch.log(sum_exp_logits.view(predicted_logits.shape)) - predicted_logits @@ -111,7 +107,6 @@ def backward(ctx, grad_output): grad_2d[arange_1d, masked_target_1d] -= softmax_update # Finally elementwise multiplication with the output gradients. - #grad_input.mul_(grad_output.unsqueeze(dim=-1)) grad_input.mul_(grad_output.flatten(0,1).unsqueeze(dim=-1)) return grad_input, None, None diff --git a/bmtrain/nn/embedding.py b/bmtrain/nn/embedding.py index 52be1d72..c8acafae 100644 --- a/bmtrain/nn/embedding.py +++ b/bmtrain/nn/embedding.py @@ -14,7 +14,6 @@ def __init__( vocab_size: int, embedding_size: int, dtype: torch.dtype = torch.half, - scale: bool = True, init_mean: float = 0.0, init_std: float = 1, ): @@ -31,7 +30,6 @@ def __init__( tp_mode=True, tp_split_dim=0, ) - self.scale = scale def forward(self, ids: torch.Tensor): """ @@ -47,9 +45,6 @@ def forward(self, ids: torch.Tensor): ids = ids.clone() - self.start_index ids[input_mask] = 0 - #if self.scale: - #embeds = F.embedding(ids, self.weight) / math.sqrt(self.dim_model) - #else: embeds = F.embedding(ids, self.weight) if config['tp_size'] > 1: @@ -57,7 +52,6 @@ def forward(self, ids: torch.Tensor): embeds = all_reduce(embeds, op="sum", comm=config['tp_comm']) embed_list = embeds.chunk(config['tp_size'], dim=0) embeds = embed_list[config['topology'].tp_id].flatten(0,1) - #print(embeds.sum()) return embeds.clone() @@ -73,15 +67,5 @@ def projection(self, x: torch.Tensor): split_input = False reduce_output_type = None gather_output = False - if self.scale: - #out = LinearHookFunc.apply(x / math.sqrt(self.dim_model), self.weight, None, gather_input, gather_output, split_input, reduce_output_type) - #print(x.sum()) - #print(x.shape, self.weight.shape) - out = LinearHookFunc.apply(x , self.weight, None, gather_input, gather_output, split_input, reduce_output_type) / math.sqrt(self.dim_model) - #out_list = out.chunk(config['tp_size'], dim=0) - #out = out_list[config['topology'].tp_id] - #print(out.sum()) - return out - else: - return LinearHookFunc.apply(x, self.weight, None, gather_input, gather_output, split_input, reduce_output_type) - return logits + out = LinearHookFunc.apply(x , self.weight, None, gather_input, gather_output, split_input, reduce_output_type) + return out diff --git a/example/layers/attention.py b/example/layers/attention.py index b2ddebb9..a49edabb 100644 --- a/example/layers/attention.py +++ b/example/layers/attention.py @@ -18,17 +18,17 @@ def __init__(self, ) -> None: super().__init__() - if config['tp_size'] <= 1: + if config['tp_size'] > 1: + self.project_q = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + self.project_k = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + self.project_v = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + self.project_out = RowParallelLinear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype) + else: self.project_q = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) self.project_k = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) self.project_v = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) self.project_out = Linear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype) - else: - self.project_q = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) - self.project_k = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) - self.project_v = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) - self.project_out = RowParallelLinear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype) self.softmax = torch.nn.Softmax(dim=-1) self.num_heads = num_heads diff --git a/example/layers/embedding.py b/example/layers/embedding.py index 4faaa133..f62151c4 100644 --- a/example/layers/embedding.py +++ b/example/layers/embedding.py @@ -82,8 +82,7 @@ def forward(self, input: torch.Tensor, projection : bool = False) -> torch.Tenso self.norm_type, self.scale_grad_by_freq, self.sparse) return out else: - out = F.linear(input, self.weight) / math.sqrt(self.embedding_dim) - #print(out.sum()) + out = F.linear(input, self.weight) return out def extra_repr(self) -> str: diff --git a/example/models/gpt.py b/example/models/gpt.py index d218b391..d2dff467 100644 --- a/example/models/gpt.py +++ b/example/models/gpt.py @@ -24,7 +24,7 @@ def __init__(self, bmt.CheckpointBlock( TransformerEncoder( dim_model, dim_head, num_heads, dim_ff, bias, dtype - ), use_checkpoint=False + ) ) for _ in range(num_layers) ]) @@ -41,18 +41,15 @@ def forward(self, mask_2d = mask_2d & (pos[:, None, :] >= pos[:, :, None]) out = self.pos_emb(pos) + self.word_emb(input) - bmt.synchronize() # for layer in self.transformers: out = self.transformers(out, mask_2d, None) out = self.layernorm(out) - bmt.synchronize() if config['tp_size'] > 1: - logits = self.word_emb.projection(out)#self.word_emb(out, projection=True) + logits = self.word_emb.projection(out) else: logits = self.word_emb(out, projection=True) - bmt.synchronize() bmt.inspect.record_tensor(logits, "logits") return logits diff --git a/example/train.py b/example/train.py index 50036675..44a7d5d2 100644 --- a/example/train.py +++ b/example/train.py @@ -10,7 +10,7 @@ def main(): bmt.init_distributed( seed=0, zero_level=2, - tp_size=4, + tp_size=2, ) model = GPT( @@ -26,13 +26,10 @@ def main(): ) bmt.init_parameters(model) - #bmt.load(model, "example_model.pt") - # print_inspect(model, "*") bmt.print_rank("Model memory") bmt.print_rank(torch.cuda.memory_summary()) bmt.synchronize() - #bmt.save(model, "example_model.pt") # data # generate dummy data for each rank From 763b4080816fb6b9e8d541ca352ea12ad12dd88a Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Thu, 24 Aug 2023 09:46:38 +0800 Subject: [PATCH 04/25] refactor nccl group; remove partition_modules in pipe_layer.py --- bmtrain/init.py | 42 ++++++++++------------- bmtrain/pipe_layer.py | 77 +------------------------------------------ 2 files changed, 19 insertions(+), 100 deletions(-) diff --git a/bmtrain/init.py b/bmtrain/init.py index a403caf7..915e30c7 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -107,7 +107,6 @@ def init_distributed( unique_id = bytes.fromhex(store.get("BMTRAIN_UNIQUE_ID").decode()) config['comm'] = nccl.commInitRank(unique_id, world_size, rank) - config['zero_comm'] = config['comm'] if config['pipe_enabled']: config["micros"] = num_micro_batches if num_micro_batches else config["pipe_size"] @@ -126,15 +125,18 @@ def init_distributed( unique_id = bytes.fromhex(store.get(f"TP_UNIQUE_ID{topo.tp_idx}").decode()) config['tp_comm'] = nccl.commInitRank(unique_id, tp_size, topo.tp_id) - if not config['pipe_enabled'] and config['tp_size'] <= 1: - config['tp_zero_comm'] = config['comm'] - else: if topo.tp_zero_id == 0: unique_id = nccl.getUniqueId() - store.set(f"ZERO_UNIQUE_ID{topo.tp_zero_idx}", unique_id.hex() ) - unique_id = bytes.fromhex(store.get(f"ZERO_UNIQUE_ID{topo.tp_zero_idx}").decode()) + store.set(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}", unique_id.hex() ) + unique_id = bytes.fromhex(store.get(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}").decode()) config ['tp_zero_comm'] = nccl.commInitRank(unique_id, world_size//(config['pipe_size'] * config['tp_size']), topo.tp_zero_id) + if topo.zero_id == 0: + unique_id = nccl.getUniqueId() + store.set(f"ZERO_UNIQUE_ID{topo.zero_idx}", unique_id.hex() ) + unique_id = bytes.fromhex(store.get(f"ZERO_UNIQUE_ID{topo.zero_idx}").decode()) + config ['zero_comm'] = nccl.commInitRank(unique_id, world_size//(config['pipe_size']), topo.zero_id) + for i in range(world_size): if i == rank: print_dict("Initialization", { @@ -162,29 +164,21 @@ def __init__(self,config): config['zero_size'] = world_size // pp_size topo=torch.tensor(range(dp_size*tp_size*pp_size),dtype=torch.int,device='cuda') topo=topo.view(pp_size,dp_size*tp_size) - self.pp_group=topo.transpose(0,1).reshape(-1,pp_size) - self.stage_id = (self.pp_group == self.rank).nonzero()[0,-1].item() self.stages = config['pipe_size'] - self.pipe_idx = (self.pp_group == self.rank).nonzero()[0, 0].item() # x axes - self.zero_id = self.pipe_idx - self.zero_idx = self.stage_id - self.tp_group = topo.reshape(pp_size, dp_size, tp_size) - self.tp_id = (self.tp_group == self.rank).nonzero()[0,2].item() - self.tp_idx = (self.tp_group == self.rank).nonzero()[0,1 if dp_size > 1 else 0].item() - - if pp_size == 1 and tp_size == 1: - self.tp_zero_id = self.rank - self.tp_zero_idx = 0 - else: - self.dp_group = self.tp_group.permute(0,2,1) - self.tp_zero_id = (self.dp_group == self.rank).nonzero()[0,2 if tp_size > 1 else 0].item() - self.tp_zero_idx = (self.dp_group == self.rank).nonzero()[0,1 if tp_size > 1 else 2].item() + for i in range(world_size): + self.pipe_idx = self.rank % pp_size + self.stage_id = self.rank // pp_size + self.tp_id = self.rank % tp_size + self.tp_idx = self.rank // tp_size + self.zero_idx = self.stage_id if pp_size > 1 else 0 + self.zero_id = self.pipe_idx if pp_size > 1 else self.rank + self.tp_zero_idx = self.tp_id + self.tp_zero_id = self.tp_idx if dp_size > 1 else 0 self.next_rank = self.stage_id+1 if self.stage_id < config['pipe_size'] - 1 else -1 self.prev_rank = self.stage_id-1 if self.stage_id > 0 else -1 - self.tails = self.pp_group[self.pipe_idx, self.stage_id:].tolist() - self.heads = self.pp_group[self.pipe_idx, :self.stage_id + 1].tolist() + def get_group_id(self,group_name): if group_name == "pipe": diff --git a/bmtrain/pipe_layer.py b/bmtrain/pipe_layer.py index 0a34ac46..efce4049 100644 --- a/bmtrain/pipe_layer.py +++ b/bmtrain/pipe_layer.py @@ -189,7 +189,6 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1) -> None: rank = config['rank'] topo = config['topology'] self.layer_ids = [] - pipe_group = topo.pp_group self.stages = topo.stages self.stage_id = topo.stage_id self.pipe_idx = topo.pipe_idx @@ -218,11 +217,6 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1) -> None: self._modules[str(self.layer_ids[0])]._is_first_layer = True self._modules[str(self.layer_ids[-1])]._is_last_layer = True - self.partition_modules(self.layer_ids) - self.next_rank = pipe_group[self.pipe_idx, self.stage_id + 1] if self.stage_id < config['pipe_size'] - 1 else -1 - self.prev_rank = pipe_group[self.pipe_idx, self.stage_id - 1] if self.stage_id > 0 else -1 - # self.micro_batches = config['num_micro_batches'] - self.save_list = [(i, i) for i in range(len(self.layer_ids))] def __len__(self) -> int: @@ -295,77 +289,8 @@ def get_stage_by_layer_id(self, layer_id : int) -> int: else: return rest + (layer_id - rest * (part_len+1)) // part_len - def partition_modules(self, idxs) -> None: - for i in range(len(self)): - contiguous_params = {} - for kw, val in self[i]._storage_info.items(): - storage_type = val["storage_type"] - contiguous_params[kw] = storage_type(round_up(val["total"], config["world_size"] // config["pipe_size"])) - nccl.allGather( - self[i]._storage_params[kw].storage(), - contiguous_params[kw], - config["comm"] - ) - - if i not in idxs: - for name, param in self[i]._module.named_parameters(): - param.data = torch.tensor([], dtype = param.dtype, device = param.device) - for kw, val in self[i]._storage_info.items(): - val["begin"] = self.stage_id - val["end"] = self.stage_id + 1 - val["partition_size"] = 1 - val["total"] = val["world_size"] - dtype = self[i]._storage_params[kw].dtype - device = self[i]._storage_params[kw].device - self[i]._storage_params[kw] = \ - torch.nn.Parameter(torch.tensor([0], dtype = dtype, device=device)) - else: - for kw, val in self[i]._storage_info.items(): - storage_type = val["storage_type"] - val["world_size"] = config["world_size"] // config["pipe_size"] - partition_size = round_up(val["total"], val["world_size"]) // val["world_size"] - val["partition_size"] = partition_size - val["begin"] = config['zero_rank'] * partition_size - val["end"] = (config['zero_rank'] + 1) * partition_size - storage_param_buffer = storage_type(partition_size) - dtype = storage_param_buffer.dtype - device = storage_param_buffer.device - self[i]._storage_params[kw] = torch.nn.Parameter( - torch.tensor([], dtype=dtype, device=device).set_(storage_param_buffer) - ) - if val["requires_grad"]: - self[i]._storage_params[kw].requires_grad_(True) - else: - self[i]._storage_params[kw].requires_grad_(False) - ordered_parameters = list(self[i]._module.named_parameters()) - for idx, named_param in enumerate(ordered_parameters): - name, param = named_param - param_info = self[i]._param_info[idx] - kw_name = _get_param_kw(param) - storage_info = self[i]._storage_info[kw_name] - storage_st = storage_info["begin"] - storage_end = storage_info["end"] - param_st = param_info["offset"] - param_end = param_st + param_info["size"] - if not (param_st >= storage_end or param_end <= storage_st): - # copy offset in parameter storage - offset_st = max(storage_st - param_st, 0) - offset_end = min(storage_end - param_st, param_info["size"]) - assert offset_st < offset_end - to_offset_st = offset_st + param_st - storage_st - to_offset_end = offset_end + param_st - storage_st - d_dtype = self[i]._storage_params[kw_name].dtype - d_device = self[i]._storage_params[kw_name].device - param.data = torch.tensor([], dtype=param.dtype, device=param.device).set_(self[i]._storage_params[kw_name].storage(), to_offset_st, (to_offset_end - to_offset_st,)) - param_info["begin"] = to_offset_st - param_info["end"] = (to_offset_end - to_offset_st,) - param.data[:] = \ - torch.tensor([], dtype=d_dtype, device=d_device).set_(contiguous_params[kw], storage_st+to_offset_st, (to_offset_end - to_offset_st,))[:] - else: - param.data = torch.tensor([], dtype=param.dtype, device=param.device) - del contiguous_params - def _save_to_state_dict(self, destination, prefix, keep_vars): + print('call _save_to_state_dict') for name, module in self._modules.items(): idx = int(name) name = prefix + name + '.' From 4c50567a94b7143cceecfe2f90977a659e0e8dc0 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Thu, 24 Aug 2023 10:43:51 +0800 Subject: [PATCH 05/25] fix by review comment --- bmtrain/block_layer.py | 26 +++++++------------ bmtrain/init.py | 2 +- bmtrain/loss/cross_entropy.py | 4 +-- bmtrain/nn/__init__.py | 4 +-- bmtrain/nn/column_parallel_linear.py | 11 ++++---- ...ropy.py => parallel_cross_entropy_func.py} | 6 ++--- .../{embedding.py => parallel_embedding.py} | 12 +++++---- ...r_hook_func.py => parallel_linear_func.py} | 2 +- bmtrain/nn/row_parallel_linear.py | 6 ++--- bmtrain/param_init.py | 1 - bmtrain/parameter.py | 7 +++-- bmtrain/pipe_layer.py | 1 - bmtrain/utils.py | 7 +++++ 13 files changed, 45 insertions(+), 44 deletions(-) rename bmtrain/nn/{cross_entropy.py => parallel_cross_entropy_func.py} (95%) rename bmtrain/nn/{embedding.py => parallel_embedding.py} (84%) rename bmtrain/nn/{parallel_linear_hook_func.py => parallel_linear_func.py} (98%) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index d01a01cb..c479a7e5 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -1,6 +1,6 @@ from typing import Dict, Iterable, Iterator, Union, List -from .utils import round_up +from .utils import (round_up, tp_split_tensor) from .global_var import config import torch from . import nccl @@ -309,10 +309,12 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, tp_mode = param._tp_mode if input_param.__class__.__name__ == "DistributedTensorWrapper": input_param = input_param.broadcast() - if not tp_mode and input_param.shape != it["shape"]: + + verify_shape = it["shape"] if not tp_mode else param._tp_original_shape + if input_param.shape != verify_shape: error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' 'the shape in current model is {}.' - .format(key, input_param.shape, it["shape"])) + .format(key, input_param.shape, verify_shape)) continue param_st = it["offset"] @@ -328,18 +330,13 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, continue # copy to buffer - if not tp_mode: - assert input_param.numel() == it["size"] + verify_size = verify_shape.numel() + assert input_param.numel() == verify_size contiguous_param = input_param.to(it["parameter"].dtype).cuda().contiguous() - if tp_mode: - tp_split_dim = param._tp_split_dim - if tp_split_dim >= 0: - param_list = contiguous_param.chunk(config['tp_size'], dim=tp_split_dim) - sub_tensor = param_list[config['topology'].tp_id] - contiguous_param = torch.empty(sub_tensor.shape, device=sub_tensor.device, dtype=sub_tensor.dtype) - contiguous_param.copy_(sub_tensor) + if tp_mode and tp_split_dim >= 0: + contiguous_param = tp_split_tensor(contiguous_param, tp_split_dim) offset_st = max(storage_st - param_st, 0) offset_end = min(storage_end - param_st, contiguous_param.numel()) @@ -429,10 +426,7 @@ def init_parameters(self): continue if param._tp_mode and param._tp_split_dim >= 0: - tensor_list = tmp_tensor.chunk(config['tp_size'], dim=param._tp_split_dim) - sub_tensor = tensor_list[config['topology'].tp_id].contiguous() - tmp_tensor = torch.empty(sub_tensor.shape, device=param.device, dtype=sub_tensor.dtype) - tmp_tensor.copy_(sub_tensor) + tmp_tensor = tp_split_tensor(tmp_tensor, param._tp_split_dim) # copy to buffer assert tmp_tensor.is_contiguous() and it["size"] == tmp_tensor.numel() diff --git a/bmtrain/init.py b/bmtrain/init.py index 915e30c7..23999e39 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -15,7 +15,7 @@ def init_distributed( seed : int = 0, zero_level: int = 3, pipe_size: int = -1, - tp_size = 1, + tp_size : int = 1, num_micro_batches: int = None, ): """Initialize distributed training. diff --git a/bmtrain/loss/cross_entropy.py b/bmtrain/loss/cross_entropy.py index a5d43a55..962cfdeb 100644 --- a/bmtrain/loss/cross_entropy.py +++ b/bmtrain/loss/cross_entropy.py @@ -1,7 +1,7 @@ from typing import Optional import torch from . import _function as F -from bmtrain.nn import fused_cross_entropy +from bmtrain.nn import parallel_cross_entropy_func from bmtrain.global_var import config from bmtrain.distributed import all_gather @@ -191,7 +191,7 @@ def __init__(self, def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if config['tp_size'] > 1: target = all_gather(target, comm=config['tp_comm']).flatten(0,1) - ret = fused_cross_entropy(input, target.long(), self.ignore_index) + ret = parallel_cross_entropy_func(input, target.long(), self.ignore_index) else: if input.dtype == torch.float32: return torch.nn.functional.cross_entropy( diff --git a/bmtrain/nn/__init__.py b/bmtrain/nn/__init__.py index 12540b57..05026738 100644 --- a/bmtrain/nn/__init__.py +++ b/bmtrain/nn/__init__.py @@ -1,5 +1,5 @@ from .linear import Linear from .column_parallel_linear import ColumnParallelLinear from .row_parallel_linear import RowParallelLinear -from .embedding import Embedding -from .cross_entropy import fused_cross_entropy +from .parallel_embedding import ParallelEmbedding +from .parallel_cross_entropy_func import parallel_cross_entropy_func diff --git a/bmtrain/nn/column_parallel_linear.py b/bmtrain/nn/column_parallel_linear.py index 9d6444b4..2baf79c2 100644 --- a/bmtrain/nn/column_parallel_linear.py +++ b/bmtrain/nn/column_parallel_linear.py @@ -3,16 +3,17 @@ import bmtrain as bmt from bmtrain.global_var import config -from .parallel_linear_hook_func import ( - LinearHookFunc, +from .parallel_linear_func import ( + ParallelLinearFunc, ReduceType) class ColumnParallelLinear(bmt.DistributedModule): - def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None, gather_output=False) -> None: + def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None, gather_output=False, gather_input=True) -> None: super().__init__() self.in_features = in_features self.out_features = out_features self.gather_output = gather_output + self.gather_input = gather_input tp_size = config['tp_size'] assert out_features % tp_size == 0 self.out_features_per_partition = out_features // tp_size @@ -23,10 +24,10 @@ def __init__(self, in_features : int, out_features: int, bias: bool = True, dtyp self.register_parameter('bias', None) def forward(self, input): - gather_input = True + gather_input = self.gather_input split_input = False reduce_output_type = None - return LinearHookFunc.apply(input, self.weight, self.bias, gather_input, self.gather_output, split_input, reduce_output_type) + return ParallelLinearFunc.apply(input, self.weight, self.bias, gather_input, self.gather_output, split_input, reduce_output_type) def extra_repr(self) -> str: return 'in_features={}, out_features={}, bias={}'.format( diff --git a/bmtrain/nn/cross_entropy.py b/bmtrain/nn/parallel_cross_entropy_func.py similarity index 95% rename from bmtrain/nn/cross_entropy.py rename to bmtrain/nn/parallel_cross_entropy_func.py index 49460b5b..55aa78cb 100644 --- a/bmtrain/nn/cross_entropy.py +++ b/bmtrain/nn/parallel_cross_entropy_func.py @@ -3,7 +3,7 @@ from bmtrain.global_var import config from bmtrain.distributed import all_reduce, all_gather -class FusedCrossEntropyFunc(torch.autograd.Function): +class ParallelCrossEntropyFunc(torch.autograd.Function): @staticmethod def forward(ctx, vocab_parallel_logits, target, label_smoothing=0.0): @@ -112,7 +112,7 @@ def backward(ctx, grad_output): return grad_input, None, None -def fused_cross_entropy(vocab_parallel_logits, target, label_smoothing=0.0): +def parallel_cross_entropy_func(vocab_parallel_logits, target, label_smoothing=0.0): """ Performs cross entropy loss when logits are split across tensor parallel ranks @@ -125,6 +125,6 @@ def fused_cross_entropy(vocab_parallel_logits, target, label_smoothing=0.0): lobal_smoothing: smoothing factor, must be in range [0.0, 1.0) default is no smoothing (=0.0) """ - out = FusedCrossEntropyFunc.apply(vocab_parallel_logits.to(torch.float32), target, label_smoothing) + out = ParallelCrossEntropyFunc.apply(vocab_parallel_logits.to(torch.float32), target, label_smoothing) return out diff --git a/bmtrain/nn/embedding.py b/bmtrain/nn/parallel_embedding.py similarity index 84% rename from bmtrain/nn/embedding.py rename to bmtrain/nn/parallel_embedding.py index c8acafae..1d7a1330 100644 --- a/bmtrain/nn/embedding.py +++ b/bmtrain/nn/parallel_embedding.py @@ -6,9 +6,9 @@ import bmtrain as bmt from bmtrain.global_var import config from bmtrain.distributed import all_reduce, all_gather -from .parallel_linear_hook_func import LinearHookFunc +from .parallel_linear_func import ParallelLinearFunc -class Embedding(bmt.DistributedModule): +class ParallelEmbedding(bmt.DistributedModule): def __init__( self, vocab_size: int, @@ -31,16 +31,18 @@ def __init__( tp_split_dim=0, ) - def forward(self, ids: torch.Tensor): + def forward(self, ids: torch.Tensor, gather_input=True): """ Args: ids (:obj:`torch.Tensor` of shape ``(batch_size, seq_len)``): Indices of input sequence tokens. + gather_input (bool) : whether gather input is required between tensor parallel group) Return: :obj:`torch.Tensor` of shape ``(batch_size, seq_len, embedding_size)``: The embedding output. """ # noqa: E501 if config['tp_size'] > 1: - ids = all_gather(ids, comm=config['tp_comm']) + if gather_input: + ids = all_gather(ids, comm=config['tp_comm']) input_mask = (ids < self.start_index) | (ids >= self.end_index) ids = ids.clone() - self.start_index ids[input_mask] = 0 @@ -67,5 +69,5 @@ def projection(self, x: torch.Tensor): split_input = False reduce_output_type = None gather_output = False - out = LinearHookFunc.apply(x , self.weight, None, gather_input, gather_output, split_input, reduce_output_type) + out = ParallelLineakFunc.apply(x , self.weight, None, gather_input, gather_output, split_input, reduce_output_type) return out diff --git a/bmtrain/nn/parallel_linear_hook_func.py b/bmtrain/nn/parallel_linear_func.py similarity index 98% rename from bmtrain/nn/parallel_linear_hook_func.py rename to bmtrain/nn/parallel_linear_func.py index af7bea0a..5c7a30bc 100644 --- a/bmtrain/nn/parallel_linear_hook_func.py +++ b/bmtrain/nn/parallel_linear_func.py @@ -20,7 +20,7 @@ def preprocess_input(input, gather_input, split_input): input = all_input_list[config['topology'].tp_id] return input -class LinearHookFunc(torch.autograd.Function): +class ParallelLinearFunc(torch.autograd.Function): @staticmethod def forward(ctx, input, weight, bias=None, gather_input=False, gather_output=False, split_input=False, reduce_output_type=None): ctx.save_for_backward(input, weight, bias) diff --git a/bmtrain/nn/row_parallel_linear.py b/bmtrain/nn/row_parallel_linear.py index a6877b7c..acc9378d 100644 --- a/bmtrain/nn/row_parallel_linear.py +++ b/bmtrain/nn/row_parallel_linear.py @@ -3,8 +3,8 @@ import bmtrain as bmt from bmtrain.global_var import config -from .parallel_linear_hook_func import ( - LinearHookFunc, +from .parallel_linear_func import ( + ParallelLinearFunc, ReduceType) class RowParallelLinear(bmt.DistributedModule): @@ -27,7 +27,7 @@ def forward(self, input): gather_input = self.split_input gather_output = False reduce_output_type = ReduceType.ALL_REDUCE if self.all_reduce_output else ReduceType.REDUCE_SCATTER - out = LinearHookFunc.apply(input, self.weight, None, gather_input, gather_output, self.split_input, reduce_output_type) + out = ParallelLinearFunc.apply(input, self.weight, None, gather_input, gather_output, self.split_input, reduce_output_type) out = out + self.bias return out diff --git a/bmtrain/param_init.py b/bmtrain/param_init.py index 5e165fe1..d5c86225 100644 --- a/bmtrain/param_init.py +++ b/bmtrain/param_init.py @@ -3,7 +3,6 @@ from .block_layer import CheckpointBlock from .parameter import DistributedParameter from .global_var import config -from . import nccl def init_distributed_parameter(params : Iterable[torch.nn.Parameter]): diff --git a/bmtrain/parameter.py b/bmtrain/parameter.py index f68d53ac..b5e5e9ae 100644 --- a/bmtrain/parameter.py +++ b/bmtrain/parameter.py @@ -53,9 +53,8 @@ def __new__(cls, original_shape = data.size() tp_original_shape = original_shape if tp_mode and tp_split_dim >= 0: - list_shape = list(original_shape) - list_shape[tp_split_dim] *= config['tp_size'] - tp_original_shape = list_shape + tp_original_shape = list(original_shape) + tp_original_shape[tp_split_dim] *= config['tp_size'] cuda_storage = cuda_tensor.storage_type()(cuda_storage_size) @@ -176,7 +175,7 @@ def backward(ctx, grad_output : torch.Tensor): ) grad_tensor = torch.tensor([], dtype=grad_output.dtype, device="cuda") grad_tensor.set_(grad_storage, 0, (ctx.tensor_size,)) - return grad_tensor, None + return grad_tensor class ParameterInitializer: """ diff --git a/bmtrain/pipe_layer.py b/bmtrain/pipe_layer.py index efce4049..a21bb9af 100644 --- a/bmtrain/pipe_layer.py +++ b/bmtrain/pipe_layer.py @@ -290,7 +290,6 @@ def get_stage_by_layer_id(self, layer_id : int) -> int: return rest + (layer_id - rest * (part_len+1)) // part_len def _save_to_state_dict(self, destination, prefix, keep_vars): - print('call _save_to_state_dict') for name, module in self._modules.items(): idx = int(name) name = prefix + name + '.' diff --git a/bmtrain/utils.py b/bmtrain/utils.py index a5687c7d..8cb87808 100644 --- a/bmtrain/utils.py +++ b/bmtrain/utils.py @@ -107,6 +107,13 @@ def see_memory(message, detail=False): """) torch.cuda.reset_peak_memory_stats() +def tp_split_tensor(tensor, split_dim): + tensor_list = tensor.chunk(config['tp_size'], dim=split_dim) + sub_tensor = tensor_list[config['topology'].tp_id].contiguous() + tmp_tensor = torch.empty(sub_tensor.shape, device=sub_tensor.device, dtype=sub_tensor.dtype) + tmp_tensor.copy_(sub_tensor) + return tmp_tensor + class AverageRecorder: """A utility class to record the average value of a quantity over time. From 825139c5f6c48afca9ac24e09dcfb6b426522f4e Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Thu, 24 Aug 2023 11:26:56 +0800 Subject: [PATCH 06/25] fix topology --- bmtrain/init.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/bmtrain/init.py b/bmtrain/init.py index 23999e39..17d217b1 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -15,8 +15,8 @@ def init_distributed( seed : int = 0, zero_level: int = 3, pipe_size: int = -1, - tp_size : int = 1, num_micro_batches: int = None, + tp_size : int = 1, ): """Initialize distributed training. This function will initialize the distributed training, set the random seed and global configurations. @@ -25,6 +25,9 @@ def init_distributed( Args: seed (int): The random seed. zero_level (int): The ZeRO optimization level. 2 for stage-2, 3 for stage-3. + pipe_size (int) : pipe_size means that all processes will be divided into pipe_size groups + num_micro_batches (int) : means that the input batchs will be divided into num_micro_batches small batches. used in pipeline mode. + tp_size (int) : tp_size means the size of each of tensor parallel group **init_distributed** reads the following environment variables: @@ -107,24 +110,23 @@ def init_distributed( unique_id = bytes.fromhex(store.get("BMTRAIN_UNIQUE_ID").decode()) config['comm'] = nccl.commInitRank(unique_id, world_size, rank) + topo = config['topology'] if config['pipe_enabled']: config["micros"] = num_micro_batches if num_micro_batches else config["pipe_size"] - topo = config['topology'] if topo.stage_id == 0: unique_id = nccl.getUniqueId() store.set(f"PIPE_UNIQUE_ID{topo.pipe_idx}", unique_id.hex()) unique_id = bytes.fromhex(store.get(f"PIPE_UNIQUE_ID{topo.pipe_idx}").decode()) config ['pipe_comm'] = nccl.commInitRank(unique_id, pipe_size, topo.stage_id) - if config['tp_size'] > 1: - topo = config['topology'] - if topo.tp_id == 0: - unique_id = nccl.getUniqueId() - store.set(f"TP_UNIQUE_ID{topo.tp_idx}", unique_id.hex()) - unique_id = bytes.fromhex(store.get(f"TP_UNIQUE_ID{topo.tp_idx}").decode()) - config['tp_comm'] = nccl.commInitRank(unique_id, tp_size, topo.tp_id) + if topo.tp_id == 0: + unique_id = nccl.getUniqueId() + store.set(f"TP_UNIQUE_ID{topo.tp_idx}", unique_id.hex()) + unique_id = bytes.fromhex(store.get(f"TP_UNIQUE_ID{topo.tp_idx}").decode()) + config['tp_comm'] = nccl.commInitRank(unique_id, tp_size, topo.tp_id) + if config['tp_size'] > 1: if topo.tp_zero_id == 0: unique_id = nccl.getUniqueId() store.set(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}", unique_id.hex() ) @@ -166,13 +168,14 @@ def __init__(self,config): topo=topo.view(pp_size,dp_size*tp_size) self.stages = config['pipe_size'] + stage_size = world_size // pp_size for i in range(world_size): - self.pipe_idx = self.rank % pp_size - self.stage_id = self.rank // pp_size + self.pipe_idx = self.rank % stage_size + self.stage_id = self.rank // stage_size self.tp_id = self.rank % tp_size self.tp_idx = self.rank // tp_size - self.zero_idx = self.stage_id if pp_size > 1 else 0 - self.zero_id = self.pipe_idx if pp_size > 1 else self.rank + self.zero_idx = self.stage_id + self.zero_id = self.pipe_idx self.tp_zero_idx = self.tp_id self.tp_zero_id = self.tp_idx if dp_size > 1 else 0 From 4ff0f419f5428ff38d9863b2afed103babe8baac Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Thu, 24 Aug 2023 13:58:52 +0800 Subject: [PATCH 07/25] fix topology --- bmtrain/init.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bmtrain/init.py b/bmtrain/init.py index 17d217b1..8f4f7063 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -176,8 +176,8 @@ def __init__(self,config): self.tp_idx = self.rank // tp_size self.zero_idx = self.stage_id self.zero_id = self.pipe_idx - self.tp_zero_idx = self.tp_id - self.tp_zero_id = self.tp_idx if dp_size > 1 else 0 + self.tp_zero_idx = self.stage_id * tp_size + self.tp_id + self.tp_zero_id = self.pipe_idx // tp_size self.next_rank = self.stage_id+1 if self.stage_id < config['pipe_size'] - 1 else -1 self.prev_rank = self.stage_id-1 if self.stage_id > 0 else -1 From a5d7ba63d90009576deaf299e64af20f944bf9b0 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Thu, 24 Aug 2023 15:43:56 +0800 Subject: [PATCH 08/25] fix --- bmtrain/nn/parallel_cross_entropy_func.py | 46 ++++++++++------------- bmtrain/nn/parallel_embedding.py | 2 +- 2 files changed, 20 insertions(+), 28 deletions(-) diff --git a/bmtrain/nn/parallel_cross_entropy_func.py b/bmtrain/nn/parallel_cross_entropy_func.py index 55aa78cb..cd1f63bf 100644 --- a/bmtrain/nn/parallel_cross_entropy_func.py +++ b/bmtrain/nn/parallel_cross_entropy_func.py @@ -6,19 +6,19 @@ class ParallelCrossEntropyFunc(torch.autograd.Function): @staticmethod - def forward(ctx, vocab_parallel_logits, target, label_smoothing=0.0): + def forward(ctx, logits, target, label_smoothing=0.0): comm = config['tp_comm'] rank = config['topology'].tp_id world_size = config['tp_size'] - # Maximum value along vocab dimension across all GPUs. - logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] - logits_max = all_reduce(logits_max, op="max", comm=comm) - # Subtract the maximum value. - vocab_parallel_logits = vocab_parallel_logits - logits_max.unsqueeze(dim=-1) - - # Get the partition's vocab indecies - #get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size - partition_vocab_size = vocab_parallel_logits.size()[-1] + + # local max + max_logits = torch.max(logits, dim=-1)[0] + # global max + max_logits = all_reduce(max_logits, op="max", comm=comm) + + logits = logits - max_logits.unsqueeze(dim=-1) + + partition_vocab_size = logits.size()[-1] vocab_start_index = rank * partition_vocab_size vocab_end_index = (rank + 1) * partition_vocab_size @@ -27,10 +27,7 @@ def forward(ctx, vocab_parallel_logits, target, label_smoothing=0.0): masked_target = target.clone() - vocab_start_index masked_target[target_mask] = 0 - # Get predicted-logits = logits[target]. - # For Simplicity, we convert logits to a 2-D tensor with size - # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. - logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) + logits_2d = logits.view(-1, partition_vocab_size) masked_target_1d = masked_target.view(-1) arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) @@ -38,12 +35,12 @@ def forward(ctx, vocab_parallel_logits, target, label_smoothing=0.0): predicted_logits_1d = predicted_logits_1d.clone().contiguous() predicted_logits = predicted_logits_1d.view_as(target) predicted_logits[target_mask] = 0.0 + # All reduce is needed to get the chunks from other GPUs. predicted_logits = all_reduce(predicted_logits, op="sum", comm=comm) - # Sum of exponential of logits along vocab dimension across all GPUs. - exp_logits = vocab_parallel_logits - torch.exp(vocab_parallel_logits, out=exp_logits) + exp_logits = logits + torch.exp(logits, out=exp_logits) sum_exp_logits = exp_logits.sum(dim=-1) sum_exp_logits = all_reduce(sum_exp_logits, op="sum", comm=comm) @@ -106,25 +103,20 @@ def backward(ctx, grad_output): else: grad_2d[arange_1d, masked_target_1d] -= softmax_update - # Finally elementwise multiplication with the output gradients. grad_input.mul_(grad_output.flatten(0,1).unsqueeze(dim=-1)) return grad_input, None, None -def parallel_cross_entropy_func(vocab_parallel_logits, target, label_smoothing=0.0): +def parallel_cross_entropy_func(logits, target, label_smoothing=0.0): """ Performs cross entropy loss when logits are split across tensor parallel ranks Arguments: - vocab_parallel_logits: logits split across tensor parallel ranks - dimension is [sequence_length, batch_size, hidden_size] - - target: correct vocab ids of dimseion [sequence_length, micro_batch_size] - - lobal_smoothing: smoothing factor, must be in range [0.0, 1.0) - default is no smoothing (=0.0) + logits: logits split across tensor parallel ranks dimension is [batch * seq_len, hidden_size]. + target: correct vocab ids of dimseion [batch * seq_len]. + lobal_smoothing: smoothing factor, must be in range [0.0, 1.0). default is 0.0. """ - out = ParallelCrossEntropyFunc.apply(vocab_parallel_logits.to(torch.float32), target, label_smoothing) + out = ParallelCrossEntropyFunc.apply(logits.to(torch.float32), target, label_smoothing) return out diff --git a/bmtrain/nn/parallel_embedding.py b/bmtrain/nn/parallel_embedding.py index 1d7a1330..7ffb74ef 100644 --- a/bmtrain/nn/parallel_embedding.py +++ b/bmtrain/nn/parallel_embedding.py @@ -69,5 +69,5 @@ def projection(self, x: torch.Tensor): split_input = False reduce_output_type = None gather_output = False - out = ParallelLineakFunc.apply(x , self.weight, None, gather_input, gather_output, split_input, reduce_output_type) + out = ParallelLinearFunc.apply(x , self.weight, None, gather_input, gather_output, split_input, reduce_output_type) return out From 2951d70a2f430d5b54b2f6ad5a6a6b73d2c17846 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Thu, 24 Aug 2023 16:01:32 +0800 Subject: [PATCH 09/25] use ParallelEmbedding --- example/models/gpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/models/gpt.py b/example/models/gpt.py index d2dff467..64474ba8 100644 --- a/example/models/gpt.py +++ b/example/models/gpt.py @@ -15,7 +15,7 @@ def __init__(self, self.max_distance = max_distance if config['tp_size'] > 1: - self.word_emb = bmt.nn.Embedding(vocab_size, dim_model, dtype=dtype) + self.word_emb = bmt.nn.ParallelEmbedding(vocab_size, dim_model, dtype=dtype) else: self.word_emb = Embedding(vocab_size, dim_model, dtype=dtype) self.pos_emb = Embedding(max_distance, dim_model, dtype=dtype) From 39319e117e7b17aece40330a056956331cf40856 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Thu, 24 Aug 2023 19:55:42 +0800 Subject: [PATCH 10/25] overlap parallel linear backward --- bmtrain/nn/parallel_linear_func.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/bmtrain/nn/parallel_linear_func.py b/bmtrain/nn/parallel_linear_func.py index 5c7a30bc..252162b2 100644 --- a/bmtrain/nn/parallel_linear_func.py +++ b/bmtrain/nn/parallel_linear_func.py @@ -73,13 +73,17 @@ def backward(ctx, grad_output): #TODO: gather on load_stream grad_all_input = grad_output.matmul(weight) grad_input = torch.empty_like(input) + current_stream = torch.cuda.current_stream() + config['tp_comm_stream'].wait_stream(current_stream) if ctx.gather_input: - nccl.reduceScatter(grad_all_input.storage(), grad_input.storage(), "sum", config['tp_comm']) + with torch.cuda.stream(config['tp_comm_stream']): + nccl.reduceScatter(grad_all_input.storage(), grad_input.storage(), "sum", config['tp_comm']) else: grad_input = grad_all_input if ctx.split_input: - grad_input = all_gather(grad_input, config['tp_comm']) + with torch.cuda.stream(config['tp_comm_stream']): + grad_input = all_gather(grad_input, config['tp_comm']) if weight.requires_grad: dim = grad_output.dim() @@ -88,4 +92,7 @@ def backward(ctx, grad_output): if bias is not None and bias.requires_grad: grad_bias = grad_output.reshape(-1, grad_output.shape[-1]).sum(0) + + current_stream = torch.cuda.current_stream() + current_stream.wait_stream(config['tp_comm_stream']) return grad_input, grad_weight, grad_bias, None, None, None, None From df3fd8f8452c8f3eea982d375aa5847c64bd0551 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Thu, 24 Aug 2023 20:28:46 +0800 Subject: [PATCH 11/25] add tp_comm_stream --- bmtrain/init.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/bmtrain/init.py b/bmtrain/init.py index 8f4f7063..9126f8d8 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -75,6 +75,7 @@ def init_distributed( config["world_size"] = world_size config["calc_stream"] = torch.cuda.current_stream() config["load_stream"] = torch.cuda.Stream(priority=-1) + config["tp_comm_stream"] = torch.cuda.Stream(priority=-1) config['barrier_stream'] = torch.cuda.Stream() config["load_event"] = torch.cuda.Event() config["zero_level"] = zero_level @@ -126,12 +127,12 @@ def init_distributed( unique_id = bytes.fromhex(store.get(f"TP_UNIQUE_ID{topo.tp_idx}").decode()) config['tp_comm'] = nccl.commInitRank(unique_id, tp_size, topo.tp_id) - if config['tp_size'] > 1: - if topo.tp_zero_id == 0: - unique_id = nccl.getUniqueId() - store.set(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}", unique_id.hex() ) - unique_id = bytes.fromhex(store.get(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}").decode()) - config ['tp_zero_comm'] = nccl.commInitRank(unique_id, world_size//(config['pipe_size'] * config['tp_size']), topo.tp_zero_id) + #if config['tp_size'] > 1: + if topo.tp_zero_id == 0: + unique_id = nccl.getUniqueId() + store.set(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}", unique_id.hex() ) + unique_id = bytes.fromhex(store.get(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}").decode()) + config ['tp_zero_comm'] = nccl.commInitRank(unique_id, world_size//(config['pipe_size'] * config['tp_size']), topo.tp_zero_id) if topo.zero_id == 0: unique_id = nccl.getUniqueId() From 99efba3b307122f381db31ab243f154ff46124f9 Mon Sep 17 00:00:00 2001 From: Achazwl <323163497@qq.com> Date: Thu, 24 Aug 2023 21:25:08 +0800 Subject: [PATCH 12/25] fix tp --- bmtrain/block_layer.py | 3 ++- bmtrain/init.py | 12 ++++++------ bmtrain/loss/cross_entropy.py | 4 +++- bmtrain/nn/__init__.py | 3 ++- bmtrain/nn/column_parallel_linear.py | 4 ++-- bmtrain/nn/linear.py | 4 ++-- bmtrain/nn/parallel_embedding.py | 5 ++--- bmtrain/nn/parallel_linear_func.py | 5 ++++- bmtrain/nn/row_parallel_linear.py | 4 ++-- setup.py | 2 +- 10 files changed, 26 insertions(+), 20 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index c479a7e5..e0a2c3e6 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -310,7 +310,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, if input_param.__class__.__name__ == "DistributedTensorWrapper": input_param = input_param.broadcast() - verify_shape = it["shape"] if not tp_mode else param._tp_original_shape + verify_shape = torch.Size(it["shape"] if not tp_mode else param._tp_original_shape) if input_param.shape != verify_shape: error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' 'the shape in current model is {}.' @@ -335,6 +335,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, contiguous_param = input_param.to(it["parameter"].dtype).cuda().contiguous() + tp_split_dim = param._tp_split_dim if tp_mode and tp_split_dim >= 0: contiguous_param = tp_split_tensor(contiguous_param, tp_split_dim) diff --git a/bmtrain/init.py b/bmtrain/init.py index 8f4f7063..eb176ad4 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -81,6 +81,7 @@ def init_distributed( config["tp_size"] = tp_size if tp_size > 0 else 1 config["topology"] = topology(config) config["zero_rank"] = config['topology'].get_group_rank("zero") + config["tp_rank"] = config['topology'].get_group_rank("tp") config["tp_zero_rank"] = config['topology'].get_group_rank("tp_zero") cpus_this_worker = None @@ -126,12 +127,11 @@ def init_distributed( unique_id = bytes.fromhex(store.get(f"TP_UNIQUE_ID{topo.tp_idx}").decode()) config['tp_comm'] = nccl.commInitRank(unique_id, tp_size, topo.tp_id) - if config['tp_size'] > 1: - if topo.tp_zero_id == 0: - unique_id = nccl.getUniqueId() - store.set(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}", unique_id.hex() ) - unique_id = bytes.fromhex(store.get(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}").decode()) - config ['tp_zero_comm'] = nccl.commInitRank(unique_id, world_size//(config['pipe_size'] * config['tp_size']), topo.tp_zero_id) + if topo.tp_zero_id == 0: + unique_id = nccl.getUniqueId() + store.set(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}", unique_id.hex() ) + unique_id = bytes.fromhex(store.get(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}").decode()) + config['tp_zero_comm'] = nccl.commInitRank(unique_id, world_size//(config['pipe_size'] * config['tp_size']), topo.tp_zero_id) if topo.zero_id == 0: unique_id = nccl.getUniqueId() diff --git a/bmtrain/loss/cross_entropy.py b/bmtrain/loss/cross_entropy.py index 962cfdeb..c4b36a7f 100644 --- a/bmtrain/loss/cross_entropy.py +++ b/bmtrain/loss/cross_entropy.py @@ -180,6 +180,7 @@ def __init__(self, reduction: str = 'mean', label_smoothing: float = 0.0, # TODO not supported yet inplace: bool = False, + parallel: bool = False, ) -> None: super().__init__() self.weight = weight @@ -187,9 +188,10 @@ def __init__(self, self.reduction = reduction self.label_smoothing = label_smoothing self.inplace = inplace + self.parallel = parallel def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - if config['tp_size'] > 1: + if self.parallel: target = all_gather(target, comm=config['tp_comm']).flatten(0,1) ret = parallel_cross_entropy_func(input, target.long(), self.ignore_index) else: diff --git a/bmtrain/nn/__init__.py b/bmtrain/nn/__init__.py index 05026738..e22d8c55 100644 --- a/bmtrain/nn/__init__.py +++ b/bmtrain/nn/__init__.py @@ -1,5 +1,6 @@ -from .linear import Linear +from .linear import Linear, OpLinear from .column_parallel_linear import ColumnParallelLinear from .row_parallel_linear import RowParallelLinear from .parallel_embedding import ParallelEmbedding from .parallel_cross_entropy_func import parallel_cross_entropy_func +from .parallel_linear_func import OpParallelLinear \ No newline at end of file diff --git a/bmtrain/nn/column_parallel_linear.py b/bmtrain/nn/column_parallel_linear.py index 2baf79c2..e8f554c8 100644 --- a/bmtrain/nn/column_parallel_linear.py +++ b/bmtrain/nn/column_parallel_linear.py @@ -4,7 +4,7 @@ import bmtrain as bmt from bmtrain.global_var import config from .parallel_linear_func import ( - ParallelLinearFunc, + OpParallelLinear, ReduceType) class ColumnParallelLinear(bmt.DistributedModule): @@ -27,7 +27,7 @@ def forward(self, input): gather_input = self.gather_input split_input = False reduce_output_type = None - return ParallelLinearFunc.apply(input, self.weight, self.bias, gather_input, self.gather_output, split_input, reduce_output_type) + return OpParallelLinear.apply(input, self.weight, self.bias, gather_input, self.gather_output, split_input, reduce_output_type) def extra_repr(self) -> str: return 'in_features={}, out_features={}, bias={}'.format( diff --git a/bmtrain/nn/linear.py b/bmtrain/nn/linear.py index faf0770e..cb04863a 100644 --- a/bmtrain/nn/linear.py +++ b/bmtrain/nn/linear.py @@ -2,7 +2,7 @@ import torch.nn.functional as F import bmtrain as bmt -class CustomLinear(torch.autograd.Function): +class OpLinear(torch.autograd.Function): @staticmethod def forward(ctx, x, weight, bias=None): ctx.save_for_backward(x, weight, bias) @@ -35,7 +35,7 @@ def __init__(self, in_features : int, out_features: int, bias: bool = True, dtyp self.register_parameter('bias', None) def forward(self, input): - return CustomLinear.apply(input, self.weight, self.bias) + return OpLinear.apply(input, self.weight, self.bias) def extra_repr(self) -> str: return 'in_features={}, out_features={}, bias={}'.format( diff --git a/bmtrain/nn/parallel_embedding.py b/bmtrain/nn/parallel_embedding.py index 7ffb74ef..cd567b4e 100644 --- a/bmtrain/nn/parallel_embedding.py +++ b/bmtrain/nn/parallel_embedding.py @@ -6,7 +6,7 @@ import bmtrain as bmt from bmtrain.global_var import config from bmtrain.distributed import all_reduce, all_gather -from .parallel_linear_func import ParallelLinearFunc +from .parallel_linear_func import OpParallelLinear class ParallelEmbedding(bmt.DistributedModule): def __init__( @@ -35,7 +35,6 @@ def forward(self, ids: torch.Tensor, gather_input=True): """ Args: ids (:obj:`torch.Tensor` of shape ``(batch_size, seq_len)``): Indices of input sequence tokens. - gather_input (bool) : whether gather input is required between tensor parallel group) Return: :obj:`torch.Tensor` of shape ``(batch_size, seq_len, embedding_size)``: The embedding output. """ # noqa: E501 @@ -69,5 +68,5 @@ def projection(self, x: torch.Tensor): split_input = False reduce_output_type = None gather_output = False - out = ParallelLinearFunc.apply(x , self.weight, None, gather_input, gather_output, split_input, reduce_output_type) + out = OpParallelLinear.apply(x , self.weight, None, gather_input, gather_output, split_input, reduce_output_type) return out diff --git a/bmtrain/nn/parallel_linear_func.py b/bmtrain/nn/parallel_linear_func.py index 5c7a30bc..dc2d3816 100644 --- a/bmtrain/nn/parallel_linear_func.py +++ b/bmtrain/nn/parallel_linear_func.py @@ -20,9 +20,12 @@ def preprocess_input(input, gather_input, split_input): input = all_input_list[config['topology'].tp_id] return input -class ParallelLinearFunc(torch.autograd.Function): +class OpParallelLinear(torch.autograd.Function): @staticmethod def forward(ctx, input, weight, bias=None, gather_input=False, gather_output=False, split_input=False, reduce_output_type=None): + if reduce_output_type is not None: + reduce_output_type = ReduceType(reduce_output_type) + ctx.save_for_backward(input, weight, bias) ctx.gather_output = gather_output ctx.split_input = split_input diff --git a/bmtrain/nn/row_parallel_linear.py b/bmtrain/nn/row_parallel_linear.py index acc9378d..71a4297d 100644 --- a/bmtrain/nn/row_parallel_linear.py +++ b/bmtrain/nn/row_parallel_linear.py @@ -4,7 +4,7 @@ import bmtrain as bmt from bmtrain.global_var import config from .parallel_linear_func import ( - ParallelLinearFunc, + OpParallelLinear, ReduceType) class RowParallelLinear(bmt.DistributedModule): @@ -27,7 +27,7 @@ def forward(self, input): gather_input = self.split_input gather_output = False reduce_output_type = ReduceType.ALL_REDUCE if self.all_reduce_output else ReduceType.REDUCE_SCATTER - out = ParallelLinearFunc.apply(input, self.weight, None, gather_input, gather_output, self.split_input, reduce_output_type) + out = OpParallelLinear.apply(input, self.weight, None, gather_input, gather_output, self.split_input, reduce_output_type) out = out + self.bias return out diff --git a/setup.py b/setup.py index 2bbb55d8..ad1c8905 100644 --- a/setup.py +++ b/setup.py @@ -92,7 +92,7 @@ def build_extension(self, ext): ] setup( name='bmtrain', - version='0.2.3.post2', + version='0.2.3.post3', author="Guoyang Zeng", author_email="qbjooo@qq.com", description="A toolkit for training big models", From f1b4fd7127124eb98d9ec32f16424f81312e32b5 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Fri, 25 Aug 2023 12:50:32 +0800 Subject: [PATCH 13/25] fix load_state_dict --- bmtrain/layer.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/bmtrain/layer.py b/bmtrain/layer.py index ebbef815..7de83e5e 100644 --- a/bmtrain/layer.py +++ b/bmtrain/layer.py @@ -1,6 +1,8 @@ import torch from .parameter import DistributedParameter +from .global_var import config import itertools +from .utils import tp_split_tensor class DistributedModule(torch.nn.Module): """ @@ -11,7 +13,7 @@ class DistributedModule(torch.nn.Module): def __getattr__(self, name: str): ret = super().__getattr__(name) # gather distributed parameters if not in CheckpointBlock - if isinstance(ret, DistributedParameter) and not ret._in_checkpoint_block: + if isinstance(ret, DistributedParameter) and not ret._in_checkpoint_block: return ret.gather() return ret @@ -30,8 +32,11 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): """ for name, param in self._parameters.items(): if param is not None: - if isinstance(param, DistributedParameter) and not param._in_checkpoint_block: - destination[prefix + name] = param.gather().detach().cpu() # sync operation + if isinstance(param, DistributedParameter):#and not param._in_checkpoint_block: + if param._in_checkpoint_block: + destination[prefix + name] = param.tp_gather().detach().cpu() # sync operation + else: + destination[prefix + name] = param.gather_all().detach().cpu() # sync operation else: destination[prefix + name] = param if keep_vars else param.detach().cpu() for name, buf in self._buffers.items(): @@ -81,6 +86,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, for name, param in local_state.items(): key = prefix + name if key in state_dict: + tp_mode = param._tp_mode input_param = state_dict[key] if input_param.__class__.__name__ == "DistributedTensorWrapper": input_param = input_param.broadcast() @@ -98,13 +104,17 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 'the shape in current model is {}.' .format(key, input_param.shape, param.shape)) continue - if not is_param_lazy and isinstance(param, DistributedParameter) and input_param.shape != param._original_shape: + verify_shape = torch.Size(param._original_shape if not tp_mode else param._tp_original_shape) + if not is_param_lazy and isinstance(param, DistributedParameter) and input_param.shape != verify_shape: error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' 'the shape in current model is {}.' - .format(key, input_param.shape, param.shape)) + .format(key, input_param.shape, verify_shape)) try: with torch.no_grad(): if isinstance(param, DistributedParameter): + tp_split_dim = param._tp_split_dim + if tp_mode and tp_split_dim >= 0: + input_param = tp_split_tensor(input_param, tp_split_dim) param._copy_data(input_param) else: param.copy_(input_param) From 677a316228fd623d5b7e4e71a01676f2927ffc0e Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Fri, 25 Aug 2023 15:23:08 +0800 Subject: [PATCH 14/25] test parallel linear --- tests/test_all.py | 2 + tests/test_column_parallel_linear.py | 55 ++++++++++++++++++++++++++++ tests/test_row_parallel_linear.py | 54 +++++++++++++++++++++++++++ 3 files changed, 111 insertions(+) create mode 100644 tests/test_column_parallel_linear.py create mode 100644 tests/test_row_parallel_linear.py diff --git a/tests/test_all.py b/tests/test_all.py index 6682aa93..aa382676 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -25,6 +25,8 @@ ("send_recv", 4), ("nccl_backward", 4), ("no_grad", 1), + ("column_parallel_linear", 2), + ("row_parallel_linear", 2), ("training", 4), ]) diff --git a/tests/test_column_parallel_linear.py b/tests/test_column_parallel_linear.py new file mode 100644 index 00000000..1c49570a --- /dev/null +++ b/tests/test_column_parallel_linear.py @@ -0,0 +1,55 @@ +import torch +import bmtrain as bmt +from bmtrain.global_var import config +import numpy as np + +def run_bmt(x, gather_output, ckp_path, tp_size=2): + linear = bmt.nn.ColumnParallelLinear(8,8, gather_output=gather_output) + linear = bmt.CheckpointBlock(linear) + bmt.init_parameters(linear) + y = linear(x[config['topology'].tp_id]) + y.sum().backward() + bmt.save(linear, ckp_path) + bmt.synchronize() + return y, linear._parameters['weight'].grad, linear._parameters['bias'].grad + +def run_torch(x, ckp_path): + linear = torch.nn.Linear(8, 8) + linear_dict = torch.load(ckp_path) + linear.load_state_dict(linear_dict) + linear = linear.cuda() + linear.weight.requires_grad_() + y = linear(x) + y.sum().backward() + return y, linear.weight.grad, linear.bias.grad + +def run(gather_output, ckp_path): + tp_size = bmt.config['tp_size'] + torch.cuda.manual_seed(100) + x = torch.randn(tp_size, 8,8, device='cuda').requires_grad_() + y1, weight_grad1, bias_grad1 = run_bmt(x, gather_output, ckp_path) + y2, weight_grad2, bias_grad2 = run_torch(x, ckp_path) + tp_rank = config['topology'].tp_id + if gather_output: + assert np.allclose(y1.detach().cpu().numpy(), y2.flatten(0,1).detach().cpu().numpy()) + else: + torch_out_list = torch.split(y2, y2.size()[-1] // tp_size, dim=y2.dim()-1) + assert np.allclose(y1.detach().cpu().numpy(), torch_out_list[tp_rank].flatten(0,1).detach().cpu().numpy()) + + weight_grad_list = weight_grad2.chunk(tp_size, dim=0) + assert np.allclose(weight_grad1.reshape(weight_grad_list[tp_rank].shape).cpu().numpy(), weight_grad_list[tp_rank].cpu().numpy()) + + bias_grad_list = bias_grad2.chunk(tp_size, dim=0) + assert np.allclose(bias_grad1.reshape(bias_grad_list[tp_rank].shape).cpu().numpy(), bias_grad_list[tp_rank].cpu().numpy()) + +def test_gather_output(): + run(True, 'linear.ckp') + +def test_no_gather_output(): + run(False, 'linear_no_gather.ckp') + +if __name__ == "__main__": + bmt.init_distributed(tp_size=2) + test_gather_output() + test_no_gather_output() + diff --git a/tests/test_row_parallel_linear.py b/tests/test_row_parallel_linear.py new file mode 100644 index 00000000..f89b6dc5 --- /dev/null +++ b/tests/test_row_parallel_linear.py @@ -0,0 +1,54 @@ +import torch +import bmtrain as bmt +from bmtrain.global_var import config +import numpy as np + +def run_bmt(x, ckp_path, split_input=True, use_checkpoint_block=True): + linear = bmt.nn.RowParallelLinear(8,8, split_input=split_input, all_reduce_output=True) + if use_checkpoint_block: + linear = bmt.CheckpointBlock(linear) + bmt.init_parameters(linear) + y = linear(x) + y.sum().backward() + bmt.save(linear, ckp_path) + bmt.synchronize() + return y, linear._parameters['weight'].grad, linear._parameters['bias'].grad + +def run_torch(x, ckp_path): + linear = torch.nn.Linear(8, 8) + linear_dict = torch.load(ckp_path) + linear.load_state_dict(linear_dict) + linear = linear.cuda() + linear.weight.requires_grad_() + y = linear(x) + y.sum().backward() + return y, linear.weight.grad, linear.bias.grad + +def run(split_input, use_checkpoint_block, ckp_path): + tp_size = bmt.config['tp_size'] + torch.cuda.manual_seed(100) + tp_rank = config['topology'].tp_id + x = torch.randn(8,8, device='cuda').requires_grad_() + rank_x = x.chunk(tp_size, dim=0 if split_input else 1)[tp_rank] + y1, weight_grad1, bias_grad1 = run_bmt(rank_x, ckp_path, split_input, use_checkpoint_block) + y2, weight_grad2, bias_grad2 = run_torch(x, ckp_path) + np.testing.assert_allclose(y1.detach().cpu().numpy(), y2.detach().cpu().numpy(), atol=1e-5) + + weight_grad_list = weight_grad2.chunk(tp_size, dim=1) + assert np.allclose(weight_grad1.reshape(weight_grad_list[tp_rank].shape).cpu().numpy(), weight_grad_list[tp_rank].cpu().numpy()) + + assert np.allclose(bias_grad1.cpu().numpy(), bias_grad2.cpu().numpy()) + +def test_split_input(): + run(True, False, 'row_parallel_linear.ckp') + run(True, True, 'row_parallel_linear.ckp') + +def test_no_split_input(): + run(False, False, 'row_parallel_linear_no_split.ckp') + run(False, True, 'row_parallel_linear_no_split.ckp') + +if __name__ == "__main__": + bmt.init_distributed(tp_size=2) + test_no_split_input() + test_split_input() + From 743253e8f31745000c6f3c73c0cb599cc6837166 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Fri, 25 Aug 2023 15:39:10 +0800 Subject: [PATCH 15/25] mv zero_level to CheckpointBlock --- bmtrain/block_layer.py | 3 ++- bmtrain/hook_func.py | 11 +++++++---- bmtrain/init.py | 3 --- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 18438b8c..4e9a6c01 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -66,7 +66,7 @@ class CheckpointBlock(torch.nn.Module): >>> y2, ... = transformer_block(x) >>> assert torch.allclose(y1, y2) """ - def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True): + def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, zero_level=3): super().__init__() self._module = inner_module self._inputs = None @@ -200,6 +200,7 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True): self._mode = "BLOCK" #BLOCK or ZERO or PIPE self.all_input_no_grad = False self.all_param_no_grad = False + self._zero_level = zero_level def set_pre_module(self, pre_module): if pre_module is not None: diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index 6a56300e..08d37b5d 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -9,7 +9,7 @@ def zero_pre_forward(module, inputs): enter = module._micro_idx == 0 pipe = True if enter: - zero_level = config['zero_level'] + zero_level = module._zero_level #config['zero_level'] forward_flag = 1 if zero_level == 2 else 0 if zero_level == 2 and module._ref_count > 1: forward_flag = 2 # repeating forward in same layer @@ -19,7 +19,8 @@ def zero_pre_forward(module, inputs): module._forward_block_ctx.enter(forward_flag) def zero_post_forward(module, inputs, outputs): - forward_flag = 1 if config['zero_level'] == 2 else 0 + #forward_flag = 1 if config['zero_level'] == 2 else 0 + forward_flag = 1 if module._zero_level == 2 else 0 if module.all_param_no_grad: forward_flag = 0 exit = True @@ -31,7 +32,8 @@ def zero_post_forward(module, inputs, outputs): module._ref_count += 1 def zero_pre_backward(module, grad_outputs): - backward_flag = 2 if config['zero_level'] == 2 else 0 + #backward_flag = 2 if config['zero_level'] == 2 else 0 + backward_flag = 2 if module._zero_level == 2 else 0 if module._mode != "PIPE": module._backward_block_ctx = CheckpointBlockContext(module, module._layer_dict) module._backward_block_ctx.enter(backward_flag, True) @@ -43,7 +45,8 @@ def zero_pre_backward(module, grad_outputs): module._backward_block_ctx.enter(backward_flag, True) def zero_post_backward(module, grad_inputs, grad_outputs): - backward_flag = 2 if config['zero_level'] == 2 else 0 + #backward_flag = 2 if config['zero_level'] == 2 else 0 + backward_flag = 2 if module._zero_level == 2 else 0 if module._mode != "PIPE": if module._is_first_layer: module.backward_release(backward_flag) diff --git a/bmtrain/init.py b/bmtrain/init.py index 1fa0712d..f3c1faa7 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -14,7 +14,6 @@ def init_distributed( init_method : str = "env://", seed : int = 0, - zero_level: int = 3, pipe_size: int = -1, num_micro_batches: int = None, ): @@ -24,7 +23,6 @@ def init_distributed( Args: seed (int): The random seed. - zero_level (int): The ZeRO optimization level. 2 for stage-2, 3 for stage-3. **init_distributed** reads the following environment variables: @@ -74,7 +72,6 @@ def init_distributed( config["load_stream"] = torch.cuda.Stream(priority=-1) config['barrier_stream'] = torch.cuda.Stream() config["load_event"] = torch.cuda.Event() - config["zero_level"] = zero_level config["topology"] = topology(config) config["zero_rank"] = config["topology"].get_group_rank("zero") if pipe_size > 1 else config['rank'] cpus_this_worker = None From 604ddfee44578148794561c7dd50a122c7110b9a Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Fri, 25 Aug 2023 16:53:19 +0800 Subject: [PATCH 16/25] fix overlap --- bmtrain/nn/parallel_linear_func.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/bmtrain/nn/parallel_linear_func.py b/bmtrain/nn/parallel_linear_func.py index 9e76db9e..df4fb0b5 100644 --- a/bmtrain/nn/parallel_linear_func.py +++ b/bmtrain/nn/parallel_linear_func.py @@ -31,8 +31,10 @@ def forward(ctx, input, weight, bias=None, gather_input=False, gather_output=Fal ctx.split_input = split_input ctx.gather_input = gather_input ctx.reduce_output_type = reduce_output_type + all_input = preprocess_input(input, ctx.gather_input, ctx.split_input) out = F.linear(all_input, weight, bias) + if gather_output: all_output_list = all_gather(out, config['tp_comm']) all_output_list = all_output_list.chunk(config['tp_size'], dim=0) @@ -44,12 +46,15 @@ def forward(ctx, input, weight, bias=None, gather_input=False, gather_output=Fal if reduce_output_type == ReduceType.ALL_REDUCE: nccl.allReduce(out.storage(), out.storage(), "sum", config['tp_comm']) return out + elif reduce_output_type == ReduceType.REDUCE_SCATTER: shape = list(out.shape) shape[0] = shape[0] // config['tp_size'] reduce_out = torch.empty(shape, dtype=out.dtype, device=out.device) nccl.reduceScatter(out.storage(), reduce_out.storage(), "sum", config['tp_comm']) return reduce_out + else: + assert False, "no support reduce type{}".format(reduce_output_type) @staticmethod def backward(ctx, grad_output): @@ -72,20 +77,23 @@ def backward(ctx, grad_output): all_input = preprocess_input(input, ctx.gather_input, ctx.split_input) if input.requires_grad: - #gather can async with grad_out.matmul(weight) - #TODO: gather on load_stream grad_all_input = grad_output.matmul(weight) grad_input = torch.empty_like(input) - current_stream = torch.cuda.current_stream() - config['tp_comm_stream'].wait_stream(current_stream) if ctx.gather_input: with torch.cuda.stream(config['tp_comm_stream']): + current_stream = torch.cuda.current_stream() + config['tp_comm_stream'].wait_stream(current_stream) + grad_input.record_stream(config['tp_comm_stream']) + grad_all_input.record_stream(config['tp_comm_stream']) nccl.reduceScatter(grad_all_input.storage(), grad_input.storage(), "sum", config['tp_comm']) else: grad_input = grad_all_input if ctx.split_input: with torch.cuda.stream(config['tp_comm_stream']): + current_stream = torch.cuda.current_stream() + config['tp_comm_stream'].wait_stream(current_stream) + grad_input.record_stream(config['tp_comm_stream']) grad_input = all_gather(grad_input, config['tp_comm']) if weight.requires_grad: From 0aee817ef193e8f00a44b46017f60e1bcdc8fd9f Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Fri, 25 Aug 2023 21:25:49 +0800 Subject: [PATCH 17/25] gather once in atten --- bmtrain/nn/parallel_linear_func.py | 3 +-- example/layers/attention.py | 16 ++++++++++------ 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/bmtrain/nn/parallel_linear_func.py b/bmtrain/nn/parallel_linear_func.py index df4fb0b5..f7c4573a 100644 --- a/bmtrain/nn/parallel_linear_func.py +++ b/bmtrain/nn/parallel_linear_func.py @@ -77,11 +77,11 @@ def backward(ctx, grad_output): all_input = preprocess_input(input, ctx.gather_input, ctx.split_input) if input.requires_grad: + current_stream = torch.cuda.current_stream() grad_all_input = grad_output.matmul(weight) grad_input = torch.empty_like(input) if ctx.gather_input: with torch.cuda.stream(config['tp_comm_stream']): - current_stream = torch.cuda.current_stream() config['tp_comm_stream'].wait_stream(current_stream) grad_input.record_stream(config['tp_comm_stream']) grad_all_input.record_stream(config['tp_comm_stream']) @@ -91,7 +91,6 @@ def backward(ctx, grad_output): if ctx.split_input: with torch.cuda.stream(config['tp_comm_stream']): - current_stream = torch.cuda.current_stream() config['tp_comm_stream'].wait_stream(current_stream) grad_input.record_stream(config['tp_comm_stream']) grad_input = all_gather(grad_input, config['tp_comm']) diff --git a/example/layers/attention.py b/example/layers/attention.py index a49edabb..61eeb9f2 100644 --- a/example/layers/attention.py +++ b/example/layers/attention.py @@ -19,9 +19,9 @@ def __init__(self, super().__init__() if config['tp_size'] > 1: - self.project_q = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) - self.project_k = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) - self.project_v = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + self.project_q = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype, gather_input=False) + self.project_k = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype, gather_input=False) + self.project_v = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype, gather_input=False) self.project_out = RowParallelLinear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype) else: self.project_q = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) @@ -44,9 +44,12 @@ def forward(self, batch_size, seq_q, dim_model = hidden_q.size() seq_kv = hidden_kv.size(1) + if config['tp_size'] > 1: + hidden_q = all_gather(hidden_q, comm=config['tp_comm']).flatten(0,1) + h_q : torch.Tensor = self.project_q(hidden_q) - h_k : torch.Tensor = self.project_k(hidden_kv) - h_v : torch.Tensor = self.project_v(hidden_kv) + h_k : torch.Tensor = self.project_k(hidden_q) + h_v : torch.Tensor = self.project_v(hidden_q) if config['tp_size'] > 1: #batch_size will changed in TensorParallel batch_size = h_v.shape[0] @@ -74,7 +77,8 @@ def forward(self, score = score + position_bias.view(batch_size, -1, seq_q, seq_kv) if config['tp_size'] > 1: - mask = all_gather(mask, config['tp_comm']) + with torch.no_grad(): + mask = all_gather(mask, config['tp_comm']).flatten(0,1) score = torch.where( mask.view(batch_size, 1, seq_q, seq_kv), From bd0bad0e647c7387836a1f515ae18f6f0d7e6b42 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Fri, 25 Aug 2023 22:25:17 +0800 Subject: [PATCH 18/25] fix sub grad_input in parallel linear --- bmtrain/nn/parallel_linear_func.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bmtrain/nn/parallel_linear_func.py b/bmtrain/nn/parallel_linear_func.py index f7c4573a..1f0a362a 100644 --- a/bmtrain/nn/parallel_linear_func.py +++ b/bmtrain/nn/parallel_linear_func.py @@ -79,7 +79,7 @@ def backward(ctx, grad_output): if input.requires_grad: current_stream = torch.cuda.current_stream() grad_all_input = grad_output.matmul(weight) - grad_input = torch.empty_like(input) + grad_input = torch.zeros_like(input) if ctx.gather_input: with torch.cuda.stream(config['tp_comm_stream']): config['tp_comm_stream'].wait_stream(current_stream) From 15460b6da3722b3d82aa4ef136e4a650331da0e5 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Sat, 26 Aug 2023 11:16:23 +0800 Subject: [PATCH 19/25] fix gather_output --- bmtrain/nn/parallel_linear_func.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bmtrain/nn/parallel_linear_func.py b/bmtrain/nn/parallel_linear_func.py index 1f0a362a..55741530 100644 --- a/bmtrain/nn/parallel_linear_func.py +++ b/bmtrain/nn/parallel_linear_func.py @@ -16,7 +16,7 @@ def preprocess_input(input, gather_input, split_input): input = input.flatten(0, 1) if split_input: - all_input_list = input.chunk(config['tp_size'], dim=1) + all_input_list = input.chunk(config['tp_size'], dim=-1) input = all_input_list[config['topology'].tp_id] return input @@ -68,7 +68,7 @@ def backward(ctx, grad_output): if gather_output: tp_size = config['tp_size'] tp_id = config['topology'].tp_id - grad_output_list = grad_output.chunk(tp_size, dim=1) + grad_output_list = grad_output.chunk(tp_size, dim=-1) grad_output = grad_output_list[tp_id] grad_input = grad_weight = grad_bias = None From b44a62e9c5570c34861673a4b7c57a6250aeeb02 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Sat, 26 Aug 2023 15:13:27 +0800 Subject: [PATCH 20/25] fix train.py --- example/train.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/example/train.py b/example/train.py index 44a7d5d2..8aaf65e4 100644 --- a/example/train.py +++ b/example/train.py @@ -9,7 +9,6 @@ def main(): bmt.init_distributed( seed=0, - zero_level=2, tp_size=2, ) @@ -54,7 +53,7 @@ def main(): break if config['tp_size'] > 1: - loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100) + loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, parallel=True) else: loss_func = torch.nn.CrossEntropyLoss(ignore_index=-100) From 100cd55bae8e32d1f46aacc97485c3f00703204f Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Sat, 26 Aug 2023 16:21:20 +0800 Subject: [PATCH 21/25] fused q,k,v --- example/layers/attention.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/example/layers/attention.py b/example/layers/attention.py index 61eeb9f2..8fbb7510 100644 --- a/example/layers/attention.py +++ b/example/layers/attention.py @@ -44,12 +44,18 @@ def forward(self, batch_size, seq_q, dim_model = hidden_q.size() seq_kv = hidden_kv.size(1) - if config['tp_size'] > 1: - hidden_q = all_gather(hidden_q, comm=config['tp_comm']).flatten(0,1) + assert hidden_q.data_ptr() == hidden_kv.data_ptr() + + hidden_q = bmt.nn.OpParallelLinear.apply( + hidden_q, + torch.cat([self.project_q.weight, self.project_k.weight, self.project_v.weight], dim=0), + torch.cat([self.project_q.bias, self.project_k.bias, self.project_v.bias], dim=0), + True, False, + False, None + ) + + h_q, h_k, h_v = hidden_q.chunk(3, dim=-1) - h_q : torch.Tensor = self.project_q(hidden_q) - h_k : torch.Tensor = self.project_k(hidden_q) - h_v : torch.Tensor = self.project_v(hidden_q) if config['tp_size'] > 1: #batch_size will changed in TensorParallel batch_size = h_v.shape[0] From fa09468b0312ac0c7854f71232f1d62f11a51ebb Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Sat, 26 Aug 2023 17:51:47 +0800 Subject: [PATCH 22/25] fix row parallel linear --- bmtrain/nn/row_parallel_linear.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bmtrain/nn/row_parallel_linear.py b/bmtrain/nn/row_parallel_linear.py index 71a4297d..bbe94cfa 100644 --- a/bmtrain/nn/row_parallel_linear.py +++ b/bmtrain/nn/row_parallel_linear.py @@ -28,7 +28,8 @@ def forward(self, input): gather_output = False reduce_output_type = ReduceType.ALL_REDUCE if self.all_reduce_output else ReduceType.REDUCE_SCATTER out = OpParallelLinear.apply(input, self.weight, None, gather_input, gather_output, self.split_input, reduce_output_type) - out = out + self.bias + if self.bias is not None: + out = out + self.bias return out def extra_repr(self) -> str: From 37bc403d344054d7b4878c12da22050dfd747e0b Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Sat, 26 Aug 2023 17:53:52 +0800 Subject: [PATCH 23/25] fix cross entropy --- bmtrain/loss/cross_entropy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bmtrain/loss/cross_entropy.py b/bmtrain/loss/cross_entropy.py index c4b36a7f..982a6469 100644 --- a/bmtrain/loss/cross_entropy.py +++ b/bmtrain/loss/cross_entropy.py @@ -219,7 +219,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: ret = w * ret - if config['tp_size'] > 1: + if self.parallel and config['tp_size'] > 1: ret_list = ret.chunk(config['tp_size'], dim=0) ret = ret_list[config['topology'].tp_id] w_list = w.chunk(config['tp_size'], dim=0) From 15c2c48ce1eb2b49396e473d9b9619a1d7fb82b3 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Mon, 28 Aug 2023 03:12:52 -0700 Subject: [PATCH 24/25] Update setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index ad1c8905..2bbb55d8 100644 --- a/setup.py +++ b/setup.py @@ -92,7 +92,7 @@ def build_extension(self, ext): ] setup( name='bmtrain', - version='0.2.3.post3', + version='0.2.3.post2', author="Guoyang Zeng", author_email="qbjooo@qq.com", description="A toolkit for training big models", From 42663c8263b267ad4d93bfac9cb8c35dc13e4a31 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Mon, 28 Aug 2023 20:33:52 +0800 Subject: [PATCH 25/25] overlap send communication in pipeline --- bmtrain/init.py | 1 + bmtrain/pipe_layer.py | 12 ++++++++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/bmtrain/init.py b/bmtrain/init.py index b1d92939..a6214d78 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -74,6 +74,7 @@ def init_distributed( config["calc_stream"] = torch.cuda.current_stream() config["load_stream"] = torch.cuda.Stream(priority=-1) config["tp_comm_stream"] = torch.cuda.Stream(priority=-1) + config["pp_comm_stream"] = torch.cuda.Stream(priority=-1) config['barrier_stream'] = torch.cuda.Stream() config["load_event"] = torch.cuda.Event() config["tp_size"] = tp_size if tp_size > 0 else 1 diff --git a/bmtrain/pipe_layer.py b/bmtrain/pipe_layer.py index a21bb9af..c28ab87d 100644 --- a/bmtrain/pipe_layer.py +++ b/bmtrain/pipe_layer.py @@ -139,7 +139,11 @@ def forward(ctx, input, stage_id): def backward(ctx, grad_outputs): if not ctx.is_first_stage: send_data = grad_outputs[0] if isinstance(grad_outputs, tuple) else grad_outputs - send_activations(send_data, ctx.stage_id - 1, config['pipe_comm']) + current_stream = torch.cuda.current_stream() + with torch.cuda.stream(config['pp_comm_stream']): + config['pp_comm_stream'].wait_stream(current_stream) + send_data.record_stream(current_stream) + send_activations(send_data, ctx.stage_id - 1, config['pipe_comm']) return grad_outputs, None class StagePostFunction(torch.autograd.Function): @@ -150,7 +154,11 @@ def forward(ctx, outputs, stage_id): ctx.is_last_stage = stage_id == config['pipe_size'] - 1 if not ctx.is_last_stage: send_data = outputs[0] if isinstance(outputs, tuple) else outputs - send_activations(send_data.detach(), stage_id + 1, config['pipe_comm']) + current_stream = torch.cuda.current_stream() + with torch.cuda.stream(config['pp_comm_stream']): + config['pp_comm_stream'].wait_stream(current_stream) + send_data.record_stream(current_stream) + send_activations(send_data.detach(), stage_id + 1, config['pipe_comm']) return outputs @staticmethod