diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 667826d363d2..32301c08a5c7 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -388,6 +388,9 @@ def zero_prefetch_bucket_size(self): def zero_param_persistence_threshold(self): return self._config.zero_config.param_persistence_threshold + def zero_gather_fp16_weights_on_model_save(self): + return self._config.zero_config.gather_fp16_weights_on_model_save + def fp16_enabled(self): return self._config.fp16_enabled @@ -1714,3 +1717,98 @@ def _save_zero_checkpoint(self, save_path, tag): torch.save(zero_sd, zero_checkpoint_name) self._copy_recovery_script(save_path) logger.info('zero checkpoint saved {}'.format(zero_checkpoint_name)) + + def _zero3_consolidated_fp16_state_dict(self): + """ + + Get a full non-partitioned state_dict with fp16 weights on cpu. + + This is similar to nn.Module.state_dict (modelled after _save_to_state_dict), but: + + 1. consolidates the weights from different partitions on gpu0 + 2. works on one layer at a time to require as little gpu0 memory as possible, by + moving the already consolidated weights to cpu + 3. takes care to keep the shared params shared when gradually copying the params to cpu + + Returns: + a consolidated fp16 ``state_dict`` on cpu on rank 0, ``None`` on other ranks + + """ + import deepspeed + + if not self.zero_optimization_partition_weights(): + raise ValueError("this function requires ZeRO-3 mode") + + state_dict = OrderedDict() if torch.distributed.get_rank() == 0 else None + shared_weights = {} + + def get_layer_state_dict(module, prefix=""): + # gather one layer at a time to be memory-efficient + with deepspeed.zero.GatheredParameters(list( + module.parameters(recurse=False))): + if torch.distributed.get_rank() == 0: + for name, param in module.named_parameters(recurse=False): + if param is None: + continue + key = prefix + name + # for shared weights we want to make sure not to unshare them when copying to cpu + data_ptr_id = param.storage().data_ptr() + if data_ptr_id in shared_weights: + # shared weights + # print(f"`{key}` is shared with `{shared_weights[data_ptr_id]}`") + state_dict[key] = state_dict[shared_weights[data_ptr_id]] + else: + state_dict[key] = param.detach().cpu() + shared_weights[data_ptr_id] = key + #print(f"param {name} {param.shape}") + #print(f"param {key} {param.shape} {state_dict[key].storage().data_ptr()}") + + # now buffers - not sure if need to take care of potentially shared weights here + for name, buf in module.named_buffers(recurse=False): + if buf is not None and name not in module._non_persistent_buffers_set: + state_dict[prefix + name] = buf.detach().cpu() + + for name, child in module.named_children(): + if child is not None: + get_layer_state_dict(child, prefix + name + ".") + + see_memory_usage("before get_layer_state_dict", force=False) + get_layer_state_dict(self.module, prefix="") + see_memory_usage("after get_layer_state_dict", force=False) + + return state_dict + + def save_fp16_model(self, save_dir, save_filename="pytorch_model.bin"): + r"""Save fp16 model weights + + This method saves the fp16 model weights at the desired destination. + + Arguments: + save_dir: Required. Directory for saving the model + save_filename: Optional. Filename to save to. Defaults to ``pytorch_model.bin`` + + Important: all processes must call this method and not just the process with rank 0. It is + because the processes need to work in sync to gather the weights. This method will hang + waiting to synchronize with other processes if it's called just for the process with rank 0. + + """ + + path = os.path.join(save_dir, save_filename) + + if self.zero_optimization_partition_weights(): + if self.zero_gather_fp16_weights_on_model_save(): + # consolidation is expensive in time and memory and therefore isn't a default + state_dict = self._zero3_consolidated_fp16_state_dict() + else: + # the model will be bogus if not consolidated so don't confuse the user by saving it + logger.info( + f"Did not save the model {path} because `stage3_gather_fp16_weights_on_model_save` is False" + ) + return + else: + state_dict = self.module.state_dict() + + if torch.distributed.get_rank() == 0: + os.makedirs(save_dir, exist_ok=True) + logger.info(f"Saving model weights to {path}") + torch.save(state_dict, path) diff --git a/deepspeed/runtime/zero/config.py b/deepspeed/runtime/zero/config.py index ac61a9dd52b3..622ffa9ba1cb 100755 --- a/deepspeed/runtime/zero/config.py +++ b/deepspeed/runtime/zero/config.py @@ -34,6 +34,7 @@ def __init__(self, param_dict): self.param_persistence_threshold = None self.max_live_parameters = None self.max_reuse_distance = None + self.gather_fp16_weights_on_model_save = None #Stage3 Specific Parameters self.prefetch_bucket_size = None @@ -150,3 +151,8 @@ def _initialize(self, zero_config_dict): zero_config_dict, ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD, ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD_DEFAULT) + + self.gather_fp16_weights_on_model_save = get_scalar_param( + zero_config_dict, + ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE, + ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE_DEFAULT) diff --git a/deepspeed/runtime/zero/constants.py b/deepspeed/runtime/zero/constants.py index 8d4cf2c5d293..e5812980a337 100755 --- a/deepspeed/runtime/zero/constants.py +++ b/deepspeed/runtime/zero/constants.py @@ -99,6 +99,10 @@ ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD = 'stage3_param_persistence_threshold' ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD_DEFAULT = 100000 +# gathers params for saving a model - inefficient but is required in certain situations +ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE = 'stage3_gather_fp16_weights_on_model_save' +ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE_DEFAULT = False + ZERO_OPTIMIZATION_DEFAULT = { ZERO_OPTIMIZATION_STAGE: ZERO_OPTIMIZATION_STAGE_DEFAULT, @@ -133,5 +137,7 @@ ZERO_OPTIMIZATION_PREFETCH_BUCKET_SIZE: ZERO_OPTIMIZATION_PREFETCH_BUCKET_SIZE_DEFAULT, ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD: - ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD_DEFAULT + ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD_DEFAULT, + ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE: + ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE_DEFAULT }