diff --git a/deepspeed/runtime/pipe/module.py b/deepspeed/runtime/pipe/module.py index 31144f7a69b4..af4df2a10ebf 100644 --- a/deepspeed/runtime/pipe/module.py +++ b/deepspeed/runtime/pipe/module.py @@ -253,7 +253,7 @@ def _build(self): # All pipeline parameters should be considered as model parallel in the context # of our FP16 optimizer for p in self.parameters(): - p.model_parallel = True + p.ds_pipe_replicated = False def _count_layer_params(self): """Count the trainable parameters in individual layers. @@ -472,7 +472,7 @@ def _index_tied_modules(self): # Only count the tied module once in the eyes of the FP16 optimizer if self.global_rank != tied_ranks[0]: for p in self.tied_modules[key].parameters(): - p.model_parallel = False + p.ds_pipe_replicated = True ''' if len(tied_comms) > 0: print(f'RANK={self.global_rank} tied_comms={tied_comms}') diff --git a/deepspeed/runtime/pipe/p2p.py b/deepspeed/runtime/pipe/p2p.py index 13a448484061..e129d3d5b177 100644 --- a/deepspeed/runtime/pipe/p2p.py +++ b/deepspeed/runtime/pipe/p2p.py @@ -8,12 +8,22 @@ import torch import torch.distributed as dist +# To query whether we have send/recv support +from packaging.version import Version +from deepspeed.git_version_info import torch_info + _groups = None _grid = None _async = [] +def can_send_recv() -> bool: + torch_version = Version(torch_info['version']) + sendrecv_min = Version('1.8') + return torch_version >= sendrecv_min + + #initializes adjacent process groups #run this only after torch.distributed.init_process_group() has been called def init_process_groups(grid): @@ -22,6 +32,9 @@ def init_process_groups(grid): assert _grid.pipe_parallel_size > 1, "There is no pipeline parallelism" + if not can_send_recv(): + _groups = [dist.new_group(ranks=group) for group in _grid.p2p_groups] + def _is_valid_send_recv(src_stage, dest_stage): first_stage = 0 @@ -34,7 +47,7 @@ def _is_valid_send_recv(src_stage, dest_stage): def send(tensor, dest_stage, async_op=False): global _groups - #assert async_op == False, "Doesnt support async_op true" + assert async_op == False, "Doesnt support async_op true" src_stage = _grid.get_stage_id() _is_valid_send_recv(src_stage, dest_stage) @@ -44,12 +57,18 @@ def send(tensor, dest_stage, async_op=False): op = dist.isend(tensor, dest_rank) _async.append(op) else: - return dist.send(tensor, dest_rank) + + if can_send_recv(): + return dist.send(tensor, dest_rank) + else: + group = _get_send_recv_group(src_stage, dest_stage) + src_rank = _grid.stage_to_global(stage_id=src_stage) + return dist.broadcast(tensor, src_rank, group=group, async_op=async_op) def recv(tensor, src_stage, async_op=False): global _groups - #assert async_op == False, "Doesnt support async_op true" + assert async_op == False, "Doesnt support async_op true" dest_stage = _grid.get_stage_id() _is_valid_send_recv(src_stage, dest_stage) @@ -60,7 +79,11 @@ def recv(tensor, src_stage, async_op=False): op = dist.irecv(tensor, src_rank) _async.append(op) else: - return dist.recv(tensor, src_rank) + if can_send_recv(): + return dist.recv(tensor, src_rank) + else: + group = _get_send_recv_group(src_stage, dest_stage) + return dist.broadcast(tensor, src_rank, group=group, async_op=async_op) def wait(): @@ -135,3 +158,27 @@ def _to(x): msg = _to(msg) return msg + + +def _get_send_recv_group(src_stage, dest_stage): + '''the group id is always the smaller rank unless its a wrap around''' + + stage_id = None + + first_stage = 0 + last_stage = _grid.pipe_parallel_size - 1 + + if (src_stage == first_stage and dest_stage == last_stage + or dest_stage == first_stage and src_stage == last_stage): + stage_id = last_stage + elif src_stage > dest_stage: + stage_id = dest_stage + else: + stage_id = src_stage + '''group_id corresponds to group of [group_id, group_id+1] + unless group_id is the rank of the last stage + in which case group_id correspods to group[group_id-num_stages+1, group_id] + ''' + group_id = _grid.stage_to_global(stage_id=stage_id) + + return _groups[group_id] diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index f3a4c4c0f1b3..8bf8abe21b33 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -49,6 +49,10 @@ def set_random_seed(seed): torch.manual_seed(seed) +def is_model_parallel_parameter(p) -> bool: + return hasattr(p, 'model_parallel') and p.model_parallel + + def bwc_tensor_model_parallel_rank(mpu=None): """Backwards-compatible way of querying the tensor model parallel rank from an ``mpu`` object. @@ -269,15 +273,19 @@ def get_grad_norm(parameters, norm_type=2, mpu=None): total_norm = total_norm_cuda[0].item() else: total_norm = 0. + tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=mpu) for p in parameters: - if mpu is not None: - if (mpu.get_model_parallel_rank() == 0 - ) or is_model_parallel_parameter(p): - param_norm = p.grad.data.float().norm(norm_type) - total_norm += param_norm.item()**norm_type - else: - param_norm = p.grad.data.float().norm(norm_type) - total_norm += param_norm.item()**norm_type + # Pipeline parallelism may replicate parameters. Avoid multi-counting. + if hasattr(p, 'ds_pipe_replicated') and p.ds_pipe_replicated: + continue + + # Filter to avoid over-counting replicated tensors from tensor + # model parallelism + if (tensor_mp_rank > 0) and not is_model_parallel_parameter(p): + continue + + param_norm = p.grad.data.float().norm(norm_type) + total_norm += param_norm.item()**norm_type # Sum across all model parallel GPUs. total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) @@ -294,6 +302,48 @@ def get_grad_norm(parameters, norm_type=2, mpu=None): return total_norm +def get_grad_zeros(parameters, mpu=None): + """Compute the number of grads with zero values. + + This is adapted from get_grad_norm + + Arguments: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + + Returns: + Total number of params with zero values (viewed as a single vector). + """ + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + + total_zeros = 0. + tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=mpu) + for p in parameters: + # Pipeline parallelism may replicate parameters. Avoid multi-counting. + if hasattr(p, 'ds_pipe_replicated') and p.ds_pipe_replicated: + continue + + # Filter to avoid over-counting replicated tensors from tensor + # model parallelism + if (tensor_mp_rank > 0) and not is_model_parallel_parameter(p): + continue + + count_zeros = p.grad.numel() - torch.count_nonzero(p.grad) + total_zeros += count_zeros.item() + + # Sum across all model parallel GPUs. + total_zeros_cuda = torch.cuda.FloatTensor([float(total_zeros)]) + if mpu is not None: + torch.distributed.all_reduce(total_zeros_cuda, + op=torch.distributed.ReduceOp.SUM, + group=mpu.get_model_parallel_group()) + total_zeros = total_zeros_cuda[0].item() + + return total_zeros + + def get_weight_norm(parameters, norm_type=2, mpu=None): """Clips gradient norm of an iterable of parameters. @@ -326,24 +376,19 @@ def get_weight_norm(parameters, norm_type=2, mpu=None): total_norm = total_norm_cuda[0].item() else: total_norm = 0. + tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=mpu) for p in parameters: - if mpu is not None: - if (mpu.get_model_parallel_rank() == 0 - ) or is_model_parallel_parameter(p): - try: - param_norm = float(torch.norm(p, norm_type, dtype=torch.float32)) - except TypeError as err: - param_norm = float(torch.norm(p.float(), norm_type)) - - #param_norm = p.data.float().norm(norm_type) - total_norm += param_norm**norm_type - else: - try: - param_norm = float(torch.norm(p, norm_type, dtype=torch.float32)) - except TypeError as err: - param_norm = float(torch.norm(p.float(), norm_type)) - #param_norm = p.data.float().norm(norm_type) - total_norm += param_norm**norm_type + # Pipeline parallelism may replicate parameters. Avoid multi-counting. + if hasattr(p, 'ds_pipe_replicated') and p.ds_pipe_replicated: + continue + + # Filter to avoid over-counting replicated tensors from tensor + # model parallelism + if (tensor_mp_rank > 0) and not is_model_parallel_parameter(p): + continue + + param_norm = p.data.float().norm(norm_type) + total_norm += param_norm**norm_type # Sum across all model parallel GPUs. total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) diff --git a/deepspeed/runtime/zero/stage2.py b/deepspeed/runtime/zero/stage2.py index ad529aa96391..64955ff400dc 100755 --- a/deepspeed/runtime/zero/stage2.py +++ b/deepspeed/runtime/zero/stage2.py @@ -13,7 +13,7 @@ import collections from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler -from deepspeed.runtime.utils import see_memory_usage, is_model_parallel_parameter +from deepspeed.runtime.utils import bwc_tensor_model_parallel_rank, see_memory_usage, is_model_parallel_parameter from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_GRADIENTS from deepspeed.ops.adam import DeepSpeedCPUAdam from deepspeed.ops.op_builder import UtilsBuilder @@ -149,7 +149,7 @@ def __init__(self, self.model_parallel_rank = 0 else: self.model_parallel_group = mpu.get_model_parallel_group() - self.model_parallel_rank = mpu.get_model_parallel_rank() + self.model_parallel_rank = bwc_tensor_model_parallel_rank(mpu) self.overflow = False self.clip_grad = clip_grad @@ -970,6 +970,10 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params): total_norm = 0.0 norm_type = 2.0 for p in params: + # Pipeline parallelism may replicate parameters. Avoid multi-counting. + if hasattr(p, 'ds_pipe_replicated') and p.ds_pipe_replicated: + continue + if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): param_id = self.get_param_id(p) # as some model have trainable parameters but skipped in training, @@ -1366,6 +1370,9 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2): #if dist.get_rank() == 0: # logger.info(f"Total Norm begining {total_norm}") for g, p in zip(gradients, params): + # Pipeline parallelism may replicate parameters. Avoid multi-counting. + if hasattr(p, 'ds_pipe_replicated') and p.ds_pipe_replicated: + continue if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): param_norm = g.data.double().norm(2) total_norm += param_norm.item()**2 @@ -1499,11 +1506,6 @@ def step(self, closure=None): see_memory_usage('After overflow after clearing gradients') - logger.info( - "[deepspeed] fp16 dynamic loss scale overflow! Rank {} Skipping step. Attempted loss scale: {}, " - "reducing to {}".format(dist.get_rank(), - prev_scale, - self.loss_scale)) self.start_timers(timer_names) self.stop_timers(timer_names) return