diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 4e9a6c01..8c88ec69 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -1,6 +1,6 @@ from typing import Dict, Iterable, Iterator, Union, List -from .utils import round_up +from .utils import (round_up, tp_split_tensor) from .global_var import config import torch from . import nccl @@ -94,7 +94,8 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, zero_lev "total": 0, "storage_type": storage_type, "requires_grad": param.requires_grad, - "group": param.group + "group": param.group, + "zero_comm" : param._zero_comm } param_shape = param._original_shape @@ -108,11 +109,14 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, zero_lev offsets = {} # intialize storage buffers for kw, val in self._storage_info.items(): - val["world_size"] = config["world_size"] + comm = val['zero_comm'] + world_size = nccl.commCount(comm) + rank = nccl.commRank(comm) + val["world_size"] = world_size partition_size = round_up(val["total"], val["world_size"]) // val["world_size"] val["partition_size"] = partition_size - val["begin"] = config['rank'] * partition_size - val["end"] = (config['rank'] + 1) * partition_size + val["begin"] = rank * partition_size + val["end"] = (rank+1) * partition_size offsets[kw] = 0 @@ -302,13 +306,18 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, if key in state_dict: # load here input_param = state_dict[key] + param = it['parameter'] + tp_mode = param._tp_mode if input_param.__class__.__name__ == "DistributedTensorWrapper": input_param = input_param.broadcast() - if input_param.shape != it["shape"]: + + verify_shape = torch.Size(it["shape"] if not tp_mode else param._tp_original_shape) + if input_param.shape != verify_shape: error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' 'the shape in current model is {}.' - .format(key, input_param.shape, it["shape"])) + .format(key, input_param.shape, verify_shape)) continue + param_st = it["offset"] param_end = it["offset"] + it["size"] kw_name = it["kw_name"] @@ -322,8 +331,14 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, continue # copy to buffer - assert input_param.numel() == it["size"] + verify_size = verify_shape.numel() + assert input_param.numel() == verify_size + contiguous_param = input_param.to(it["parameter"].dtype).cuda().contiguous() + + tp_split_dim = param._tp_split_dim + if tp_mode and tp_split_dim >= 0: + contiguous_param = tp_split_tensor(contiguous_param, tp_split_dim) offset_st = max(storage_st - param_st, 0) offset_end = min(storage_end - param_st, contiguous_param.numel()) @@ -331,7 +346,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, to_offset_st = offset_st + param_st - storage_st to_offset_end = offset_end + param_st - storage_st - + # copy to buffer # PyTorch 1.11 changed the API of storage.__getitem__ d_dtype = self._storage_params[kw_name].dtype @@ -398,7 +413,7 @@ def init_parameters(self): param = it["parameter"] if isinstance(param, DistributedParameter) and param._init_method is not None: # initialzie here - tmp_tensor = torch.empty(it["shape"], device=param.device, dtype=param.dtype) + tmp_tensor = torch.empty(param._tp_original_shape, device=param.device, dtype=param.dtype) param._init_method(tmp_tensor) param_st = it["offset"] param_end = it["offset"] + it["size"] @@ -412,16 +427,15 @@ def init_parameters(self): if param_end <= storage_st: continue + if param._tp_mode and param._tp_split_dim >= 0: + tmp_tensor = tp_split_tensor(tmp_tensor, param._tp_split_dim) # copy to buffer assert tmp_tensor.is_contiguous() and it["size"] == tmp_tensor.numel() - offset_st = max(storage_st - param_st, 0) - offset_end = min(storage_end - param_st, tmp_tensor.numel()) + offset_st = max(storage_st - param_st, 0) + offset_end = min(storage_end - param_st, tmp_tensor.numel()) assert offset_st < offset_end - to_offset_st = offset_st + param_st - storage_st - to_offset_end = offset_end + param_st - storage_st - # copy to buffer # PyTorch 1.11 changed the API of storage.__getitem__ d_dtype = self._storage_params[kw_name].dtype diff --git a/bmtrain/checkpointing.py b/bmtrain/checkpointing.py index b2c9ec07..550225be 100644 --- a/bmtrain/checkpointing.py +++ b/bmtrain/checkpointing.py @@ -39,10 +39,7 @@ def __init__(self, block : 'CheckpointBlock', ctx_dict : dict = None, pipe = Fal self._param_tensor = {} self._grad_tensor = {} self._need_release = False - if pipe: - self.comm = config["zero_comm"] - else: - self.comm = config["comm"] + def enter(self, flag=0, requires_grad=False): """ gather parameters @@ -74,7 +71,7 @@ def enter(self, flag=0, requires_grad=False): nccl.allGather( self.block._storage_params[kw].storage(), self._param_buffer[kw], - self.comm + val['zero_comm'] ) nccl.groupEnd() @@ -144,7 +141,7 @@ def exit(self, flag=0, backward=False): self._grad_buffer[kw], local_param.grad.storage(), "sum", - self.comm + val['zero_comm'] ) nccl.groupEnd() diff --git a/bmtrain/init.py b/bmtrain/init.py index f3c1faa7..a6214d78 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -10,12 +10,12 @@ from . import nccl from .synchronize import synchronize - def init_distributed( init_method : str = "env://", seed : int = 0, pipe_size: int = -1, num_micro_batches: int = None, + tp_size : int = 1, ): """Initialize distributed training. This function will initialize the distributed training, set the random seed and global configurations. @@ -23,6 +23,9 @@ def init_distributed( Args: seed (int): The random seed. + pipe_size (int) : pipe_size means that all processes will be divided into pipe_size groups + num_micro_batches (int) : means that the input batchs will be divided into num_micro_batches small batches. used in pipeline mode. + tp_size (int) : tp_size means the size of each of tensor parallel group **init_distributed** reads the following environment variables: @@ -70,10 +73,15 @@ def init_distributed( config["world_size"] = world_size config["calc_stream"] = torch.cuda.current_stream() config["load_stream"] = torch.cuda.Stream(priority=-1) + config["tp_comm_stream"] = torch.cuda.Stream(priority=-1) + config["pp_comm_stream"] = torch.cuda.Stream(priority=-1) config['barrier_stream'] = torch.cuda.Stream() config["load_event"] = torch.cuda.Event() + config["tp_size"] = tp_size if tp_size > 0 else 1 config["topology"] = topology(config) - config["zero_rank"] = config["topology"].get_group_rank("zero") if pipe_size > 1 else config['rank'] + config["zero_rank"] = config['topology'].get_group_rank("zero") + config["tp_rank"] = config['topology'].get_group_rank("tp") + config["tp_zero_rank"] = config['topology'].get_group_rank("tp_zero") cpus_this_worker = None all_available_cpus = sorted(list(os.sched_getaffinity(0))) @@ -102,21 +110,34 @@ def init_distributed( unique_id = bytes.fromhex(store.get("BMTRAIN_UNIQUE_ID").decode()) config['comm'] = nccl.commInitRank(unique_id, world_size, rank) + topo = config['topology'] + if config['pipe_enabled']: config["micros"] = num_micro_batches if num_micro_batches else config["pipe_size"] - topo = config['topology'] if topo.stage_id == 0: unique_id = nccl.getUniqueId() store.set(f"PIPE_UNIQUE_ID{topo.pipe_idx}", unique_id.hex()) unique_id = bytes.fromhex(store.get(f"PIPE_UNIQUE_ID{topo.pipe_idx}").decode()) config ['pipe_comm'] = nccl.commInitRank(unique_id, pipe_size, topo.stage_id) - if topo.zero_id == 0: - unique_id = nccl.getUniqueId() - store.set(f"ZERO_UNIQUE_ID{topo.zero_idx}", unique_id.hex() ) - unique_id = bytes.fromhex(store.get(f"ZERO_UNIQUE_ID{topo.zero_idx}").decode()) - config ['zero_comm'] = nccl.commInitRank(unique_id, world_size//pipe_size, topo.zero_id) - else: - config['zero_comm'] = config['comm'] + + if topo.tp_id == 0: + unique_id = nccl.getUniqueId() + store.set(f"TP_UNIQUE_ID{topo.tp_idx}", unique_id.hex()) + unique_id = bytes.fromhex(store.get(f"TP_UNIQUE_ID{topo.tp_idx}").decode()) + config['tp_comm'] = nccl.commInitRank(unique_id, tp_size, topo.tp_id) + + if topo.tp_zero_id == 0: + unique_id = nccl.getUniqueId() + store.set(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}", unique_id.hex() ) + unique_id = bytes.fromhex(store.get(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}").decode()) + config['tp_zero_comm'] = nccl.commInitRank(unique_id, world_size//(config['pipe_size'] * config['tp_size']), topo.tp_zero_id) + + if topo.zero_id == 0: + unique_id = nccl.getUniqueId() + store.set(f"ZERO_UNIQUE_ID{topo.zero_idx}", unique_id.hex() ) + unique_id = bytes.fromhex(store.get(f"ZERO_UNIQUE_ID{topo.zero_idx}").decode()) + config ['zero_comm'] = nccl.commInitRank(unique_id, world_size//(config['pipe_size']), topo.zero_id) + for i in range(world_size): if i == rank: print_dict("Initialization", { @@ -129,40 +150,57 @@ def init_distributed( "cpus": cpus_this_worker }) synchronize() + class topology: def __init__(self,config): # pipe_idx is the idx of the pipeline in the group self.rank = config['rank'] pp_size = config["pipe_size"] + tp_size = config["tp_size"] world_size = config["world_size"] - assert world_size % pp_size == 0, "The nums of GPUs must be divisible by the pipeline parallel size" - - dp_size = world_size // pp_size - topo=torch.tensor(range(dp_size*pp_size),dtype=torch.int,device='cuda') - topo=topo.view(pp_size,dp_size) - self.pp_group=topo.transpose(0,1).reshape(-1,pp_size) - self.dp_group=topo - self.stage_id = (self.pp_group == self.rank).nonzero()[0,-1].item() + assert world_size % (pp_size * tp_size) == 0, "The nums of GPUs must be divisible by the pipeline parallel size * tensor parallel size" + + dp_size = world_size // (pp_size * tp_size) + config['tp_zero_size'] = dp_size + config['zero_size'] = world_size // pp_size + topo=torch.tensor(range(dp_size*tp_size*pp_size),dtype=torch.int,device='cuda') + topo=topo.view(pp_size,dp_size*tp_size) self.stages = config['pipe_size'] - self.pipe_idx = (self.pp_group == self.rank).nonzero()[0, 0].item() # x axes - self.zero_id = self.pipe_idx - self.zero_idx = self.stage_id + + stage_size = world_size // pp_size + for i in range(world_size): + self.pipe_idx = self.rank % stage_size + self.stage_id = self.rank // stage_size + self.tp_id = self.rank % tp_size + self.tp_idx = self.rank // tp_size + self.zero_idx = self.stage_id + self.zero_id = self.pipe_idx + self.tp_zero_idx = self.stage_id * tp_size + self.tp_id + self.tp_zero_id = self.pipe_idx // tp_size + self.next_rank = self.stage_id+1 if self.stage_id < config['pipe_size'] - 1 else -1 self.prev_rank = self.stage_id-1 if self.stage_id > 0 else -1 - self.tails = self.pp_group[self.pipe_idx, self.stage_id:].tolist() - self.heads = self.pp_group[self.pipe_idx, :self.stage_id + 1].tolist() + def get_group_id(self,group_name): if group_name == "pipe": return self.pipe_idx elif group_name == "zero": return self.zero_idx + elif group_name == "tp_zero": + return self.tp_zero_idx + elif group_name == "tp": + return self.tp_idx def get_group_rank(self,group_name): if group_name == "pipe": return self.stage_id elif group_name == "zero": return self.zero_id + elif group_name == "tp_zero": + return self.tp_zero_id + elif group_name == "tp": + return self.tp_id def is_initialized() -> bool: return config["initialized"] diff --git a/bmtrain/layer.py b/bmtrain/layer.py index ebbef815..7de83e5e 100644 --- a/bmtrain/layer.py +++ b/bmtrain/layer.py @@ -1,6 +1,8 @@ import torch from .parameter import DistributedParameter +from .global_var import config import itertools +from .utils import tp_split_tensor class DistributedModule(torch.nn.Module): """ @@ -11,7 +13,7 @@ class DistributedModule(torch.nn.Module): def __getattr__(self, name: str): ret = super().__getattr__(name) # gather distributed parameters if not in CheckpointBlock - if isinstance(ret, DistributedParameter) and not ret._in_checkpoint_block: + if isinstance(ret, DistributedParameter) and not ret._in_checkpoint_block: return ret.gather() return ret @@ -30,8 +32,11 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): """ for name, param in self._parameters.items(): if param is not None: - if isinstance(param, DistributedParameter) and not param._in_checkpoint_block: - destination[prefix + name] = param.gather().detach().cpu() # sync operation + if isinstance(param, DistributedParameter):#and not param._in_checkpoint_block: + if param._in_checkpoint_block: + destination[prefix + name] = param.tp_gather().detach().cpu() # sync operation + else: + destination[prefix + name] = param.gather_all().detach().cpu() # sync operation else: destination[prefix + name] = param if keep_vars else param.detach().cpu() for name, buf in self._buffers.items(): @@ -81,6 +86,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, for name, param in local_state.items(): key = prefix + name if key in state_dict: + tp_mode = param._tp_mode input_param = state_dict[key] if input_param.__class__.__name__ == "DistributedTensorWrapper": input_param = input_param.broadcast() @@ -98,13 +104,17 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 'the shape in current model is {}.' .format(key, input_param.shape, param.shape)) continue - if not is_param_lazy and isinstance(param, DistributedParameter) and input_param.shape != param._original_shape: + verify_shape = torch.Size(param._original_shape if not tp_mode else param._tp_original_shape) + if not is_param_lazy and isinstance(param, DistributedParameter) and input_param.shape != verify_shape: error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' 'the shape in current model is {}.' - .format(key, input_param.shape, param.shape)) + .format(key, input_param.shape, verify_shape)) try: with torch.no_grad(): if isinstance(param, DistributedParameter): + tp_split_dim = param._tp_split_dim + if tp_mode and tp_split_dim >= 0: + input_param = tp_split_tensor(input_param, tp_split_dim) param._copy_data(input_param) else: param.copy_(input_param) diff --git a/bmtrain/loss/cross_entropy.py b/bmtrain/loss/cross_entropy.py index 31223640..982a6469 100644 --- a/bmtrain/loss/cross_entropy.py +++ b/bmtrain/loss/cross_entropy.py @@ -1,6 +1,10 @@ from typing import Optional import torch from . import _function as F +from bmtrain.nn import parallel_cross_entropy_func +from bmtrain.global_var import config +from bmtrain.distributed import all_gather + class OpFusedCrossEntropy(torch.autograd.Function): """ CrossEntropy dim = 1 @@ -176,6 +180,7 @@ def __init__(self, reduction: str = 'mean', label_smoothing: float = 0.0, # TODO not supported yet inplace: bool = False, + parallel: bool = False, ) -> None: super().__init__() self.weight = weight @@ -183,21 +188,26 @@ def __init__(self, self.reduction = reduction self.label_smoothing = label_smoothing self.inplace = inplace + self.parallel = parallel def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - if input.dtype == torch.float32: - return torch.nn.functional.cross_entropy( - input, - target.long(), - weight=self.weight, - ignore_index=self.ignore_index, - reduction=self.reduction, - label_smoothing=self.label_smoothing) - - if self.inplace: - ret = OpFusedCrossEntropyInplace.apply(input, target.int(), self.ignore_index) # return float tensor + if self.parallel: + target = all_gather(target, comm=config['tp_comm']).flatten(0,1) + ret = parallel_cross_entropy_func(input, target.long(), self.ignore_index) else: - ret = OpFusedCrossEntropy.apply(input, target.int(), self.ignore_index) # return float tensor + if input.dtype == torch.float32: + return torch.nn.functional.cross_entropy( + input, + target.long(), + weight=self.weight, + ignore_index=self.ignore_index, + reduction=self.reduction, + label_smoothing=self.label_smoothing) + + if self.inplace: + ret = OpFusedCrossEntropyInplace.apply(input, target.int(), self.ignore_index) # return float tensor + else: + ret = OpFusedCrossEntropy.apply(input, target.int(), self.ignore_index) # return float tensor if self.weight is not None: if self.weight.dim() != 1 or self.weight.size(0) != input.size(1): @@ -208,6 +218,12 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: w = (target != self.ignore_index).int() ret = w * ret + + if self.parallel and config['tp_size'] > 1: + ret_list = ret.chunk(config['tp_size'], dim=0) + ret = ret_list[config['topology'].tp_id] + w_list = w.chunk(config['tp_size'], dim=0) + w = w_list[config['topology'].tp_id] if self.reduction == "none": return ret diff --git a/bmtrain/nn/__init__.py b/bmtrain/nn/__init__.py index 67f9fdee..e22d8c55 100644 --- a/bmtrain/nn/__init__.py +++ b/bmtrain/nn/__init__.py @@ -1 +1,6 @@ -from .linear import Linear +from .linear import Linear, OpLinear +from .column_parallel_linear import ColumnParallelLinear +from .row_parallel_linear import RowParallelLinear +from .parallel_embedding import ParallelEmbedding +from .parallel_cross_entropy_func import parallel_cross_entropy_func +from .parallel_linear_func import OpParallelLinear \ No newline at end of file diff --git a/bmtrain/nn/column_parallel_linear.py b/bmtrain/nn/column_parallel_linear.py new file mode 100644 index 00000000..e8f554c8 --- /dev/null +++ b/bmtrain/nn/column_parallel_linear.py @@ -0,0 +1,36 @@ +import torch +from torch.nn.parameter import Parameter + +import bmtrain as bmt +from bmtrain.global_var import config +from .parallel_linear_func import ( + OpParallelLinear, + ReduceType) + +class ColumnParallelLinear(bmt.DistributedModule): + def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None, gather_output=False, gather_input=True) -> None: + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.gather_output = gather_output + self.gather_input = gather_input + tp_size = config['tp_size'] + assert out_features % tp_size == 0 + self.out_features_per_partition = out_features // tp_size + self.weight = bmt.DistributedParameter(torch.empty(self.out_features_per_partition, in_features, dtype=dtype, device="cuda"), init_method=torch.nn.init.xavier_normal_, tp_split_dim=0, tp_mode=True) + if bias: + self.bias = bmt.DistributedParameter(torch.empty(self.out_features_per_partition, dtype=dtype, device="cuda"), init_method=torch.nn.init.zeros_, tp_split_dim=0, tp_mode=True) + else: + self.register_parameter('bias', None) + + def forward(self, input): + gather_input = self.gather_input + split_input = False + reduce_output_type = None + return OpParallelLinear.apply(input, self.weight, self.bias, gather_input, self.gather_output, split_input, reduce_output_type) + + def extra_repr(self) -> str: + return 'in_features={}, out_features={}, bias={}'.format( + self.in_features, self.out_features_per_partitions, self.bias is not None + ) + diff --git a/bmtrain/nn/linear.py b/bmtrain/nn/linear.py index faf0770e..cb04863a 100644 --- a/bmtrain/nn/linear.py +++ b/bmtrain/nn/linear.py @@ -2,7 +2,7 @@ import torch.nn.functional as F import bmtrain as bmt -class CustomLinear(torch.autograd.Function): +class OpLinear(torch.autograd.Function): @staticmethod def forward(ctx, x, weight, bias=None): ctx.save_for_backward(x, weight, bias) @@ -35,7 +35,7 @@ def __init__(self, in_features : int, out_features: int, bias: bool = True, dtyp self.register_parameter('bias', None) def forward(self, input): - return CustomLinear.apply(input, self.weight, self.bias) + return OpLinear.apply(input, self.weight, self.bias) def extra_repr(self) -> str: return 'in_features={}, out_features={}, bias={}'.format( diff --git a/bmtrain/nn/parallel_cross_entropy_func.py b/bmtrain/nn/parallel_cross_entropy_func.py new file mode 100644 index 00000000..cd1f63bf --- /dev/null +++ b/bmtrain/nn/parallel_cross_entropy_func.py @@ -0,0 +1,122 @@ +import torch +import bmtrain as bmt +from bmtrain.global_var import config +from bmtrain.distributed import all_reduce, all_gather + +class ParallelCrossEntropyFunc(torch.autograd.Function): + + @staticmethod + def forward(ctx, logits, target, label_smoothing=0.0): + comm = config['tp_comm'] + rank = config['topology'].tp_id + world_size = config['tp_size'] + + # local max + max_logits = torch.max(logits, dim=-1)[0] + # global max + max_logits = all_reduce(max_logits, op="max", comm=comm) + + logits = logits - max_logits.unsqueeze(dim=-1) + + partition_vocab_size = logits.size()[-1] + vocab_start_index = rank * partition_vocab_size + vocab_end_index = (rank + 1) * partition_vocab_size + + # Create a mask of valid vocab ids (1 means it needs to be masked). + target_mask = (target < vocab_start_index) | (target >= vocab_end_index) + masked_target = target.clone() - vocab_start_index + masked_target[target_mask] = 0 + + logits_2d = logits.view(-1, partition_vocab_size) + masked_target_1d = masked_target.view(-1) + arange_1d = torch.arange(start=0, end=logits_2d.size()[0], + device=logits_2d.device) + predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] + predicted_logits_1d = predicted_logits_1d.clone().contiguous() + predicted_logits = predicted_logits_1d.view_as(target) + predicted_logits[target_mask] = 0.0 + + # All reduce is needed to get the chunks from other GPUs. + predicted_logits = all_reduce(predicted_logits, op="sum", comm=comm) + # Sum of exponential of logits along vocab dimension across all GPUs. + exp_logits = logits + torch.exp(logits, out=exp_logits) + sum_exp_logits = exp_logits.sum(dim=-1) + sum_exp_logits = all_reduce(sum_exp_logits, op="sum", comm=comm) + + # Loss = log(sum(exp(logits))) - predicted-logit. + loss = torch.log(sum_exp_logits.view(predicted_logits.shape)) - predicted_logits + + # Normalize and optionally smooth logits + exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + + vocab_size = exp_logits.size(-1) + if label_smoothing > 0: + """ + We'd like to assign 1 / (K - 1) probability mass to every index that is not the ground truth. + = (1 - alpha) * y_gt + alpha * mean(y_{i for i != gt}) + = (1 - alpha) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i + = ((K - 1) * (1 - alpha) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i + = (K * (1 - alpha) - 1) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i} y_i + = (1 - (alpha * K) / (K - 1)) * y_gt + ( (alpha * K) / (K - 1) ) * \sum_{i} y_i / K + From: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/losses/smoothed_cross_entropy.py + """ + assert 1.0 > label_smoothing > 0.0 + smoothing = label_smoothing * vocab_size / (vocab_size - 1) + + # Exp logits at this point are normalized probabilities. So we can just take the log to get log-probs. + log_probs = torch.log(exp_logits) + mean_log_probs = log_probs.mean(dim=-1) + loss = (1.0 - smoothing) * loss - smoothing * mean_log_probs + + ctx.label_smoothing, ctx.vocab_size = label_smoothing, vocab_size + ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) + + # Store softmax, target-mask and masked-target for backward pass. + ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) + + return loss + + @staticmethod + def backward(ctx, grad_output): + # Retreive tensors from the forward path. + softmax, target_mask, masked_target_1d = ctx.saved_tensors + label_smoothing, vocab_size = ctx.label_smoothing, ctx.vocab_size + + # All the inputs have softmax as thier gradient. + grad_input = softmax + # For simplicity, work with the 2D gradient. + partition_vocab_size = softmax.size()[-1] + grad_2d = grad_input.view(-1, partition_vocab_size) + + # Add the gradient from matching classes. + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], + device=grad_2d.device) + + softmax_update = 1.0 - target_mask.view(-1).float() + + if label_smoothing > 0: + smoothing = label_smoothing * vocab_size / (vocab_size - 1) + grad_2d[arange_1d, masked_target_1d] -= (1.0 - smoothing) * softmax_update + average_grad = 1 / vocab_size + grad_2d[arange_1d, :] -= smoothing * average_grad + else: + grad_2d[arange_1d, masked_target_1d] -= softmax_update + + grad_input.mul_(grad_output.flatten(0,1).unsqueeze(dim=-1)) + + return grad_input, None, None + + +def parallel_cross_entropy_func(logits, target, label_smoothing=0.0): + """ + Performs cross entropy loss when logits are split across tensor parallel ranks + + Arguments: + logits: logits split across tensor parallel ranks dimension is [batch * seq_len, hidden_size]. + target: correct vocab ids of dimseion [batch * seq_len]. + lobal_smoothing: smoothing factor, must be in range [0.0, 1.0). default is 0.0. + """ + out = ParallelCrossEntropyFunc.apply(logits.to(torch.float32), target, label_smoothing) + return out + diff --git a/bmtrain/nn/parallel_embedding.py b/bmtrain/nn/parallel_embedding.py new file mode 100644 index 00000000..cd567b4e --- /dev/null +++ b/bmtrain/nn/parallel_embedding.py @@ -0,0 +1,72 @@ +import torch +from torch.nn.parameter import Parameter +import torch.nn.functional as F +import math + +import bmtrain as bmt +from bmtrain.global_var import config +from bmtrain.distributed import all_reduce, all_gather +from .parallel_linear_func import OpParallelLinear + +class ParallelEmbedding(bmt.DistributedModule): + def __init__( + self, + vocab_size: int, + embedding_size: int, + dtype: torch.dtype = torch.half, + init_mean: float = 0.0, + init_std: float = 1, + ): + super().__init__() + + self.dim_model = embedding_size + assert vocab_size % config['tp_size'] == 0 + self.vocab_size_per_partition = vocab_size // config['tp_size'] + self.start_index = config['topology'].tp_id * self.vocab_size_per_partition + self.end_index = (config['topology'].tp_id+1) * self.vocab_size_per_partition + self.weight = bmt.DistributedParameter( + torch.empty(self.vocab_size_per_partition, embedding_size, dtype=dtype), + init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std), + tp_mode=True, + tp_split_dim=0, + ) + + def forward(self, ids: torch.Tensor, gather_input=True): + """ + Args: + ids (:obj:`torch.Tensor` of shape ``(batch_size, seq_len)``): Indices of input sequence tokens. + Return: + :obj:`torch.Tensor` of shape ``(batch_size, seq_len, embedding_size)``: The embedding output. + """ # noqa: E501 + + if config['tp_size'] > 1: + if gather_input: + ids = all_gather(ids, comm=config['tp_comm']) + input_mask = (ids < self.start_index) | (ids >= self.end_index) + ids = ids.clone() - self.start_index + ids[input_mask] = 0 + + embeds = F.embedding(ids, self.weight) + + if config['tp_size'] > 1: + embeds[input_mask, :] = 0.0 + embeds = all_reduce(embeds, op="sum", comm=config['tp_comm']) + embed_list = embeds.chunk(config['tp_size'], dim=0) + embeds = embed_list[config['topology'].tp_id].flatten(0,1) + + return embeds.clone() + + def projection(self, x: torch.Tensor): + """ + Projection based on embedding's weight. For example, embedding map vocab_size to embed_size, than projection map embed_size back to vocab_size. + Args: + x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_model)``): Input of projection + Returns: + :obj:`torch.Tensor` of shape ``(batch, seq_len, vocab_output_size)``: The projection output. + """ # noqa: E501 + gather_input = True + split_input = False + reduce_output_type = None + gather_output = False + out = OpParallelLinear.apply(x , self.weight, None, gather_input, gather_output, split_input, reduce_output_type) + return out diff --git a/bmtrain/nn/parallel_linear_func.py b/bmtrain/nn/parallel_linear_func.py new file mode 100644 index 00000000..55741530 --- /dev/null +++ b/bmtrain/nn/parallel_linear_func.py @@ -0,0 +1,108 @@ +import torch +import torch.nn.functional as F +from bmtrain.global_var import config +from ..distributed import all_gather, all_reduce +from .. import nccl +import bmtrain as bmt +from enum import Enum + +class ReduceType(Enum): + ALL_REDUCE = 1 + REDUCE_SCATTER = 2 + +def preprocess_input(input, gather_input, split_input): + if gather_input: + input = all_gather(input, config['tp_comm']) + input = input.flatten(0, 1) + + if split_input: + all_input_list = input.chunk(config['tp_size'], dim=-1) + input = all_input_list[config['topology'].tp_id] + return input + +class OpParallelLinear(torch.autograd.Function): + @staticmethod + def forward(ctx, input, weight, bias=None, gather_input=False, gather_output=False, split_input=False, reduce_output_type=None): + if reduce_output_type is not None: + reduce_output_type = ReduceType(reduce_output_type) + + ctx.save_for_backward(input, weight, bias) + ctx.gather_output = gather_output + ctx.split_input = split_input + ctx.gather_input = gather_input + ctx.reduce_output_type = reduce_output_type + + all_input = preprocess_input(input, ctx.gather_input, ctx.split_input) + out = F.linear(all_input, weight, bias) + + if gather_output: + all_output_list = all_gather(out, config['tp_comm']) + all_output_list = all_output_list.chunk(config['tp_size'], dim=0) + out = torch.cat(all_output_list, dim=all_output_list[0].dim()-1).flatten(0,1) + + if reduce_output_type is None: + return out + + if reduce_output_type == ReduceType.ALL_REDUCE: + nccl.allReduce(out.storage(), out.storage(), "sum", config['tp_comm']) + return out + + elif reduce_output_type == ReduceType.REDUCE_SCATTER: + shape = list(out.shape) + shape[0] = shape[0] // config['tp_size'] + reduce_out = torch.empty(shape, dtype=out.dtype, device=out.device) + nccl.reduceScatter(out.storage(), reduce_out.storage(), "sum", config['tp_comm']) + return reduce_out + else: + assert False, "no support reduce type{}".format(reduce_output_type) + + @staticmethod + def backward(ctx, grad_output): + input, weight, bias = ctx.saved_tensors + gather_output = ctx.gather_output + + if ctx.reduce_output_type == ReduceType.REDUCE_SCATTER: + grad_output = all_gather(grad_output, config['tp_comm']) + grad_output = grad_output.flatten(0, 1) + + if gather_output: + tp_size = config['tp_size'] + tp_id = config['topology'].tp_id + grad_output_list = grad_output.chunk(tp_size, dim=-1) + grad_output = grad_output_list[tp_id] + + grad_input = grad_weight = grad_bias = None + + if input.requires_grad or weight.requires_grad: + all_input = preprocess_input(input, ctx.gather_input, ctx.split_input) + + if input.requires_grad: + current_stream = torch.cuda.current_stream() + grad_all_input = grad_output.matmul(weight) + grad_input = torch.zeros_like(input) + if ctx.gather_input: + with torch.cuda.stream(config['tp_comm_stream']): + config['tp_comm_stream'].wait_stream(current_stream) + grad_input.record_stream(config['tp_comm_stream']) + grad_all_input.record_stream(config['tp_comm_stream']) + nccl.reduceScatter(grad_all_input.storage(), grad_input.storage(), "sum", config['tp_comm']) + else: + grad_input = grad_all_input + + if ctx.split_input: + with torch.cuda.stream(config['tp_comm_stream']): + config['tp_comm_stream'].wait_stream(current_stream) + grad_input.record_stream(config['tp_comm_stream']) + grad_input = all_gather(grad_input, config['tp_comm']) + + if weight.requires_grad: + dim = grad_output.dim() + grad_weight = grad_output.reshape(-1, + grad_output.shape[-1]).t().matmul(all_input.reshape(-1, all_input.shape[-1])) + + if bias is not None and bias.requires_grad: + grad_bias = grad_output.reshape(-1, grad_output.shape[-1]).sum(0) + + current_stream = torch.cuda.current_stream() + current_stream.wait_stream(config['tp_comm_stream']) + return grad_input, grad_weight, grad_bias, None, None, None, None diff --git a/bmtrain/nn/row_parallel_linear.py b/bmtrain/nn/row_parallel_linear.py new file mode 100644 index 00000000..bbe94cfa --- /dev/null +++ b/bmtrain/nn/row_parallel_linear.py @@ -0,0 +1,38 @@ +import torch +from torch.nn.parameter import Parameter + +import bmtrain as bmt +from bmtrain.global_var import config +from .parallel_linear_func import ( + OpParallelLinear, + ReduceType) + +class RowParallelLinear(bmt.DistributedModule): + def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None, split_input=False, all_reduce_output=False) -> None: + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.split_input = split_input + self.all_reduce_output = all_reduce_output + tp_size = config['tp_size'] + assert in_features % tp_size == 0 + self.in_features_per_partition = in_features // tp_size + self.weight = bmt.DistributedParameter(torch.empty(self.out_features, self.in_features_per_partition, dtype=dtype, device="cuda"), init_method=torch.nn.init.xavier_normal_, tp_split_dim=1, tp_mode=True) + if bias: + self.bias = bmt.DistributedParameter(torch.empty(self.out_features, dtype=dtype, device="cuda"), init_method=torch.nn.init.zeros_, tp_split_dim=-1, tp_mode=True) + else: + self.register_parameter('bias', None) + + def forward(self, input): + gather_input = self.split_input + gather_output = False + reduce_output_type = ReduceType.ALL_REDUCE if self.all_reduce_output else ReduceType.REDUCE_SCATTER + out = OpParallelLinear.apply(input, self.weight, None, gather_input, gather_output, self.split_input, reduce_output_type) + if self.bias is not None: + out = out + self.bias + return out + + def extra_repr(self) -> str: + return 'in_features={}, out_features={}, bias={}'.format( + self.in_features_per_partition, self.out_features, self.bias is not None + ) diff --git a/bmtrain/param_init.py b/bmtrain/param_init.py index 8b74c580..d5c86225 100644 --- a/bmtrain/param_init.py +++ b/bmtrain/param_init.py @@ -13,17 +13,27 @@ def init_distributed_parameter(params : Iterable[torch.nn.Parameter]): continue with torch.no_grad(): partition_size = param.storage().size() - global_size = partition_size * config['world_size'] - + global_size = partition_size * config['tp_zero_size'] * config['tp_size'] tmp_storage = param.storage_type()(global_size) tmp_tensor = torch.tensor([], dtype=param.dtype, device="cuda") - tmp_tensor.set_(tmp_storage, 0, param._original_shape) + tmp_tensor.set_(tmp_storage, 0, param._tp_original_shape) param._init_method(tmp_tensor) + if param._tp_mode and param._tp_split_dim >= 0: + tensor_list = tmp_tensor.chunk(config['tp_size'], dim=param._tp_split_dim) + sub_tensor = tensor_list[config['topology'].tp_id].contiguous() + tmp_tensor = torch.empty(sub_tensor.shape, device=param.device, dtype=sub_tensor.dtype) + tmp_tensor.copy_(sub_tensor) + + if param._tp_mode: + begin = config['tp_zero_rank'] + else: + begin = config['zero_rank'] + end = begin + 1 # Pytorch 1.11 changed the API of storage.__getitem__ torch.tensor([], dtype=param.dtype, device=param.device).set_(param.storage())[:] = \ - torch.tensor([], dtype=param.dtype, device=param.device).set_(tmp_storage)[partition_size * config['rank'] : partition_size * (config['rank'] + 1)] + torch.tensor([], dtype=param.dtype, device=param.device).set_(tmp_tensor.storage())[partition_size * begin : partition_size * end] # param.storage().copy_(tmp_storage[partition_size * config['rank'] : partition_size * (config['rank'] + 1)]) def iterate_parameters(model : torch.nn.Module): diff --git a/bmtrain/parameter.py b/bmtrain/parameter.py index f965cdac..b5e5e9ae 100644 --- a/bmtrain/parameter.py +++ b/bmtrain/parameter.py @@ -3,6 +3,7 @@ from .utils import round_up from .global_var import config from . import nccl +from .distributed import all_gather class DistributedParameter(torch.nn.Parameter): r""" @@ -31,7 +32,9 @@ def __new__(cls, data : torch.Tensor, requires_grad : bool = True, init_method : Optional[Callable[['DistributedParameter'], None]] = None, - group : Optional[str] = None + group : Optional[str] = None, + tp_mode : bool = False, + tp_split_dim : int = -1, ): if not config["initialized"]: raise RuntimeError("BMTrain is not initialized") @@ -39,14 +42,24 @@ def __new__(cls, num_of_elements = data.numel() cuda_tensor = torch.tensor([], dtype=data.dtype, device="cuda") - cuda_storage_size = round_up(num_of_elements, config["world_size"]) // config["world_size"] + if tp_mode: + comm = config['tp_zero_comm'] + else: + comm = config['zero_comm'] + world_size = nccl.commCount(comm) + rank = nccl.commRank(comm) + cuda_storage_size = round_up(num_of_elements, world_size) // world_size original_shape = data.size() + tp_original_shape = original_shape + if tp_mode and tp_split_dim >= 0: + tp_original_shape = list(original_shape) + tp_original_shape[tp_split_dim] *= config['tp_size'] cuda_storage = cuda_tensor.storage_type()(cuda_storage_size) - start_of_partition = cuda_storage_size * config["rank"] - end_of_partition = min(num_of_elements, cuda_storage_size * (config["rank"] + 1)) + start_of_partition = cuda_storage_size * rank + end_of_partition = min(num_of_elements, cuda_storage_size * (rank + 1)) # FX: cuda_tensor_size < 0 if num_of_elements is too small cuda_tensor_size = max(end_of_partition - start_of_partition, 0) @@ -60,7 +73,12 @@ def __new__(cls, setattr(ret, "_end_partition", end_of_partition) setattr(ret, "_init_method", init_method) setattr(ret, "_in_checkpoint_block", False) - setattr(ret, "_group", group) + setattr(ret, "_group", group if not tp_mode else "tp") + + setattr(ret, "_tp_mode", tp_mode) + setattr(ret, "_zero_comm", comm) + setattr(ret, "_tp_split_dim", tp_split_dim) + setattr(ret, "_tp_original_shape", tp_original_shape) return ret @property @@ -83,24 +101,52 @@ def gather(self) -> torch.Tensor: current_stream.wait_stream(config['load_stream']) return output_tensor + def gather_all(self) -> torch.tensor: + zero_param = self.gather() + if config['tp_size'] > 1 and self._tp_split_dim >= 0: + output_tensor = all_gather(zero_param, config['tp_comm']) + if self._tp_split_dim == 1: + output_list = output_tensor.chunk(config['tp_size'], dim=0) + output = torch.cat(output_list, dim=output_list[0].dim()-1).flatten(0,1) + return output + else: + return output_tensor.flatten(0,1) + else: + return zero_param + + def tp_gather(self) -> torch.tensor: + if config['tp_size'] > 1 and self._tp_split_dim >= 0: + output_tensor = all_gather(self, config['tp_comm']) + if self._tp_split_dim == 1: + output_list = output_tensor.chunk(config['tp_size'], dim=0) + output = torch.cat(output_list, dim=output_list[0].dim()-1).flatten(0,1) + return output + else: + return output_tensor.flatten(0,1) + else: + return self + def _copy_data(self, data : torch.Tensor): self.data.copy_(data.view(-1)[self._start_partition : self._end_partition]) - class OpAllGather(torch.autograd.Function): @staticmethod def forward(ctx, value : DistributedParameter): assert isinstance(value, DistributedParameter) + comm = value._zero_comm #config['zero_comm'] + world_size = nccl.commCount(comm) + ctx.comm = comm + ctx.world_size = world_size partition_size = value.storage().size() - global_size = partition_size * config['world_size'] + global_size = partition_size * world_size storage = value.storage_type()(global_size) nccl.allGather( value.storage(), storage, - config['comm'] + comm ) output_tensor = torch.tensor([], dtype=value.dtype, device="cuda") @@ -117,15 +163,15 @@ def backward(ctx, grad_output : torch.Tensor): grad_storage = grad_output.storage_type()(ctx.partition_size) grad_output_storage = grad_output.storage() - if grad_output_storage.size() == ctx.partition_size * config['world_size']: + if grad_output_storage.size() == ctx.partition_size * ctx.world_size: pass else: - grad_output_storage.resize_(ctx.partition_size * config['world_size']) + grad_output_storage.resize_(ctx.partition_size * ctx.world_size) nccl.reduceScatter( grad_output_storage, grad_storage, 'sum', - config['comm'] + ctx.comm ) grad_tensor = torch.tensor([], dtype=grad_output.dtype, device="cuda") grad_tensor.set_(grad_storage, 0, (ctx.tensor_size,)) diff --git a/bmtrain/pipe_layer.py b/bmtrain/pipe_layer.py index 0a34ac46..c28ab87d 100644 --- a/bmtrain/pipe_layer.py +++ b/bmtrain/pipe_layer.py @@ -139,7 +139,11 @@ def forward(ctx, input, stage_id): def backward(ctx, grad_outputs): if not ctx.is_first_stage: send_data = grad_outputs[0] if isinstance(grad_outputs, tuple) else grad_outputs - send_activations(send_data, ctx.stage_id - 1, config['pipe_comm']) + current_stream = torch.cuda.current_stream() + with torch.cuda.stream(config['pp_comm_stream']): + config['pp_comm_stream'].wait_stream(current_stream) + send_data.record_stream(current_stream) + send_activations(send_data, ctx.stage_id - 1, config['pipe_comm']) return grad_outputs, None class StagePostFunction(torch.autograd.Function): @@ -150,7 +154,11 @@ def forward(ctx, outputs, stage_id): ctx.is_last_stage = stage_id == config['pipe_size'] - 1 if not ctx.is_last_stage: send_data = outputs[0] if isinstance(outputs, tuple) else outputs - send_activations(send_data.detach(), stage_id + 1, config['pipe_comm']) + current_stream = torch.cuda.current_stream() + with torch.cuda.stream(config['pp_comm_stream']): + config['pp_comm_stream'].wait_stream(current_stream) + send_data.record_stream(current_stream) + send_activations(send_data.detach(), stage_id + 1, config['pipe_comm']) return outputs @staticmethod @@ -189,7 +197,6 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1) -> None: rank = config['rank'] topo = config['topology'] self.layer_ids = [] - pipe_group = topo.pp_group self.stages = topo.stages self.stage_id = topo.stage_id self.pipe_idx = topo.pipe_idx @@ -218,11 +225,6 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1) -> None: self._modules[str(self.layer_ids[0])]._is_first_layer = True self._modules[str(self.layer_ids[-1])]._is_last_layer = True - self.partition_modules(self.layer_ids) - self.next_rank = pipe_group[self.pipe_idx, self.stage_id + 1] if self.stage_id < config['pipe_size'] - 1 else -1 - self.prev_rank = pipe_group[self.pipe_idx, self.stage_id - 1] if self.stage_id > 0 else -1 - # self.micro_batches = config['num_micro_batches'] - self.save_list = [(i, i) for i in range(len(self.layer_ids))] def __len__(self) -> int: @@ -295,76 +297,6 @@ def get_stage_by_layer_id(self, layer_id : int) -> int: else: return rest + (layer_id - rest * (part_len+1)) // part_len - def partition_modules(self, idxs) -> None: - for i in range(len(self)): - contiguous_params = {} - for kw, val in self[i]._storage_info.items(): - storage_type = val["storage_type"] - contiguous_params[kw] = storage_type(round_up(val["total"], config["world_size"] // config["pipe_size"])) - nccl.allGather( - self[i]._storage_params[kw].storage(), - contiguous_params[kw], - config["comm"] - ) - - if i not in idxs: - for name, param in self[i]._module.named_parameters(): - param.data = torch.tensor([], dtype = param.dtype, device = param.device) - for kw, val in self[i]._storage_info.items(): - val["begin"] = self.stage_id - val["end"] = self.stage_id + 1 - val["partition_size"] = 1 - val["total"] = val["world_size"] - dtype = self[i]._storage_params[kw].dtype - device = self[i]._storage_params[kw].device - self[i]._storage_params[kw] = \ - torch.nn.Parameter(torch.tensor([0], dtype = dtype, device=device)) - else: - for kw, val in self[i]._storage_info.items(): - storage_type = val["storage_type"] - val["world_size"] = config["world_size"] // config["pipe_size"] - partition_size = round_up(val["total"], val["world_size"]) // val["world_size"] - val["partition_size"] = partition_size - val["begin"] = config['zero_rank'] * partition_size - val["end"] = (config['zero_rank'] + 1) * partition_size - storage_param_buffer = storage_type(partition_size) - dtype = storage_param_buffer.dtype - device = storage_param_buffer.device - self[i]._storage_params[kw] = torch.nn.Parameter( - torch.tensor([], dtype=dtype, device=device).set_(storage_param_buffer) - ) - if val["requires_grad"]: - self[i]._storage_params[kw].requires_grad_(True) - else: - self[i]._storage_params[kw].requires_grad_(False) - ordered_parameters = list(self[i]._module.named_parameters()) - for idx, named_param in enumerate(ordered_parameters): - name, param = named_param - param_info = self[i]._param_info[idx] - kw_name = _get_param_kw(param) - storage_info = self[i]._storage_info[kw_name] - storage_st = storage_info["begin"] - storage_end = storage_info["end"] - param_st = param_info["offset"] - param_end = param_st + param_info["size"] - if not (param_st >= storage_end or param_end <= storage_st): - # copy offset in parameter storage - offset_st = max(storage_st - param_st, 0) - offset_end = min(storage_end - param_st, param_info["size"]) - assert offset_st < offset_end - to_offset_st = offset_st + param_st - storage_st - to_offset_end = offset_end + param_st - storage_st - d_dtype = self[i]._storage_params[kw_name].dtype - d_device = self[i]._storage_params[kw_name].device - param.data = torch.tensor([], dtype=param.dtype, device=param.device).set_(self[i]._storage_params[kw_name].storage(), to_offset_st, (to_offset_end - to_offset_st,)) - param_info["begin"] = to_offset_st - param_info["end"] = (to_offset_end - to_offset_st,) - param.data[:] = \ - torch.tensor([], dtype=d_dtype, device=d_device).set_(contiguous_params[kw], storage_st+to_offset_st, (to_offset_end - to_offset_st,))[:] - else: - param.data = torch.tensor([], dtype=param.dtype, device=param.device) - del contiguous_params - def _save_to_state_dict(self, destination, prefix, keep_vars): for name, module in self._modules.items(): idx = int(name) diff --git a/bmtrain/utils.py b/bmtrain/utils.py index a5687c7d..8cb87808 100644 --- a/bmtrain/utils.py +++ b/bmtrain/utils.py @@ -107,6 +107,13 @@ def see_memory(message, detail=False): """) torch.cuda.reset_peak_memory_stats() +def tp_split_tensor(tensor, split_dim): + tensor_list = tensor.chunk(config['tp_size'], dim=split_dim) + sub_tensor = tensor_list[config['topology'].tp_id].contiguous() + tmp_tensor = torch.empty(sub_tensor.shape, device=sub_tensor.device, dtype=sub_tensor.dtype) + tmp_tensor.copy_(sub_tensor) + return tmp_tensor + class AverageRecorder: """A utility class to record the average value of a quantity over time. diff --git a/example/layers/attention.py b/example/layers/attention.py index 243df3ea..8fbb7510 100644 --- a/example/layers/attention.py +++ b/example/layers/attention.py @@ -1,8 +1,14 @@ from typing import Optional import torch import bmtrain as bmt -from bmtrain.nn import Linear +from bmtrain.nn import ( + Linear, + ColumnParallelLinear, + RowParallelLinear, +) import math +from bmtrain.global_var import config +from bmtrain.distributed import all_gather class Attention(bmt.DistributedModule): def __init__(self, @@ -12,11 +18,17 @@ def __init__(self, ) -> None: super().__init__() - self.project_q = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) - self.project_k = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) - self.project_v = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + if config['tp_size'] > 1: + self.project_q = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype, gather_input=False) + self.project_k = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype, gather_input=False) + self.project_v = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype, gather_input=False) + self.project_out = RowParallelLinear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype) + else: + self.project_q = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + self.project_k = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + self.project_v = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + self.project_out = Linear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype) - self.project_out = Linear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype) self.softmax = torch.nn.Softmax(dim=-1) self.num_heads = num_heads @@ -32,32 +44,48 @@ def forward(self, batch_size, seq_q, dim_model = hidden_q.size() seq_kv = hidden_kv.size(1) - h_q : torch.Tensor = self.project_q(hidden_q) - h_k : torch.Tensor = self.project_k(hidden_kv) - h_v : torch.Tensor = self.project_v(hidden_kv) + assert hidden_q.data_ptr() == hidden_kv.data_ptr() - h_q = h_q.view(batch_size, seq_q, self.num_heads, self.dim_head) - h_k = h_k.view(batch_size, seq_kv, self.num_heads, self.dim_head) - h_v = h_v.view(batch_size, seq_kv, self.num_heads, self.dim_head) + hidden_q = bmt.nn.OpParallelLinear.apply( + hidden_q, + torch.cat([self.project_q.weight, self.project_k.weight, self.project_v.weight], dim=0), + torch.cat([self.project_q.bias, self.project_k.bias, self.project_v.bias], dim=0), + True, False, + False, None + ) + + h_q, h_k, h_v = hidden_q.chunk(3, dim=-1) + + if config['tp_size'] > 1: + #batch_size will changed in TensorParallel + batch_size = h_v.shape[0] + + h_q = h_q.view(batch_size, seq_q, -1, self.dim_head) + h_k = h_k.view(batch_size, seq_kv, -1, self.dim_head) + h_v = h_v.view(batch_size, seq_kv, -1, self.dim_head) h_q = h_q.permute(0, 2, 1, 3).contiguous() h_k = h_k.permute(0, 2, 1, 3).contiguous() h_v = h_v.permute(0, 2, 1, 3).contiguous() - h_q = h_q.view(batch_size * self.num_heads, seq_q, self.dim_head) - h_k = h_k.view(batch_size * self.num_heads, seq_kv, self.dim_head) - h_v = h_v.view(batch_size * self.num_heads, seq_kv, self.dim_head) + h_q = h_q.view(-1, seq_q, self.dim_head) + h_k = h_k.view(-1, seq_kv, self.dim_head) + h_v = h_v.view(-1, seq_kv, self.dim_head) score = torch.bmm( h_q, h_k.transpose(1, 2) ) score = score / math.sqrt(self.dim_head) - score = score.view(batch_size, self.num_heads, seq_q, seq_kv) + score = score.view(batch_size, -1, seq_q, seq_kv) if position_bias is not None: - score = score + position_bias.view(batch_size, self.num_heads, seq_q, seq_kv) - + score = score + position_bias.view(batch_size, -1, seq_q, seq_kv) + + if config['tp_size'] > 1: + with torch.no_grad(): + mask = all_gather(mask, config['tp_comm']).flatten(0,1) + score = torch.where( mask.view(batch_size, 1, seq_q, seq_kv), score, @@ -70,14 +98,14 @@ def forward(self, torch.scalar_tensor(0, device=score.device, dtype=score.dtype) ) - score = score.view(batch_size * self.num_heads, seq_q, seq_kv) + score = score.view(-1, seq_q, seq_kv) h_out = torch.bmm( score, h_v ) - h_out = h_out.view(batch_size, self.num_heads, seq_q, self.dim_head) + h_out = h_out.view(batch_size, -1, seq_q, self.dim_head) h_out = h_out.permute(0, 2, 1, 3).contiguous() - h_out = h_out.view(batch_size, seq_q, self.num_heads * self.dim_head) + h_out = h_out.view(batch_size, seq_q, -1) attn_out = self.project_out(h_out) return attn_out diff --git a/example/layers/embedding.py b/example/layers/embedding.py index 13c47384..f62151c4 100644 --- a/example/layers/embedding.py +++ b/example/layers/embedding.py @@ -77,11 +77,13 @@ def from_pretrained(cls, embeddings, freeze=True, padding_idx=None, def forward(self, input: torch.Tensor, projection : bool = False) -> torch.Tensor: if not projection: - return F.embedding( + out = F.embedding( input, self.weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse) + return out else: - return F.linear(input, self.weight) / math.sqrt(self.embedding_dim) + out = F.linear(input, self.weight) + return out def extra_repr(self) -> str: s = '{num_embeddings}, {embedding_dim}' @@ -97,4 +99,4 @@ def extra_repr(self) -> str: s += ', sparse=True' return s.format(**self.__dict__) - \ No newline at end of file + diff --git a/example/layers/feedforward.py b/example/layers/feedforward.py index 99d2dc3b..e88d2495 100644 --- a/example/layers/feedforward.py +++ b/example/layers/feedforward.py @@ -1,16 +1,23 @@ import torch import bmtrain as bmt -from bmtrain.nn import Linear +from bmtrain.nn import ( + Linear, + ColumnParallelLinear, + RowParallelLinear) +from bmtrain.global_var import config class Feedforward(bmt.DistributedModule): def __init__(self, dim_model : int, dim_ff : int, bias : bool = True, dtype = None) -> None: super().__init__() - self.w_in = Linear(dim_model, dim_ff, bias = bias, dtype=dtype) - self.w_out = Linear(dim_ff, dim_model, bias = bias, dtype=dtype) + if config['tp_size'] > 1: + self.w_in = ColumnParallelLinear(dim_model, dim_ff, bias = bias, dtype=dtype) + self.w_out = RowParallelLinear(dim_ff, dim_model, bias = bias, dtype=dtype) + else: + self.w_in = Linear(dim_model, dim_ff, bias=bias, dtype=dtype) + self.w_out = Linear(dim_ff, dim_model, bias=bias, dtype=dtype) self.relu = torch.nn.ReLU() def forward(self, input : torch.Tensor) -> torch.Tensor: - return self.w_out(self.relu(self.w_in(input))) diff --git a/example/models/gpt.py b/example/models/gpt.py index 78d77a7d..64474ba8 100644 --- a/example/models/gpt.py +++ b/example/models/gpt.py @@ -1,6 +1,7 @@ import torch import bmtrain as bmt from layers import TransformerEncoder, Layernorm, Embedding, TransformerEncoder +from bmtrain.global_var import config class GPT(bmt.DistributedModule): def __init__(self, @@ -13,7 +14,10 @@ def __init__(self, self.max_distance = max_distance - self.word_emb = Embedding(vocab_size, dim_model, dtype=dtype) + if config['tp_size'] > 1: + self.word_emb = bmt.nn.ParallelEmbedding(vocab_size, dim_model, dtype=dtype) + else: + self.word_emb = Embedding(vocab_size, dim_model, dtype=dtype) self.pos_emb = Embedding(max_distance, dim_model, dtype=dtype) self.transformers = bmt.TransformerBlockList([ @@ -42,7 +46,10 @@ def forward(self, out = self.transformers(out, mask_2d, None) out = self.layernorm(out) - logits = self.word_emb(out, projection=True) + if config['tp_size'] > 1: + logits = self.word_emb.projection(out) + else: + logits = self.word_emb(out, projection=True) bmt.inspect.record_tensor(logits, "logits") - return logits \ No newline at end of file + return logits diff --git a/example/train.py b/example/train.py index 1a744e20..8aaf65e4 100644 --- a/example/train.py +++ b/example/train.py @@ -3,12 +3,13 @@ from models import GPT import time from bmtrain import optim +from bmtrain.global_var import config from bmtrain import inspect def main(): bmt.init_distributed( seed=0, - zero_level=2, + tp_size=2, ) model = GPT( @@ -24,7 +25,6 @@ def main(): ) bmt.init_parameters(model) - # print_inspect(model, "*") bmt.print_rank("Model memory") bmt.print_rank(torch.cuda.memory_summary()) @@ -52,7 +52,11 @@ def main(): if i == bmt.rank(): break - loss_func = torch.nn.CrossEntropyLoss(ignore_index=-100) + if config['tp_size'] > 1: + loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, parallel=True) + else: + loss_func = torch.nn.CrossEntropyLoss(ignore_index=-100) + optimizer = optim.AdamOffloadOptimizer(model.parameters(), weight_decay=1e-2) lr_scheduler = bmt.lr_scheduler.Noam(optimizer, start_lr=1e-3, warmup_iter=40, end_iter=1000, num_iter=0) @@ -77,7 +81,10 @@ def main(): ) batch, seq_len, vocab_out_size = logits.size() - loss = loss_func(logits.view(batch * seq_len, vocab_out_size), targets.view(batch * seq_len)) + if config['tp_size'] > 1: + loss = loss_func(logits.view(batch * seq_len, vocab_out_size), targets) + else: + loss = loss_func(logits.float().view(batch * seq_len, vocab_out_size), targets.view(batch * seq_len)) global_loss = bmt.sum_loss(loss).item() diff --git a/tests/test_all.py b/tests/test_all.py index 6682aa93..aa382676 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -25,6 +25,8 @@ ("send_recv", 4), ("nccl_backward", 4), ("no_grad", 1), + ("column_parallel_linear", 2), + ("row_parallel_linear", 2), ("training", 4), ]) diff --git a/tests/test_column_parallel_linear.py b/tests/test_column_parallel_linear.py new file mode 100644 index 00000000..1c49570a --- /dev/null +++ b/tests/test_column_parallel_linear.py @@ -0,0 +1,55 @@ +import torch +import bmtrain as bmt +from bmtrain.global_var import config +import numpy as np + +def run_bmt(x, gather_output, ckp_path, tp_size=2): + linear = bmt.nn.ColumnParallelLinear(8,8, gather_output=gather_output) + linear = bmt.CheckpointBlock(linear) + bmt.init_parameters(linear) + y = linear(x[config['topology'].tp_id]) + y.sum().backward() + bmt.save(linear, ckp_path) + bmt.synchronize() + return y, linear._parameters['weight'].grad, linear._parameters['bias'].grad + +def run_torch(x, ckp_path): + linear = torch.nn.Linear(8, 8) + linear_dict = torch.load(ckp_path) + linear.load_state_dict(linear_dict) + linear = linear.cuda() + linear.weight.requires_grad_() + y = linear(x) + y.sum().backward() + return y, linear.weight.grad, linear.bias.grad + +def run(gather_output, ckp_path): + tp_size = bmt.config['tp_size'] + torch.cuda.manual_seed(100) + x = torch.randn(tp_size, 8,8, device='cuda').requires_grad_() + y1, weight_grad1, bias_grad1 = run_bmt(x, gather_output, ckp_path) + y2, weight_grad2, bias_grad2 = run_torch(x, ckp_path) + tp_rank = config['topology'].tp_id + if gather_output: + assert np.allclose(y1.detach().cpu().numpy(), y2.flatten(0,1).detach().cpu().numpy()) + else: + torch_out_list = torch.split(y2, y2.size()[-1] // tp_size, dim=y2.dim()-1) + assert np.allclose(y1.detach().cpu().numpy(), torch_out_list[tp_rank].flatten(0,1).detach().cpu().numpy()) + + weight_grad_list = weight_grad2.chunk(tp_size, dim=0) + assert np.allclose(weight_grad1.reshape(weight_grad_list[tp_rank].shape).cpu().numpy(), weight_grad_list[tp_rank].cpu().numpy()) + + bias_grad_list = bias_grad2.chunk(tp_size, dim=0) + assert np.allclose(bias_grad1.reshape(bias_grad_list[tp_rank].shape).cpu().numpy(), bias_grad_list[tp_rank].cpu().numpy()) + +def test_gather_output(): + run(True, 'linear.ckp') + +def test_no_gather_output(): + run(False, 'linear_no_gather.ckp') + +if __name__ == "__main__": + bmt.init_distributed(tp_size=2) + test_gather_output() + test_no_gather_output() + diff --git a/tests/test_row_parallel_linear.py b/tests/test_row_parallel_linear.py new file mode 100644 index 00000000..f89b6dc5 --- /dev/null +++ b/tests/test_row_parallel_linear.py @@ -0,0 +1,54 @@ +import torch +import bmtrain as bmt +from bmtrain.global_var import config +import numpy as np + +def run_bmt(x, ckp_path, split_input=True, use_checkpoint_block=True): + linear = bmt.nn.RowParallelLinear(8,8, split_input=split_input, all_reduce_output=True) + if use_checkpoint_block: + linear = bmt.CheckpointBlock(linear) + bmt.init_parameters(linear) + y = linear(x) + y.sum().backward() + bmt.save(linear, ckp_path) + bmt.synchronize() + return y, linear._parameters['weight'].grad, linear._parameters['bias'].grad + +def run_torch(x, ckp_path): + linear = torch.nn.Linear(8, 8) + linear_dict = torch.load(ckp_path) + linear.load_state_dict(linear_dict) + linear = linear.cuda() + linear.weight.requires_grad_() + y = linear(x) + y.sum().backward() + return y, linear.weight.grad, linear.bias.grad + +def run(split_input, use_checkpoint_block, ckp_path): + tp_size = bmt.config['tp_size'] + torch.cuda.manual_seed(100) + tp_rank = config['topology'].tp_id + x = torch.randn(8,8, device='cuda').requires_grad_() + rank_x = x.chunk(tp_size, dim=0 if split_input else 1)[tp_rank] + y1, weight_grad1, bias_grad1 = run_bmt(rank_x, ckp_path, split_input, use_checkpoint_block) + y2, weight_grad2, bias_grad2 = run_torch(x, ckp_path) + np.testing.assert_allclose(y1.detach().cpu().numpy(), y2.detach().cpu().numpy(), atol=1e-5) + + weight_grad_list = weight_grad2.chunk(tp_size, dim=1) + assert np.allclose(weight_grad1.reshape(weight_grad_list[tp_rank].shape).cpu().numpy(), weight_grad_list[tp_rank].cpu().numpy()) + + assert np.allclose(bias_grad1.cpu().numpy(), bias_grad2.cpu().numpy()) + +def test_split_input(): + run(True, False, 'row_parallel_linear.ckp') + run(True, True, 'row_parallel_linear.ckp') + +def test_no_split_input(): + run(False, False, 'row_parallel_linear_no_split.ckp') + run(False, True, 'row_parallel_linear_no_split.ckp') + +if __name__ == "__main__": + bmt.init_distributed(tp_size=2) + test_no_split_input() + test_split_input() +