diff --git a/deepspeed/runtime/activation_checkpointing/checkpointing.py b/deepspeed/runtime/activation_checkpointing/checkpointing.py index f56e4b853313..efe95f91bac5 100644 --- a/deepspeed/runtime/activation_checkpointing/checkpointing.py +++ b/deepspeed/runtime/activation_checkpointing/checkpointing.py @@ -24,7 +24,7 @@ from deepspeed.runtime.config import DeepSpeedConfig from deepspeed.utils import logger -from deepspeed.runtime.utils import move_to_device, see_memory_usage +from deepspeed.runtime.utils import move_to_device, see_memory_usage, bwc_tensor_model_parallel_rank from deepspeed.utils.timer import SynchronizedWallClockTimer as Timers # DeepSpeed Checkpointing Enabled or Disabled @@ -213,9 +213,12 @@ def model_parallel_cuda_manual_seed(seed): model parallel regions. """ global mpu + + tp_rank = bwc_tensor_model_parallel_rank(mpu) + # 2718 is just for fun and any POSITIVE value will work. offset = seed + 2718 - model_parallel_seed = offset + mpu.get_model_parallel_rank() + model_parallel_seed = offset + tp_rank # Data parallel gets the original sedd. data_parallel_seed = seed @@ -225,7 +228,7 @@ def model_parallel_cuda_manual_seed(seed): 'model parallel rank {}, and data parallel rank {} with ' 'model parallel seed: {} and data parallel seed: {}'.format( torch.distributed.get_rank(), - mpu.get_model_parallel_rank(), + tp_rank, mpu.get_data_parallel_rank(), model_parallel_seed, data_parallel_seed), @@ -384,9 +387,14 @@ def save_args_for_backward(*all_args): global data_offsets, size_offsets if mp_rank is None: if mpu is not None: - mp_rank = mpu.get_model_parallel_rank() - mp_size = mpu.get_model_parallel_world_size() - mp_group = mpu.get_model_parallel_group() + if hasattr(mpu, 'get_tensor_model_parallel_rank'): + mp_rank = mpu.get_tensor_model_parallel_rank() + mp_size = mpu.get_tensor_model_parallel_world_size() + mp_group = mpu.get_tensor_model_parallel_group() + else: + mp_rank = mpu.get_model_parallel_rank() + mp_size = mpu.get_model_parallel_world_size() + mp_group = mpu.get_model_parallel_group() else: mp_rank = 0 mp_size = 1 diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 325222db4669..a782581dde8d 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -36,6 +36,7 @@ ZERO_OPTIMIZATION_OPTIMIZER_STATES, ZERO_OPTIMIZATION_GRADIENTS, ZERO_OPTIMIZATION_WEIGHTS from deepspeed.runtime.csr_tensor import CSRTensor import deepspeed.runtime.lr_schedules as lr_schedules +from deepspeed.runtime.utils import get_grad_norm from deepspeed.utils import logger, log_dist, init_distributed from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer from deepspeed.utils.debug import debug_extract_module_and_param_names @@ -122,6 +123,8 @@ def __init__(self, self.block_eigenvalue = None self.gas_boundary_ctr = 0 self.dist_backend = "nccl" + self._step_applied = False + self._global_grad_norm = None # for debug purposes - can then debug print: debug_get_module_name(module) debug_extract_module_and_param_names(model) @@ -234,6 +237,40 @@ def get_batch_info(self): """ return self.train_batch_size, self.train_micro_batch_size_per_gpu, self.gradient_accumulation_steps + def set_train_batch_size(self, train_batch_size): + """Adjust the global batch size by increasing or decreasing the number of + micro-batches (i.e., gradient accumulation steps). The size of each micro-batch + (i.e., ``train_micro_batch_size_per_gpu``) is not changed. + Args: + train_batch_size (int): The new global batch size for training. + Raises: + ValueError: if ``train_batch_size`` is not divisible by the + configured micro-batch size and data parallelism. + """ + if train_batch_size % (self.train_micro_batch_size_per_gpu() * + self.dp_world_size) != 0: + #print(f'{train_batch_size=} {self.train_micro_batch_size_per_gpu()=} {self.dp_world_size=}') + raise ValueError( + f'Train batch size must be divisible by micro-batch data parallelism') + new_gas = train_batch_size // (self.train_micro_batch_size_per_gpu() * + self.dp_world_size) + # overwrite config + self._config.train_batch_size = train_batch_size + self._config.gradient_accumulation_steps = new_gas + + def get_global_grad_norm(self) -> float: + """Return the 2-norm of all gradients. If there is model parallelism, + the norm will be global. + + The computed norm will be cached and reused until the next step() pass. + .. note:: + In the presence of model parallelism, this is a collective call + and acts as a barrier among ``mpu.get_model_parallel_group()``. + Returns: + float: norm + """ + return self._global_grad_norm + def checkpoint_tag_validation_enabled(self): return self._config.checkpoint_tag_validation_enabled @@ -479,6 +516,9 @@ def steps_per_print(self): def zero_allgather_partitions(self): return self._config.zero_config.allgather_partitions + def zero_round_robin_gradients(self): + return self._config.zero_config.round_robin_gradients + def dump_state(self): return self._config.dump_state @@ -570,7 +610,7 @@ def _configure_with_arguments(self, args, mpu): if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ: ompi_local_rank = os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK") local_rank = os.environ.get('LOCAL_RANK', ompi_local_rank) - assert ompi_local_rank == local_rank, f"LOCAL_RANK ({local_rank}) != OMPI_COMM_WORLD_LOCAL_RANK ({mpi_local_rank}), " \ + assert ompi_local_rank == local_rank, f"LOCAL_RANK ({local_rank}) != OMPI_COMM_WORLD_LOCAL_RANK ({ompi_local_rank}), " \ "not sure how to proceed as we're seeing conficting local rank info." os.environ['LOCAL_RANK'] = local_rank @@ -904,6 +944,15 @@ def _configure_zero_optimizer(self, optimizer): gradient_predivide=self.gradient_predivide) elif zero_stage <= ZERO_OPTIMIZATION_GRADIENTS: overlap_comm = self.zero_overlap_comm() + contiguous_gradients = self.zero_contiguous_gradients() + round_robin_gradients = self.zero_round_robin_gradients() + + # Overlap and contiguous grads are meaningless in stage 1 and are ignored + if zero_stage == ZERO_OPTIMIZATION_OPTIMIZER_STATES: + overlap_comm = False + contiguous_gradients = False + round_robin_gradients = False + if isinstance(self.module, PipelineModule): if overlap_comm: logger.warning( @@ -918,7 +967,7 @@ def _configure_zero_optimizer(self, optimizer): dynamic_loss_scale=self.dynamic_loss_scale(), dynamic_loss_args=self.dynamic_loss_scale_args(), clip_grad=self.gradient_clipping(), - contiguous_gradients=self.zero_contiguous_gradients(), + contiguous_gradients=contiguous_gradients, reduce_bucket_size=self.zero_reduce_bucket_size(), allgather_bucket_size=self.zero_allgather_bucket_size(), dp_process_group=self.data_parallel_group, @@ -930,7 +979,8 @@ def _configure_zero_optimizer(self, optimizer): gradient_predivide_factor=self.gradient_predivide_factor(), gradient_accumulation_steps=self.gradient_accumulation_steps(), ignore_unused_parameters=self.zero_ignore_unused_parameters(), - partition_grads=zero_stage == ZERO_OPTIMIZATION_GRADIENTS) + partition_grads=zero_stage == ZERO_OPTIMIZATION_GRADIENTS, + round_robin_gradients=round_robin_gradients) elif zero_stage == ZERO_OPTIMIZATION_WEIGHTS: logger.info("Initializing ZeRO Stage 3") if dist.get_rank() == 0 else None from deepspeed.runtime.zero.stage3 import FP16_DeepSpeedZeroOptimizer_Stage3 @@ -992,6 +1042,18 @@ def is_iterable_style_dataset(obj): torch.utils.data.IterableDataset ) # hasattr(obj, "__iter__") should work as well + def was_step_applied(self) -> bool: + """Returns True if the latest ``step()`` produced in parameter updates. + + Note that a ``False`` return is not an error condition. Steps are frequently + no-ops, such as between gradient accumulation boundaries or when overflows + occur. + + Returns: + bool: Whether the latest ``step()`` modified model parameters. + """ + return self._step_applied + def deepspeed_io(self, dataset, batch_size=None, @@ -1129,6 +1191,10 @@ def forward(self, *inputs, **kwargs): return loss def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): + # Pass (PP) gas boundary flag to optimizer (required for zero) + self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary( + ) + # ZeRO stage 2 communicates during non gradient accumulation boundaries as well if self.zero_optimization_partition_gradients(): self.optimizer.overlapping_partition_gradients_reduce_epilogue() @@ -1264,6 +1330,9 @@ def _take_model_step(self, lr_kwargs, block_eigenvalue={}): self.optimizer.step() + if hasattr(self.optimizer, '_global_grad_norm'): + self._global_grad_norm = self.optimizer._global_grad_norm + # Quantize the updated parameter if there no overflow if self.quantizer: self.quantizer.quantize( @@ -1286,12 +1355,19 @@ def _take_model_step(self, lr_kwargs, block_eigenvalue={}): overflow = False if hasattr(self.optimizer, 'overflow'): overflow = self.optimizer.overflow + self._step_applied = not overflow if overflow: self.skipped_steps += 1 else: if self.lr_scheduler is not None: - self.lr_scheduler.step(**(lr_kwargs or {})) + try: + self.lr_scheduler.step(**(lr_kwargs or {})) + except TypeError: + # XXX Hack to work with Megatron 2.0 and DeepSpeed pipelines. + # We don't currently have a way to specify lr_kwargs from + # pipe_engine.train_batch() + self.lr_scheduler.step(increment=self.train_batch_size()) if report_progress and (self.global_steps + 1) % self.steps_per_print() == 0: self._report_progress(self.global_steps + 1) @@ -1311,6 +1387,8 @@ def step(self, lr_kwargs=None): "init in order to use step" report_progress = self.global_rank == 0 if self.global_rank else True + self._step_applied = False # assume False, will flip to True + # Update the model when we reach gradient accumulation boundaries if self.is_gradient_accumulation_boundary(): self.gas_boundary_ctr += 1 @@ -1674,9 +1752,12 @@ def load_checkpoint(self, load_lr_scheduler_states=load_lr_scheduler_states) if self.zero_optimization() and load_path is not None: - self._load_zero_checkpoint(load_dir, - tag, - load_optimizer_states=load_optimizer_states) + success = self._load_zero_checkpoint( + load_dir, + tag, + load_optimizer_states=load_optimizer_states) + if not success: + self.optimizer._restore_from_fp16_weights() return load_path, client_states @@ -1746,7 +1827,7 @@ def _load_checkpoint(self, def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True): zero_sd_list = self._get_all_zero_checkpoints(load_dir, tag) if zero_sd_list is None: - return + return False self.optimizer.load_state_dict( state_dict_list=zero_sd_list, @@ -1755,6 +1836,7 @@ def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True): print( f'loading {len(zero_sd_list)} zero partition checkpoints for rank {self.global_rank}' ) + return True def _get_mp_rank_zero_checkpoint_names(self, load_dir, tag, mp_rank, dp_world_size): zero_ckpt_names = [] @@ -1973,7 +2055,7 @@ def _copy_recovery_script(self, save_path): script = "zero_to_fp32.py" src = os.path.join(base_dir, "utils", script) dst = os.path.join(save_path, script) - logger.info(f"creating recovery script {dst}") + #logger.info(f"creating recovery script {dst}") copyfile(src, dst) # make executable os.chmod(dst, os.stat(dst).st_mode | stat.S_IEXEC) @@ -1986,7 +2068,7 @@ def _save_zero_checkpoint(self, save_path, tag): ds_version=version) torch.save(zero_sd, zero_checkpoint_name) self._copy_recovery_script(save_path) - logger.info('zero checkpoint saved {}'.format(zero_checkpoint_name)) + #logger.info('zero checkpoint saved {}'.format(zero_checkpoint_name)) def _zero3_consolidated_fp16_state_dict(self): """ diff --git a/deepspeed/runtime/fp16/fused_optimizer.py b/deepspeed/runtime/fp16/fused_optimizer.py index fba0d6b1fd59..72dd2c161845 100755 --- a/deepspeed/runtime/fp16/fused_optimizer.py +++ b/deepspeed/runtime/fp16/fused_optimizer.py @@ -9,7 +9,7 @@ import math from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors -from deepspeed.runtime.utils import get_grad_norm, CheckOverflow, get_weight_norm +from deepspeed.runtime.utils import get_global_norm, get_grad_norm, CheckOverflow, get_weight_norm from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE from deepspeed.utils import logger, log_dist @@ -45,6 +45,8 @@ def __init__(self, self.fp16_groups_flat = [] self.fp32_groups_flat = [] + self._global_grad_norm = 0. + # loop to deal with groups for i, param_group in enumerate(self.optimizer.param_groups): # push this group to list before modify @@ -161,8 +163,11 @@ def step_fused_adam(self, closure=None): "scale: {}, reducing to {}".format(prev_scale, self.cur_scale)) return self.overflow + + self._global_grad_norm = get_global_norm(norm_list=norm_groups) + combined_scale = self.unscale_and_clip_grads(grads_groups_flat, - norm_groups, + self._global_grad_norm, apply_scale=False) # norm is in fact norm*cur_scale self.optimizer.step(grads=[[g] for g in grads_groups_flat], @@ -251,8 +256,10 @@ def step(self, closure=None): all_groups_norm = get_grad_norm(self.fp32_groups_flat, mpu=self.mpu) self.stop_timers([COMPUTE_NORM]) + self._global_grad_norm = get_global_norm(norm_list=[all_groups_norm]) + self.start_timers([UNSCALE_AND_CLIP]) - self.unscale_and_clip_grads(grads_groups_flat, [all_groups_norm]) + self.unscale_and_clip_grads(grads_groups_flat, self._global_grad_norm) self.stop_timers([UNSCALE_AND_CLIP]) self.start_timers([BASIC_STEP]) @@ -277,12 +284,7 @@ def step(self, closure=None): return self.overflow - def unscale_and_clip_grads(self, grad_groups_flat, norm_groups, apply_scale=True): - total_norm = 0.0 - for norm in norm_groups: - total_norm += norm**2.0 - total_norm = math.sqrt(total_norm) - + def unscale_and_clip_grads(self, grad_groups_flat, total_norm, apply_scale=True): # compute combined scale factor for this group combined_scale = self.cur_scale if self.clip_grad > 0.: diff --git a/deepspeed/runtime/fp16/unfused_optimizer.py b/deepspeed/runtime/fp16/unfused_optimizer.py index c30df0bef1d0..08a4a7c41cb2 100755 --- a/deepspeed/runtime/fp16/unfused_optimizer.py +++ b/deepspeed/runtime/fp16/unfused_optimizer.py @@ -9,7 +9,7 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors import math -from deepspeed.runtime.utils import get_grad_norm, CheckOverflow, get_weight_norm +from deepspeed.runtime.utils import get_global_norm, get_grad_norm, CheckOverflow, get_weight_norm from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE from deepspeed.utils import logger @@ -32,6 +32,7 @@ def __init__(self, fused_lamb_legacy=False): self.fused_lamb_legacy = fused_lamb_legacy + self._global_grad_norm = 0. if torch.distributed.get_rank() == 0: logger.info(f'Fused Lamb Legacy : {self.fused_lamb_legacy} ') @@ -148,7 +149,9 @@ def step_fused_lamb(self, closure=None): self.cur_scale)) return self.overflow - combined_scale = self.unscale_and_clip_grads(norm_groups, apply_scale=False) + self._global_grad_norm = get_global_norm(norm_list=norm_groups) + combined_scale = self.unscale_and_clip_grads(self._global_grad_norm, + apply_scale=False) self.optimizer.step(grads=grads_groups, output_params=self.fp16_groups, scale=combined_scale) @@ -197,7 +200,8 @@ def step(self, closure=None): else: fp32_param.grad = fp16_param.grad.to(fp32_param.dtype) - self.unscale_and_clip_grads(norm_groups) + self._global_grad_norm = get_global_norm(norm_list=norm_groups) + self.unscale_and_clip_grads(self._global_grad_norm) self.optimizer.step() @@ -212,12 +216,7 @@ def step(self, closure=None): return self.overflow - def unscale_and_clip_grads(self, norm_groups, apply_scale=True): - total_norm = 0.0 - for norm in norm_groups: - total_norm += norm**2.0 - total_norm = math.sqrt(total_norm) - + def unscale_and_clip_grads(self, total_norm, apply_scale=True): # compute combined scale factor for this group combined_scale = self.cur_scale if self.clip_grad > 0.: diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 196cbe8c6217..cadb5b82e36f 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -110,8 +110,9 @@ def __init__(self, *super_args, **super_kwargs): self.is_model_parallel = self.grid.model_parallel_size > 1 # Partition input/output buffers + # XXX temporarily disable while I revert some partition hacks. self.is_pipe_partitioned = self.is_model_parallel - self.is_grad_partitioned = False + self.is_grad_partitioned = self.is_model_parallel model_parameters = filter(lambda p: p.requires_grad, self.module.parameters()) num_params = sum([p.numel() for p in model_parameters]) @@ -393,6 +394,19 @@ def eval_batch(self, data_iter, compute_loss=True, reduce_output='avg'): return eval_output + def set_train_batch_size(self, train_batch_size): + """Adjust the global batch size by increasing or decreasing the number of + micro-batches (i.e., gradient accumulation steps). The size of each micro-batch + (i.e., ``train_micro_batch_size_per_gpu``) is not changed. + Args: + train_batch_size (int): The new global batch size for training. + Raises: + ValueError: if ``train_batch_size`` is not divisible by the + configured micro-batch size and data parallelism. + """ + super().set_train_batch_size(train_batch_size) + self.micro_batches = self.gradient_accumulation_steps() + def is_first_stage(self): """True if this process is in the first stage in the pipeline.""" return self.stage_id == 0 @@ -553,12 +567,18 @@ def _exec_forward_pass(self, buffer_id): local_part=inputs[1], group=self.grid.get_slice_parallel_group()) + inputs = part_input.full() + inputs.requires_grad = True + part_input = None + self.pipe_buffers['inputs'][buffer_id] = inputs + ''' inputs = tuple([part_input.full(), inputs[2]]) inputs[0].requires_grad = True # skip mask #inputs[1].requires_grad = True part_input = None self.pipe_buffers['inputs'][buffer_id] = inputs + ''' # Zero out the gradients each time we use the tensor because only the data in # tensor changes across batches @@ -568,13 +588,14 @@ def _exec_forward_pass(self, buffer_id): # Partition the outputs if we are not the last stage if self.is_pipe_partitioned and not self.is_last_stage(): - part = PartitionedTensor(tensor=outputs[0], + assert torch.is_tensor(outputs) + part = PartitionedTensor(tensor=outputs, group=self.grid.get_slice_parallel_group()) # Clear the large output data, but save the computation graph - outputs[0].data = torch.zeros(1) - self.pipe_buffers['output_tensors'][buffer_id] = outputs[0] + outputs.data = torch.zeros(1) + self.pipe_buffers['output_tensors'][buffer_id] = outputs # Inject the partitioned tensor into the output before sending - outputs = tuple([part.to_meta(), part.data(), outputs[1]]) + outputs = tuple([part.to_meta(), part.data()]) part = None self.pipe_buffers['outputs'][buffer_id] = outputs @@ -632,15 +653,11 @@ def _exec_backward_pass(self, buffer_id): local_part=outputs[1], group=self.grid.get_slice_parallel_group()) self.pipe_buffers['output_tensors'][buffer_id].data = part_output.full() - outputs = tuple( - [self.pipe_buffers['output_tensors'][buffer_id], - outputs[2]]) + outputs = self.pipe_buffers['output_tensors'][buffer_id] else: # Already restored from partition - self.pipe_buffers['output_tensors'][buffer_id].data = outputs[0] - outputs = tuple( - [self.pipe_buffers['output_tensors'][buffer_id], - outputs[1]]) + self.pipe_buffers['output_tensors'][buffer_id].data = outputs + outputs = self.pipe_buffers['output_tensors'][buffer_id] grad_tensors = self.grad_layer if self.is_grad_partitioned: @@ -649,7 +666,7 @@ def _exec_backward_pass(self, buffer_id): meta=self.grad_layer[0], local_part=self.grad_layer[1], group=self.grid.get_slice_parallel_group()) - grad_tensors = tuple([part_grad.full(), self.grad_layer[2]]) + grad_tensors = part_grad.full() part_grad = None #print(f'RANK={self.global_rank} BEFORE-BWD restored grad={self.grad_layer[0].size()} {self.grad_layer[1].size()}') @@ -872,13 +889,10 @@ def _exec_send_grads(self, buffer_id): # Partition the gradient if self.is_grad_partitioned: - part = PartitionedTensor(tensor=inputs[0].grad, + assert torch.is_tensor(inputs) + part = PartitionedTensor(tensor=inputs.grad, group=self.grid.get_slice_parallel_group()) - # Clear the large output data, but save the computation graph - # Inject the partitoned tensor into the output before sending - - # XXX Hack - inputs = tuple([part.to_meta(), part.data(), inputs[1]]) + inputs = tuple([part.to_meta(), part.data()]) # XXX Terrible hack # Drop the attention mask from the input buffer here. It does not have @@ -899,8 +913,6 @@ def _exec_send_grads(self, buffer_id): # First two sends are partitioned gradient p2p.send(inputs[0], self.prev_stage) p2p.send(inputs[1], self.prev_stage) - # XXX hack hack hack - #p2p.send(inputs[2].grad, self.prev_stage) else: for idx, buffer in enumerate(inputs): # Skip tensors that will not produce a grad @@ -974,7 +986,7 @@ def _exec_recv_grads(self, buffer_id): local_part=outputs[1], group=self.grid.get_slice_parallel_group()) outputs[0].data = part_output.full() - outputs = tuple([outputs[0], outputs[2]]) + outputs = outputs[0] # save for backward self.pipe_buffers['outputs'][buffer_id] = outputs @@ -984,7 +996,7 @@ def _exec_recv_grads(self, buffer_id): s = list(outputs.size()) self.grad_layer = self._allocate_buffer(s, num_buffers=1)[0] else: - sizes = [list(t.size()) for t in outputs if t.is_floating_point()] + sizes = [list(t.size()) for t in outputs] # if t.is_floating_point()] self.grad_layer = self._allocate_buffers(sizes, num_buffers=1)[0] if isinstance(self.grad_layer, torch.Tensor): diff --git a/deepspeed/runtime/pipe/module.py b/deepspeed/runtime/pipe/module.py index c4e111e47315..e61dd1c72878 100644 --- a/deepspeed/runtime/pipe/module.py +++ b/deepspeed/runtime/pipe/module.py @@ -253,7 +253,7 @@ def _build(self): # All pipeline parameters should be considered as model parallel in the context # of our FP16 optimizer for p in self.parameters(): - p.model_parallel = True + p.ds_pipe_replicated = False def _count_layer_params(self): """Count the trainable parameters in individual layers. @@ -472,7 +472,7 @@ def _index_tied_modules(self): # Only count the tied module once in the eyes of the FP16 optimizer if self.global_rank != tied_ranks[0]: for p in self.tied_modules[key].parameters(): - p.model_parallel = False + p.ds_pipe_replicated = True ''' if len(tied_comms) > 0: print(f'RANK={self.global_rank} tied_comms={tied_comms}') @@ -559,7 +559,18 @@ def save_state_dict(self, save_dir): model_ckpt_path = self.ckpt_layer_path(save_dir, idx) if not hasattr(layer, 'state_dict'): continue - torch.save(layer.state_dict(), model_ckpt_path) + # We pass cloned tensors to torch.save() to avoid checkpoint bloat which occurs because torch.save() + # saves the underlying storage rather than the slice of the storage corresponding to individual tensors. + # This is a problem in DeepSpeed because we often allocate tensors using slices of large flattened buffers. + # Tensor cloning helps to avoid this problem because the storage of cloned tensors are closer to the true size. + # It is expected that the garbage collector will reclaim the cloned tensor storage to avoid memory bloat. + # See https://pytorch.org/docs/stable/notes/serialization.html#preserve-storage-sharing + orig_state_dict = layer.state_dict() + final_state_dict = type(orig_state_dict)( + {k: v.clone() + for k, + v in orig_state_dict.items()}) + torch.save(final_state_dict, model_ckpt_path) def load_state_dir(self, load_dir, strict=True): for idx, layer in enumerate(self.forward_funcs): @@ -577,15 +588,18 @@ def load_state_dir(self, load_dir, strict=True): layer.load_state_dict(checkpoint) - if self._grid.data_parallel_id == 0: - logger.info( - f'RANK={self.global_rank} Loaded layer={idx+self._local_start} file={load_path}' - ) + # if self._grid.data_parallel_id == 0: + # logger.info( + # f'RANK={self.global_rank} Loaded layer={idx+self._local_start} file={load_path}' + # ) self._synchronize_tied_weights() def _is_checkpointable(self, funcs): - if self.__class__.__name__ == 'GPT2ModelPipe': + # This is an unfortunate hack related to torch and deepspeed activation checkpoint implementations. + # Some layers like torch.nn.Embedding will not receive grads if checkpointed, which breaks things. + # I presume it's related to the discrete inputs that cannot require_grad? Need to revisit. + if self.__class__.__name__ in ('GPTModelPipe', 'GPT2ModelPipe'): return all('ParallelTransformerLayerPipe' in f.__class__.__name__ for f in funcs) diff --git a/deepspeed/runtime/pipe/p2p.py b/deepspeed/runtime/pipe/p2p.py index 24c0f250a4b9..e129d3d5b177 100644 --- a/deepspeed/runtime/pipe/p2p.py +++ b/deepspeed/runtime/pipe/p2p.py @@ -2,11 +2,27 @@ Copyright 2019 The Microsoft DeepSpeed Team ''' +import pickle +import typing + +import torch import torch.distributed as dist +# To query whether we have send/recv support +from packaging.version import Version +from deepspeed.git_version_info import torch_info + _groups = None _grid = None +_async = [] + + +def can_send_recv() -> bool: + torch_version = Version(torch_info['version']) + sendrecv_min = Version('1.8') + return torch_version >= sendrecv_min + #initializes adjacent process groups #run this only after torch.distributed.init_process_group() has been called @@ -16,7 +32,8 @@ def init_process_groups(grid): assert _grid.pipe_parallel_size > 1, "There is no pipeline parallelism" - _groups = [dist.new_group(ranks=group) for group in _grid.p2p_groups] + if not can_send_recv(): + _groups = [dist.new_group(ranks=group) for group in _grid.p2p_groups] def _is_valid_send_recv(src_stage, dest_stage): @@ -30,40 +47,117 @@ def _is_valid_send_recv(src_stage, dest_stage): def send(tensor, dest_stage, async_op=False): global _groups - - async_op = False + assert async_op == False, "Doesnt support async_op true" src_stage = _grid.get_stage_id() _is_valid_send_recv(src_stage, dest_stage) - group = _get_send_recv_group(src_stage, dest_stage) - src_rank = _grid.stage_to_global(stage_id=src_stage) + dest_rank = _grid.stage_to_global(stage_id=dest_stage) + if async_op: + global _async + op = dist.isend(tensor, dest_rank) + _async.append(op) + else: - return dist.broadcast(tensor, src_rank, group=group, async_op=async_op) + if can_send_recv(): + return dist.send(tensor, dest_rank) + else: + group = _get_send_recv_group(src_stage, dest_stage) + src_rank = _grid.stage_to_global(stage_id=src_stage) + return dist.broadcast(tensor, src_rank, group=group, async_op=async_op) def recv(tensor, src_stage, async_op=False): - global _groups - - async_op = False + assert async_op == False, "Doesnt support async_op true" dest_stage = _grid.get_stage_id() _is_valid_send_recv(src_stage, dest_stage) - group = _get_send_recv_group(src_stage, dest_stage) src_rank = _grid.stage_to_global(stage_id=src_stage) - return dist.broadcast(tensor, src_rank, group=group, async_op=async_op) - - -def barrier(stage_id): - global _groups, _grid - group_id = _grid.stage_to_global(stage_id=stage_id) - if (dist.get_rank() >= 0): - print("Barrier Group ID", group_id) - print("Barrier Group", _grid.p2p_groups[group_id]) - dist.barrier(group=_groups[group_id]) - if (dist.get_rank() >= 0): - print("Exiting Barrier ", group_id) + if async_op: + global _async + op = dist.irecv(tensor, src_rank) + _async.append(op) + else: + if can_send_recv(): + return dist.recv(tensor, src_rank) + else: + group = _get_send_recv_group(src_stage, dest_stage) + return dist.broadcast(tensor, src_rank, group=group, async_op=async_op) + + +def wait(): + global _async + for op in _async: + op.wait() + _async = [] + + torch.cuda.synchronize() + + +def send_obj(msg: typing.Any, dest: int): + """Send an arbitrary python object to ``dest``. + + Note: ``msg`` must be pickleable. + + WARN: This incurs a CPU -> GPU transfer and should be used sparingly + for performance reasons. + + Args: + msg (typing.Any): The object to send. + dest (int): Destination rank. + """ + # serialize the message + msg = pickle.dumps(msg) + # construct a tensor to send + msg = torch.ByteTensor(torch.ByteStorage.from_buffer(msg)).cuda() + + # Send meta and message + length_tensor = torch.tensor([len(msg)], dtype=torch.long).cuda() + dist.send(length_tensor, dst=dest) + dist.send(msg, dst=dest) + + +def recv_obj(sender: int) -> typing.Any: + """Receive an arbitrary python object from ``sender``. + + WARN: This incur a CPU <-> GPU transfers and should be used sparingly + for performance reasons. + + Args: + sender (int): The rank sending the message. + """ + # Get message meta + length = torch.tensor([0], dtype=torch.long).cuda() + dist.recv(length, src=sender) + + # Receive and deserialize + msg = torch.empty(length.item(), dtype=torch.uint8).cuda() + dist.recv(msg, src=sender) + + msg = pickle.loads(msg.cpu().numpy().tobytes()) + + def _to(x): + """Recursively move to the current device.""" + if torch.is_tensor(x): + return x.cuda() + if isinstance(x, (tuple, list)): + ret = [_to(x_) for x_ in x] + if isinstance(x, tuple): + ret = tuple(ret) + return ret + # handle kwargs + if isinstance(x, dict): + ret = dict() + for key, val in x.items(): + ret[_to(key)] = _to(val) + return ret + + # Anything else is a no-op + return x + + msg = _to(msg) + return msg def _get_send_recv_group(src_stage, dest_stage): diff --git a/deepspeed/runtime/state_dict_factory.py b/deepspeed/runtime/state_dict_factory.py index f6be1dbbe57b..f5562e141c3f 100755 --- a/deepspeed/runtime/state_dict_factory.py +++ b/deepspeed/runtime/state_dict_factory.py @@ -51,10 +51,6 @@ def load(self, self.module_key = module_key num_ckpt = len(self.ckpt_list) idx = mp_rank * num_ckpt // mp_world_size - - logger.info( - f'mp_world_size: {mp_world_size}, mp_rank: {mp_rank}, module_key: {module_key}' - ) """ We have multiple cases to handle here for both training and inference: 1. PipeModule loading mp_rank_*.pt files, is_pipe_parallel=True, module_key is not None a. if no mp_size/pp_size resizing occurs, for both training & inference, loading @@ -82,7 +78,7 @@ def load(self, merge_count = 1 if num_ckpt == mp_world_size: assert os.path.exists(load_path) - logger.info(f'rank: {mp_rank} loading checkpoint: {load_path}') + #logger.info(f'rank: {mp_rank} loading checkpoint: {load_path}') sd = torch.load(load_path, map_location=lambda storage, loc: storage) if quantize: @@ -162,7 +158,7 @@ def set_module(self, sd, module): return sd def check_ckpt_list(self): - logger.info(f'checkpoint file list: {self.ckpt_list}') + #logger.info(f'checkpoint file list: {self.ckpt_list}') assert len(self.ckpt_list) > 0 sd = torch.load(self.ckpt_list[0], map_location=lambda storage, loc: storage) diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index c792d2a6d0db..da9727284232 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -9,7 +9,7 @@ import os import psutil import gc -from math import ceil +from math import ceil, sqrt from math import floor from bisect import bisect_left, bisect_right @@ -49,6 +49,48 @@ def set_random_seed(seed): torch.manual_seed(seed) +def is_model_parallel_parameter(p) -> bool: + return hasattr(p, 'model_parallel') and p.model_parallel + + +def bwc_tensor_model_parallel_rank(mpu=None): + """Backwards-compatible way of querying the tensor model parallel rank from + an ``mpu`` object. + + *Tensor* model parallelism means that tensors are physically split across + processes. This contrasts with *pipeline* model parallelism, in which the + layers are partitioned but tensors left intact. + + The API for tensor model parallelism has changed across versions and this + helper provides a best-effort implementation across versions of ``mpu`` + objects. The preferred mechanism is + ``mpu.get_tensor_model_parallel_rank()``. + + This should "just work" with both Megatron-LM and DeepSpeed's pipeline + parallelism. + + Args: + mpu (model parallel unit, optional): The tensor model parallel rank. + If ``mpu=None``, returns 0. Defaults to ``None``. + + Returns: + int: the rank + """ + if mpu is None: + # No model parallelism in easy :) + return 0 + + if hasattr(mpu, 'get_tensor_model_parallel_rank'): + # New Megatron and DeepSpeed convention (post pipeline-parallelism release) + return mpu.get_tensor_model_parallel_rank() + elif hasattr(mpu, 'get_slice_parallel_rank'): + # Some DeepSpeed + pipeline parallelism versions + return mpu.get_slice_parallel_rank() + else: + # Deprecated Megatron and DeepSpeed convention + return mpu.get_model_parallel_rank() + + def move_to_device(item, device): """ Move tensor onto device. Works on individual tensors, and tensors contained/nested in lists, tuples, and dicts. @@ -198,6 +240,15 @@ def _handle_overflow(cpu_sum, x, i): ) +def get_global_norm(norm_list): + """ Compute total from a list of norms + """ + total_norm = 0.0 + for norm in norm_list: + total_norm += norm**2.0 + return sqrt(total_norm) + + def get_grad_norm(parameters, norm_type=2, mpu=None): """Clips gradient norm of an iterable of parameters. @@ -231,15 +282,19 @@ def get_grad_norm(parameters, norm_type=2, mpu=None): total_norm = total_norm_cuda[0].item() else: total_norm = 0. + tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=mpu) for p in parameters: - if mpu is not None: - if (mpu.get_model_parallel_rank() == 0 - ) or is_model_parallel_parameter(p): - param_norm = p.grad.data.float().norm(norm_type) - total_norm += param_norm.item()**norm_type - else: - param_norm = p.grad.data.float().norm(norm_type) - total_norm += param_norm.item()**norm_type + # Pipeline parallelism may replicate parameters. Avoid multi-counting. + if hasattr(p, 'ds_pipe_replicated') and p.ds_pipe_replicated: + continue + + # Filter to avoid over-counting replicated tensors from tensor + # model parallelism + if (tensor_mp_rank > 0) and not is_model_parallel_parameter(p): + continue + + param_norm = p.grad.data.float().norm(norm_type) + total_norm += param_norm.item()**norm_type # Sum across all model parallel GPUs. total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) @@ -256,6 +311,48 @@ def get_grad_norm(parameters, norm_type=2, mpu=None): return total_norm +def get_grad_zeros(parameters, mpu=None): + """Compute the number of grads with zero values. + + This is adapted from get_grad_norm + + Arguments: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + + Returns: + Total number of params with zero values (viewed as a single vector). + """ + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + + total_zeros = 0. + tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=mpu) + for p in parameters: + # Pipeline parallelism may replicate parameters. Avoid multi-counting. + if hasattr(p, 'ds_pipe_replicated') and p.ds_pipe_replicated: + continue + + # Filter to avoid over-counting replicated tensors from tensor + # model parallelism + if (tensor_mp_rank > 0) and not is_model_parallel_parameter(p): + continue + + count_zeros = p.grad.numel() - torch.count_nonzero(p.grad) + total_zeros += count_zeros.item() + + # Sum across all model parallel GPUs. + total_zeros_cuda = torch.cuda.FloatTensor([float(total_zeros)]) + if mpu is not None: + torch.distributed.all_reduce(total_zeros_cuda, + op=torch.distributed.ReduceOp.SUM, + group=mpu.get_model_parallel_group()) + total_zeros = total_zeros_cuda[0].item() + + return total_zeros + + def get_weight_norm(parameters, norm_type=2, mpu=None): """Clips gradient norm of an iterable of parameters. @@ -288,24 +385,19 @@ def get_weight_norm(parameters, norm_type=2, mpu=None): total_norm = total_norm_cuda[0].item() else: total_norm = 0. + tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=mpu) for p in parameters: - if mpu is not None: - if (mpu.get_model_parallel_rank() == 0 - ) or is_model_parallel_parameter(p): - try: - param_norm = float(torch.norm(p, norm_type, dtype=torch.float32)) - except TypeError as err: - param_norm = float(torch.norm(p.float(), norm_type)) - - #param_norm = p.data.float().norm(norm_type) - total_norm += param_norm**norm_type - else: - try: - param_norm = float(torch.norm(p, norm_type, dtype=torch.float32)) - except TypeError as err: - param_norm = float(torch.norm(p.float(), norm_type)) - #param_norm = p.data.float().norm(norm_type) - total_norm += param_norm**norm_type + # Pipeline parallelism may replicate parameters. Avoid multi-counting. + if hasattr(p, 'ds_pipe_replicated') and p.ds_pipe_replicated: + continue + + # Filter to avoid over-counting replicated tensors from tensor + # model parallelism + if (tensor_mp_rank > 0) and not is_model_parallel_parameter(p): + continue + + param_norm = p.data.float().norm(norm_type) + total_norm += param_norm**norm_type # Sum across all model parallel GPUs. total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) diff --git a/deepspeed/runtime/zero/config.py b/deepspeed/runtime/zero/config.py index 377ad94549a7..a48dd4e620b4 100755 --- a/deepspeed/runtime/zero/config.py +++ b/deepspeed/runtime/zero/config.py @@ -39,6 +39,7 @@ def __init__(self, param_dict): self.gather_fp16_weights_on_model_save = None self.ignore_unused_parameters = None + self.round_robin_gradients = None if ZERO_OPTIMIZATION in param_dict.keys(): zero_config_dict = param_dict[ZERO_OPTIMIZATION] @@ -184,3 +185,8 @@ def _initialize(self, zero_config_dict): self.legacy_stage1 = get_scalar_param(zero_config_dict, ZERO_OPTIMIZATION_LEGACY_STAGE1, ZERO_OPTIMIZATION_LEGACY_STAGE1_DEFAULT) + + self.round_robin_gradients = get_scalar_param( + zero_config_dict, + ZERO_OPTIMIZATION_ROUND_ROBIN_GRADIENTS, + ZERO_OPTIMIZATION_ROUND_ROBIN_GRADIENTS_DEFAULT) diff --git a/deepspeed/runtime/zero/constants.py b/deepspeed/runtime/zero/constants.py index 7beebe00e717..e0a26d53609f 100755 --- a/deepspeed/runtime/zero/constants.py +++ b/deepspeed/runtime/zero/constants.py @@ -30,7 +30,8 @@ "sub_group_size" : 1000000000000, "offload_param": {...}, "offload_optimizer": {...}, - "ignore_unused_parameters": [true|false] + "ignore_unused_parameters": [true|false], + "round_robin_gradients": [true|false] } } ''' @@ -124,6 +125,10 @@ ZERO_OPTIMIZATION_LEGACY_STAGE1 = "legacy_stage1" ZERO_OPTIMIZATION_LEGACY_STAGE1_DEFAULT = False +# Stage 2 - partition gradients in a round robin fashsion to load-balance reduction and offload copying +ZERO_OPTIMIZATION_ROUND_ROBIN_GRADIENTS = 'round_robin_gradients' +ZERO_OPTIMIZATION_ROUND_ROBIN_GRADIENTS_DEFAULT = False + #yapf: disable ZERO_OPTIMIZATION_DEFAULT = { ZERO_OPTIMIZATION_STAGE: @@ -161,5 +166,7 @@ ZERO_OPTIMIZATION_IGNORE_UNUSED_PARAMETERS: ZERO_OPTIMIZATION_IGNORE_UNUSED_PARAMETERS_DEFAULT, ZERO_OPTIMIZATION_LEGACY_STAGE1: - ZERO_OPTIMIZATION_LEGACY_STAGE1_DEFAULT + ZERO_OPTIMIZATION_LEGACY_STAGE1_DEFAULT, + ZERO_OPTIMIZATION_ROUND_ROBIN_GRADIENTS: + ZERO_OPTIMIZATION_ROUND_ROBIN_GRADIENTS_DEFAULT } diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 6ef87f9e00aa..7dc9076eed3e 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -215,11 +215,16 @@ def recurse(cl): # Inserts _post_init_method at the end of init method # for all sub classes of torch.nn.Module class InsertPostInitMethodToModuleSubClasses(object): - def __init__(self, enabled=True, mem_efficient_linear=True, config=None, dtype=None): + def __init__(self, + enabled=True, + mem_efficient_linear=True, + ds_config=None, + dtype=None): self.mem_efficient_linear = mem_efficient_linear self.enabled = enabled - self._set_dtype(config, dtype) - assert self.dtype in [torch.half, torch.float], f"Invalid data type {self.dtype}, allowed values are [torch.half, torch.float]" + self._set_dtype(ds_config, dtype) + assert self.dtype in [ + torch.half, torch.float], f"Invalid data type {self.dtype}, allowed values are [torch.half, torch.float]" def __enter__(self): if not self.enabled: @@ -287,8 +292,8 @@ def _disable_class(cls): torch.Tensor.__new__ = torch.Tensor.__old_new__ torch.empty = _orig_torch_empty - #un doing it here will undo it during training - #if self.mem_efficient_linear: + # un doing it here will undo it during training + # if self.mem_efficient_linear: # torch.nn.functional.linear = self.linear_bk # if self.mem_efficient_linear: # torch.nn.functional.linear = self.linear_bk @@ -303,8 +308,7 @@ def _post_init_method(self, module): def _set_dtype(self, ds_config, dtype): if ds_config is not None and dtype is None: - _ds_config = DeepSpeedConfig(ds_config) - self.dtype = torch.half if _ds_config.fp16_enabled else torch.float + self.dtype = torch.half if ds_config.fp16_enabled else torch.float elif dtype is None: self.dtype = torch.half else: @@ -321,9 +325,11 @@ def __init__(self, mem_efficient_linear=True, remote_device=None, pin_memory=False, + config_dict_or_path=None, config=None, enabled=True, - dtype=None): + dtype=None, + mpu=None): """A context to enable massive model construction for training with ZeRO-3. Models are automatically partitioned (or, sharded) across the system and converted to half precision. @@ -343,12 +349,14 @@ def __init__(self, pin_memory (bool, optional): Potentially increase performance by using pinned memory for model weights. ``remote_device`` must be ``"cpu"``. Defaults to ``False``. - config (``json file`` or dict, optional): If provided, provides configuration + config_dict_or_path (dict or ``json file``, optional): If provided, provides configuration for swapping fp16 params to NVMe. + config (dict or ``json file``, optional): Deprecated, use config_dict_or_path instead. enabled (bool, optional): If ``False``, this context has no effect. Defaults to ``True``. dtype (``dtype``, optional): Can be used to change the data type of the parameters. Supported options are ``torch.half`` and ``torch.float``. Defaults to ``None`` + mpu (``object``, optional): A model parallelism unit object that implements get_{model,data}_parallel_{rank,group,wolrd_size} This context accelerates model initialization and enables models that are too large to allocate in their entirety in CPU memory. It has the @@ -420,9 +428,11 @@ def get_model(): model = deepspeed.zero.Init(module=model) """ + _ds_config = DeepSpeedConfig(config_dict_or_path, + mpu) if config_dict_or_path is not None else None super().__init__(enabled=enabled, mem_efficient_linear=mem_efficient_linear, - config=config, + ds_config=_ds_config, dtype=dtype) if not torch.distributed.is_initialized(): init_distributed() @@ -435,21 +445,20 @@ def get_model(): self.rank = torch.distributed.get_rank(group=self.ds_process_group) self.world_size = torch.distributed.get_world_size(group=self.ds_process_group) - #Local device is the device where the parameters are consumed - #It is the device where parameters are fully instantiated using allgather + # Local device is the device where the parameters are consumed + # It is the device where parameters are fully instantiated using allgather self.local_device = torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"])) - self._validate_remote_device(remote_device, config) + self._validate_remote_device(remote_device, _ds_config) - #Remote device is the device where parameter partiitons are stored - #It can be same as local_device or it could be CPU or NVMe. + # Remote device is the device where parameter partiitons are stored + # It can be same as local_device or it could be CPU or NVMe. self.remote_device = self.local_device if remote_device is None else remote_device self.pin_memory = pin_memory if ( self.remote_device == OFFLOAD_CPU_DEVICE) else False # Enable fp16 param swapping to NVMe if self.remote_device == OFFLOAD_NVME_DEVICE: - _ds_config = DeepSpeedConfig(config) self.param_swapper = AsyncPartitionedParameterSwapper(_ds_config) else: self.param_swapper = None @@ -463,22 +472,21 @@ def get_model(): self._convert_to_deepspeed_param(param) param.partition() - def _validate_remote_device(self, remote_device, ds_config): - if ds_config is not None: - _ds_config = DeepSpeedConfig(ds_config) + def _validate_remote_device(self, remote_device, _ds_config): + if _ds_config is not None: if remote_device in [None, OFFLOAD_CPU_DEVICE]: if _ds_config.zero_config.offload_param is not None: offload_param_device = _ds_config.zero_config.offload_param[ OFFLOAD_PARAM_DEVICE] assert offload_param_device != OFFLOAD_NVME_DEVICE, \ - f"{OFFLOAD_PARAM_DEVICE} in DeepSpeed Config cannot be {offload_param_device} if remote device is {remote_device}." + f"{OFFLOAD_PARAM_DEVICE} in DeepSpeed Config cannot be {offload_param_device} if remote device is {remote_device}." if remote_device == OFFLOAD_NVME_DEVICE: assert _ds_config.zero_config.offload_param is not None, \ - f'{OFFLOAD_PARAM} must be defined in DeepSpeed Config if remote device is {OFFLOAD_NVME_DEVICE}.' + f'{OFFLOAD_PARAM} must be defined in DeepSpeed Config if remote device is {OFFLOAD_NVME_DEVICE}.' assert _ds_config.zero_config.offload_param[OFFLOAD_PARAM_NVME_PATH] is not None, \ - f'{OFFLOAD_PARAM_NVME_PATH} in DeepSpeed Config cannot be None if remote device is {OFFLOAD_NVME_DEVICE}' + f'{OFFLOAD_PARAM_NVME_PATH} in DeepSpeed Config cannot be None if remote device is {OFFLOAD_NVME_DEVICE}' def _post_init_method(self, module): #see_memory_usage(f"Before converting parmas in {module.__class__.__name__}", force=False) @@ -624,7 +632,7 @@ def _ensure_availability_of_partitioned_params(self, params): def _all_gather(self, param_list, async_op=False, hierarchy=None): - #fetches from nvme if the partition is not available and in nvme + # fetches from nvme if the partition is not available and in nvme self._ensure_availability_of_partitioned_params(param_list) handles = [] @@ -651,10 +659,10 @@ def _all_gather(self, param_list, async_op=False, hierarchy=None): def _partition(self, param_list, force=False, has_been_updated=False): for param in param_list: #print_rank_0(f"Before Partitioning Param {param.ds_id}") - #self._param_status(param) + # self._param_status(param) self._partition_param(param, has_been_updated=has_been_updated) param.ds_status = ZeroParamStatus.NOT_AVAILABLE - #if param.ds_tensor is not None: + # if param.ds_tensor is not None: # assert id(param.data) == id(param.ds_tensor.data), \ # "After the parameters are initially partitioned, make sure we are not recreating the partition." #print_rank_0(f"After Partitioning Param {param.ds_id}") @@ -678,7 +686,7 @@ def _partition_param(self, param, buffer=None, has_been_updated=False): # if numel in empty_buffers: # empty_buffers[numel].append(buffer) - #if torch.distributed.get_rank(): + # if torch.distributed.get_rank(): # print(f"Releasing {param.data.numel()}") if param.ds_tensor is not None and not has_been_updated: @@ -687,7 +695,7 @@ def _partition_param(self, param, buffer=None, has_been_updated=False): see_memory_usage( f'Before partitioning param {param.ds_id} {param.shape}', force=False) - #param.data does not store anything meaningful in partitioned state + # param.data does not store anything meaningful in partitioned state param.data = torch.ones(1, dtype=self.dtype).to(param.device) see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}', force=False) @@ -765,7 +773,7 @@ def _partition_param(self, param, buffer=None, has_been_updated=False): #param.data = param.ds_tensor.data - #param.data does not store anything meaningful in partitioned state + # param.data does not store anything meaningful in partitioned state see_memory_usage(f'Before partitioning param {param.ds_id} {param.shape}', force=False) @@ -1002,7 +1010,8 @@ def _partition_gradient(self, param, partition_buffer=None, accumulate=False): dtype=param.dtype, device=param.device) else: - assert partition_buffer.numel() >= partition_size, f"The partition buffer size {partition_buffer.numel()} should match the size of param.ds_tensor {partition_size}" + assert partition_buffer.numel( + ) >= partition_size, f"The partition buffer size {partition_buffer.numel()} should match the size of param.ds_tensor {partition_size}" rank = torch.distributed.get_rank(group=self.ds_process_group) start = partition_size * rank diff --git a/deepspeed/runtime/zero/stage1.py b/deepspeed/runtime/zero/stage1.py index 7660c9917b84..20a6c5a21944 100755 --- a/deepspeed/runtime/zero/stage1.py +++ b/deepspeed/runtime/zero/stage1.py @@ -5,7 +5,7 @@ from deepspeed.runtime.zero.utils import _initialize_parameter_parallel_groups from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler -from deepspeed.runtime.utils import get_grad_norm, CheckOverflow +from deepspeed.runtime.utils import get_global_norm, get_grad_norm, CheckOverflow from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_OPTIMIZER_STATES from deepspeed.utils import logger, log_dist from deepspeed.ops.op_builder import UtilsBuilder @@ -104,6 +104,7 @@ def __init__(self, self.postscale_gradients = postscale_gradients self.gradient_predivide_factor = gradient_predivide_factor self.gradient_average = gradient_average + self._global_grad_norm = 0. # TODO: automatically turn off if #params > some_limit self.all_gather_partitions = all_gather_partitions @@ -683,8 +684,11 @@ def step(self, closure=None): local_sub_partitions_grad_groups.append(local_grad_sub_partitions) + self._global_grad_norm = get_global_norm(norm_list=norm_groups) + #RS: update unscale/clip with sub partitions - self.unscale_and_clip_grads(local_sub_partitions_grad_groups, norm_groups) + self.unscale_and_clip_grads(local_sub_partitions_grad_groups, + self._global_grad_norm) self.optimizer.step() @@ -720,12 +724,7 @@ def step(self, closure=None): return self.overflow - def unscale_and_clip_grads(self, grad_groups_flat, norm_groups): - total_norm = 0.0 - for norm in norm_groups: - total_norm += norm**2.0 - total_norm = math.sqrt(total_norm) - + def unscale_and_clip_grads(self, grad_groups_flat, total_norm): # compute combined scale factor for this group combined_scale = self.loss_scale if self.clip_grad > 0.: diff --git a/deepspeed/runtime/zero/stage2.py b/deepspeed/runtime/zero/stage2.py index ad529aa96391..09cdfe80b59b 100755 --- a/deepspeed/runtime/zero/stage2.py +++ b/deepspeed/runtime/zero/stage2.py @@ -13,7 +13,7 @@ import collections from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler -from deepspeed.runtime.utils import see_memory_usage, is_model_parallel_parameter +from deepspeed.runtime.utils import bwc_tensor_model_parallel_rank, get_global_norm, see_memory_usage, is_model_parallel_parameter from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_GRADIENTS from deepspeed.ops.adam import DeepSpeedCPUAdam from deepspeed.ops.op_builder import UtilsBuilder @@ -99,12 +99,14 @@ def __init__(self, gradient_predivide_factor=1.0, gradient_accumulation_steps=1, ignore_unused_parameters=True, - partition_grads=True): + partition_grads=True, + round_robin_gradients=False): if dist.get_rank() == 0: logger.info(f"Reduce bucket size {reduce_bucket_size}") logger.info(f"Allgather bucket size {allgather_bucket_size}") logger.info(f"CPU Offload: {cpu_offload}") + logger.info(f'Round robin gradient partitioning: {round_robin_gradients}') # The fused optimizer does all the work. We need this layer for two reason: # 1. maintain same user API from apex.fp16_utils # 2. keep common stuff here in case we need to add ne552w fused optimizer later @@ -144,12 +146,14 @@ def __init__(self, self.is_gradient_accumulation_boundary = True + self._global_grad_norm = 0. + if mpu is None: self.model_parallel_group = None self.model_parallel_rank = 0 else: self.model_parallel_group = mpu.get_model_parallel_group() - self.model_parallel_rank = mpu.get_model_parallel_rank() + self.model_parallel_rank = bwc_tensor_model_parallel_rank(mpu) self.overflow = False self.clip_grad = clip_grad @@ -159,6 +163,7 @@ def __init__(self, self.gradient_accumulation_steps = gradient_accumulation_steps self.micro_step_id = 0 self.ignore_unused_parameters = ignore_unused_parameters + self.round_robin_gradients = round_robin_gradients self.extra_large_param_to_reduce = None @@ -232,10 +237,15 @@ def __init__(self, # This ensures that gradients are reduced in a fashion such that ownership round robins among the ranks. # For example, rather than 3 gradients (g_n+2, g_n+1, g_n) that are reduced consecutively belonging # to the same rank, instead they will belong to 3 ranks (r_m+2, r_m+1, r_m). - round_robin_tensors, round_robin_indices = self._round_robin_reorder( - self.fp16_groups[i], - dist.get_world_size(group=self.dp_process_group) - ) + if self.round_robin_gradients: + round_robin_tensors, round_robin_indices = self._round_robin_reorder( + self.fp16_groups[i], + dist.get_world_size(group=self.dp_process_group) + ) + else: + round_robin_tensors = self.fp16_groups[i] + round_robin_indices = list(range(len(self.fp16_groups[i]))) + self.round_robin_fp16_groups.append(round_robin_tensors) self.round_robin_fp6_indices.append(round_robin_indices) @@ -462,7 +472,7 @@ def initialize_optimizer_states(self): if not self.cpu_offload: for group in self.single_partition_of_fp32_groups: - group.grad = None + group.grad = None #class init return @@ -486,7 +496,8 @@ def reduce_gradients(self, pipeline_parallel=False): if not self.overlap_comm: for i, group in enumerate(self.fp16_groups): for param in group: - self.reduce_ready_partitions_and_remove_grads(param, i) + if param.grad is not None: + self.reduce_ready_partitions_and_remove_grads(param, i) # reduce any pending grads in either hook/non-hook case self.overlapping_partition_gradients_reduce_epilogue() @@ -964,12 +975,16 @@ def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param): src_tensor = param.grad.view(-1).narrow(0, source_offset, num_elements).float() dest_tensor.copy_(src_tensor, non_blocking=True) - param.grad = None + param.grad = None #offload only def complete_grad_norm_calculation_for_cpu_offload(self, params): total_norm = 0.0 norm_type = 2.0 for p in params: + # Pipeline parallelism may replicate parameters. Avoid multi-counting. + if hasattr(p, 'ds_pipe_replicated') and p.ds_pipe_replicated: + continue + if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): param_id = self.get_param_id(p) # as some model have trainable parameters but skipped in training, @@ -1082,17 +1097,18 @@ def reduce_ipg_grads(self): self.params_already_reduced[param_id] = True - if not self.is_param_in_current_partition[param_id]: - if self.overlap_comm and self.contiguous_gradients is False: - # Clear grads of other partitions during the next reduction - # to avoid clearing them before the reduction is complete. - if self.previous_reduced_grads is None: - self.previous_reduced_grads = [] - self.previous_reduced_grads.append(param) - else: - param.grad = None - elif self.contiguous_gradients: - self.copy_grads_in_partition(param) + if self.partition_gradients: + if not self.is_param_in_current_partition[param_id]: + if self.overlap_comm and self.contiguous_gradients is False: + # Clear grads of other partitions during the next reduction + # to avoid clearing them before the reduction is complete. + if self.previous_reduced_grads is None: + self.previous_reduced_grads = [] + self.previous_reduced_grads.append(param) + else: + param.grad = None #only if self.partition_gradients + elif self.contiguous_gradients: + self.copy_grads_in_partition(param) self.grads_in_ipg_bucket = [] self.params_in_ipg_bucket = [] @@ -1112,7 +1128,7 @@ def are_all_related_partitions_reduced(params_id): for params_id in self.is_grad_computed[i][partition_id]: if are_all_related_partitions_reduced(params_id): - self.param_dict[params_id].grad = None + self.param_dict[params_id].grad = None # dead code def flatten_and_print(self, message, tensors, start=0, n=5): flatten_tensor = self.flatten(tensors) @@ -1201,7 +1217,7 @@ def allreduce_bucket(self, bucket, allreduce_always_fp32=False, rank=None, log=N def _clear_previous_reduced_grads(self): if self.previous_reduced_grads is not None: for param in self.previous_reduced_grads: - param.grad = None + param.grad = None # overlap enabled self.previous_reduced_grads = None #if rank is specified do a reduction instead of an allreduce @@ -1316,7 +1332,7 @@ def zero_grad(self, set_grads_to_None=True): for group in self.fp16_groups: for p in group: if set_grads_to_None: - p.grad = None + p.grad = None # epilogue and in step else: if p.grad is not None: p.grad.detach_() @@ -1366,6 +1382,9 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2): #if dist.get_rank() == 0: # logger.info(f"Total Norm begining {total_norm}") for g, p in zip(gradients, params): + # Pipeline parallelism may replicate parameters. Avoid multi-counting. + if hasattr(p, 'ds_pipe_replicated') and p.ds_pipe_replicated: + continue if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): param_norm = g.data.double().norm(2) total_norm += param_norm.item()**2 @@ -1442,7 +1461,7 @@ def get_flat_partition(self, def free_grad_in_param_list(self, param_list): for p in param_list: - p.grad = None + p.grad = None # in step def reset_cpu_buffers(self): self.norm_for_param_grads = {} @@ -1499,11 +1518,6 @@ def step(self, closure=None): see_memory_usage('After overflow after clearing gradients') - logger.info( - "[deepspeed] fp16 dynamic loss scale overflow! Rank {} Skipping step. Attempted loss scale: {}, " - "reducing to {}".format(dist.get_rank(), - prev_scale, - self.loss_scale)) self.start_timers(timer_names) self.stop_timers(timer_names) return @@ -1548,7 +1562,8 @@ def step(self, closure=None): single_partition_grad_groups.append(single_grad_partition) - self.unscale_and_clip_grads(single_partition_grad_groups, norm_groups) + self._global_grad_norm = get_global_norm(norm_list=norm_groups) + self.unscale_and_clip_grads(single_partition_grad_groups, self._global_grad_norm) self.stop_timers([OPTIMIZER_GRADIENTS]) self.start_timers([OPTIMIZER_STEP]) @@ -1570,7 +1585,7 @@ def step(self, closure=None): #get rid of the fp32 gradients. Not needed anymore if not self.cpu_offload: for group in self.single_partition_of_fp32_groups: - group.grad = None + group.grad = None # in step for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups): fp16_partitions[partition_id].data.copy_(fp32_partition.data) @@ -1624,12 +1639,7 @@ def step(self, closure=None): return - def unscale_and_clip_grads(self, grad_groups_flat, norm_groups): - total_norm = 0.0 - for norm in norm_groups: - total_norm += norm**2.0 - total_norm = math.sqrt(total_norm) - + def unscale_and_clip_grads(self, grad_groups_flat, total_norm): # compute combined scale factor for this group combined_scale = self.loss_scale if self.clip_grad > 0.: diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 2806f5d420da..9b06d177fefd 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -16,7 +16,7 @@ from deepspeed.utils.logging import logger from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler -from deepspeed.runtime.utils import see_memory_usage, is_model_parallel_parameter +from deepspeed.runtime.utils import get_global_norm, see_memory_usage, is_model_parallel_parameter from deepspeed.runtime.zero.partition_parameters import * from deepspeed.runtime.zero.partition_parameters import _init_external_params from deepspeed.runtime.zero.constants import ZERO_OPTIMIZATION_WEIGHTS @@ -650,6 +650,7 @@ def __init__(self, self.flatten = util_ops.flatten self.unflatten = util_ops.unflatten self.dtype = self.optimizer.param_groups[0]['params'][0].dtype + self._global_grad_norm = 0. if not all(is_zero_param(p) for p in module.parameters()): group = None @@ -2764,6 +2765,7 @@ def step(self, closure=None): return norm_groups = self._get_norm_groups() + self._global_grad_norm = get_global_norm(norm_list=norm_groups) timer_names = set() @@ -2777,7 +2779,7 @@ def step(self, closure=None): self._prepare_sub_group(sub_group_id, timer_names) #scale the fp32 gradients - self.unscale_and_clip_grads(sub_group_id, norm_groups) + self.unscale_and_clip_grads(sub_group_id, self._global_grad_norm) #apply the optimizer step on the sub group and copy fp32 parameters to fp16 self._optimizer_step(sub_group_id) @@ -2826,15 +2828,9 @@ def dump_post_step_gradients(self): norm_list = [param_norm, ds_norm] + unflat_norm print(f'Post-Step Norms {i} {param_id} = {norm_list}') - def unscale_and_clip_grads(self, sub_group_id, norm_groups): - + def unscale_and_clip_grads(self, sub_group_id, total_norm): grad_groups_flat = [self.fp32_partitioned_groups_flat[sub_group_id].grad] - total_norm = 0.0 - for norm in norm_groups: - total_norm += norm**2.0 - total_norm = math.sqrt(total_norm) - # compute combined scale factor for this group combined_scale = self.loss_scale if self.clip_grad > 0.: diff --git a/op_builder/builder.py b/op_builder/builder.py index 21547f896473..70f0fd3d3e55 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -63,7 +63,8 @@ def get_default_compute_capatabilities(): 11: ["11.0", "11.1", "11.2", - "11.3"], + "11.3", + "11.4"], } diff --git a/tests/unit/simple_model.py b/tests/unit/simple_model.py index 9c6062d79faa..15c40976b6a1 100755 --- a/tests/unit/simple_model.py +++ b/tests/unit/simple_model.py @@ -17,10 +17,7 @@ def __init__(self, hidden_dim, empty_grad=False): def forward(self, x, y): hidden_dim = x - if self.empty_grad and torch.distributed.get_rank() == 0: - hidden_dim = self.linear(hidden_dim) + self.linear2(hidden_dim) - else: - hidden_dim = self.linear(hidden_dim) + hidden_dim = self.linear(hidden_dim) return self.cross_entropy_loss(hidden_dim, y) diff --git a/tests/unit/test_fp16.py b/tests/unit/test_fp16.py index b2e76f0b7b82..0c0ef3edd3a8 100755 --- a/tests/unit/test_fp16.py +++ b/tests/unit/test_fp16.py @@ -856,3 +856,38 @@ def _go(args): model.step() _go(args=args) + + +@pytest.mark.parametrize('stage', [1, 2, 3]) +def test_zero_empty_grad(tmpdir, stage): + config_dict = { + "train_batch_size": 1, + "steps_per_print": 1, + "fp16": { + "enabled": True + }, + "zero_optimization": { + "stage": stage + } + } + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim) + + @distributed_test(world_size=[1]) + def _go(args, model, hidden_dim): + optimizer = torch.optim.Adam(model.parameters()) + model, _, _, _ = deepspeed.initialize(args=args, + model=model, + optimizer=optimizer) + data_loader = random_dataloader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device) + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + _go(args=args, model=model, hidden_dim=hidden_dim)