From 31e0b909f9a4186c395f8913fdab3466b1ff1317 Mon Sep 17 00:00:00 2001 From: Tunji Ruwase Date: Sat, 31 Jul 2021 15:22:06 +0000 Subject: [PATCH 1/2] Use mpu in DeepSpeedConfig() call --- .../runtime/zero/partition_parameters.py | 62 ++++++++++--------- 1 file changed, 34 insertions(+), 28 deletions(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 6ef87f9e00aa..5856d0277e8d 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -215,11 +215,16 @@ def recurse(cl): # Inserts _post_init_method at the end of init method # for all sub classes of torch.nn.Module class InsertPostInitMethodToModuleSubClasses(object): - def __init__(self, enabled=True, mem_efficient_linear=True, config=None, dtype=None): + def __init__(self, + enabled=True, + mem_efficient_linear=True, + ds_config=None, + dtype=None): self.mem_efficient_linear = mem_efficient_linear self.enabled = enabled - self._set_dtype(config, dtype) - assert self.dtype in [torch.half, torch.float], f"Invalid data type {self.dtype}, allowed values are [torch.half, torch.float]" + self._set_dtype(ds_config, dtype) + assert self.dtype in [ + torch.half, torch.float], f"Invalid data type {self.dtype}, allowed values are [torch.half, torch.float]" def __enter__(self): if not self.enabled: @@ -287,8 +292,8 @@ def _disable_class(cls): torch.Tensor.__new__ = torch.Tensor.__old_new__ torch.empty = _orig_torch_empty - #un doing it here will undo it during training - #if self.mem_efficient_linear: + # un doing it here will undo it during training + # if self.mem_efficient_linear: # torch.nn.functional.linear = self.linear_bk # if self.mem_efficient_linear: # torch.nn.functional.linear = self.linear_bk @@ -303,8 +308,7 @@ def _post_init_method(self, module): def _set_dtype(self, ds_config, dtype): if ds_config is not None and dtype is None: - _ds_config = DeepSpeedConfig(ds_config) - self.dtype = torch.half if _ds_config.fp16_enabled else torch.float + self.dtype = torch.half if ds_config.fp16_enabled else torch.float elif dtype is None: self.dtype = torch.half else: @@ -323,7 +327,8 @@ def __init__(self, pin_memory=False, config=None, enabled=True, - dtype=None): + dtype=None, + mpu=None): """A context to enable massive model construction for training with ZeRO-3. Models are automatically partitioned (or, sharded) across the system and converted to half precision. @@ -349,6 +354,7 @@ def __init__(self, effect. Defaults to ``True``. dtype (``dtype``, optional): Can be used to change the data type of the parameters. Supported options are ``torch.half`` and ``torch.float``. Defaults to ``None`` + mpu (``object``, optional): A model parallelism unit object that implements get_{model,data}_parallel_{rank,group,wolrd_size} This context accelerates model initialization and enables models that are too large to allocate in their entirety in CPU memory. It has the @@ -420,9 +426,10 @@ def get_model(): model = deepspeed.zero.Init(module=model) """ + _ds_config = DeepSpeedConfig(config, mpu) if config is not None else None super().__init__(enabled=enabled, mem_efficient_linear=mem_efficient_linear, - config=config, + ds_config=_ds_config, dtype=dtype) if not torch.distributed.is_initialized(): init_distributed() @@ -435,21 +442,20 @@ def get_model(): self.rank = torch.distributed.get_rank(group=self.ds_process_group) self.world_size = torch.distributed.get_world_size(group=self.ds_process_group) - #Local device is the device where the parameters are consumed - #It is the device where parameters are fully instantiated using allgather + # Local device is the device where the parameters are consumed + # It is the device where parameters are fully instantiated using allgather self.local_device = torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"])) - self._validate_remote_device(remote_device, config) + self._validate_remote_device(remote_device, _ds_config) - #Remote device is the device where parameter partiitons are stored - #It can be same as local_device or it could be CPU or NVMe. + # Remote device is the device where parameter partiitons are stored + # It can be same as local_device or it could be CPU or NVMe. self.remote_device = self.local_device if remote_device is None else remote_device self.pin_memory = pin_memory if ( self.remote_device == OFFLOAD_CPU_DEVICE) else False # Enable fp16 param swapping to NVMe if self.remote_device == OFFLOAD_NVME_DEVICE: - _ds_config = DeepSpeedConfig(config) self.param_swapper = AsyncPartitionedParameterSwapper(_ds_config) else: self.param_swapper = None @@ -463,22 +469,21 @@ def get_model(): self._convert_to_deepspeed_param(param) param.partition() - def _validate_remote_device(self, remote_device, ds_config): - if ds_config is not None: - _ds_config = DeepSpeedConfig(ds_config) + def _validate_remote_device(self, remote_device, _ds_config): + if _ds_config is not None: if remote_device in [None, OFFLOAD_CPU_DEVICE]: if _ds_config.zero_config.offload_param is not None: offload_param_device = _ds_config.zero_config.offload_param[ OFFLOAD_PARAM_DEVICE] assert offload_param_device != OFFLOAD_NVME_DEVICE, \ - f"{OFFLOAD_PARAM_DEVICE} in DeepSpeed Config cannot be {offload_param_device} if remote device is {remote_device}." + f"{OFFLOAD_PARAM_DEVICE} in DeepSpeed Config cannot be {offload_param_device} if remote device is {remote_device}." if remote_device == OFFLOAD_NVME_DEVICE: assert _ds_config.zero_config.offload_param is not None, \ - f'{OFFLOAD_PARAM} must be defined in DeepSpeed Config if remote device is {OFFLOAD_NVME_DEVICE}.' + f'{OFFLOAD_PARAM} must be defined in DeepSpeed Config if remote device is {OFFLOAD_NVME_DEVICE}.' assert _ds_config.zero_config.offload_param[OFFLOAD_PARAM_NVME_PATH] is not None, \ - f'{OFFLOAD_PARAM_NVME_PATH} in DeepSpeed Config cannot be None if remote device is {OFFLOAD_NVME_DEVICE}' + f'{OFFLOAD_PARAM_NVME_PATH} in DeepSpeed Config cannot be None if remote device is {OFFLOAD_NVME_DEVICE}' def _post_init_method(self, module): #see_memory_usage(f"Before converting parmas in {module.__class__.__name__}", force=False) @@ -624,7 +629,7 @@ def _ensure_availability_of_partitioned_params(self, params): def _all_gather(self, param_list, async_op=False, hierarchy=None): - #fetches from nvme if the partition is not available and in nvme + # fetches from nvme if the partition is not available and in nvme self._ensure_availability_of_partitioned_params(param_list) handles = [] @@ -651,10 +656,10 @@ def _all_gather(self, param_list, async_op=False, hierarchy=None): def _partition(self, param_list, force=False, has_been_updated=False): for param in param_list: #print_rank_0(f"Before Partitioning Param {param.ds_id}") - #self._param_status(param) + # self._param_status(param) self._partition_param(param, has_been_updated=has_been_updated) param.ds_status = ZeroParamStatus.NOT_AVAILABLE - #if param.ds_tensor is not None: + # if param.ds_tensor is not None: # assert id(param.data) == id(param.ds_tensor.data), \ # "After the parameters are initially partitioned, make sure we are not recreating the partition." #print_rank_0(f"After Partitioning Param {param.ds_id}") @@ -678,7 +683,7 @@ def _partition_param(self, param, buffer=None, has_been_updated=False): # if numel in empty_buffers: # empty_buffers[numel].append(buffer) - #if torch.distributed.get_rank(): + # if torch.distributed.get_rank(): # print(f"Releasing {param.data.numel()}") if param.ds_tensor is not None and not has_been_updated: @@ -687,7 +692,7 @@ def _partition_param(self, param, buffer=None, has_been_updated=False): see_memory_usage( f'Before partitioning param {param.ds_id} {param.shape}', force=False) - #param.data does not store anything meaningful in partitioned state + # param.data does not store anything meaningful in partitioned state param.data = torch.ones(1, dtype=self.dtype).to(param.device) see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}', force=False) @@ -765,7 +770,7 @@ def _partition_param(self, param, buffer=None, has_been_updated=False): #param.data = param.ds_tensor.data - #param.data does not store anything meaningful in partitioned state + # param.data does not store anything meaningful in partitioned state see_memory_usage(f'Before partitioning param {param.ds_id} {param.shape}', force=False) @@ -1002,7 +1007,8 @@ def _partition_gradient(self, param, partition_buffer=None, accumulate=False): dtype=param.dtype, device=param.device) else: - assert partition_buffer.numel() >= partition_size, f"The partition buffer size {partition_buffer.numel()} should match the size of param.ds_tensor {partition_size}" + assert partition_buffer.numel( + ) >= partition_size, f"The partition buffer size {partition_buffer.numel()} should match the size of param.ds_tensor {partition_size}" rank = torch.distributed.get_rank(group=self.ds_process_group) start = partition_size * rank From b1bd1c3c8008ba2720d8e32cb6fad0e913255ab2 Mon Sep 17 00:00:00 2001 From: Tunji Ruwase Date: Mon, 9 Aug 2021 13:40:15 +0000 Subject: [PATCH 2/2] Improve argument naming --- deepspeed/runtime/zero/partition_parameters.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 5856d0277e8d..7dc9076eed3e 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -325,6 +325,7 @@ def __init__(self, mem_efficient_linear=True, remote_device=None, pin_memory=False, + config_dict_or_path=None, config=None, enabled=True, dtype=None, @@ -348,8 +349,9 @@ def __init__(self, pin_memory (bool, optional): Potentially increase performance by using pinned memory for model weights. ``remote_device`` must be ``"cpu"``. Defaults to ``False``. - config (``json file`` or dict, optional): If provided, provides configuration + config_dict_or_path (dict or ``json file``, optional): If provided, provides configuration for swapping fp16 params to NVMe. + config (dict or ``json file``, optional): Deprecated, use config_dict_or_path instead. enabled (bool, optional): If ``False``, this context has no effect. Defaults to ``True``. dtype (``dtype``, optional): Can be used to change the data type of the parameters. @@ -426,7 +428,8 @@ def get_model(): model = deepspeed.zero.Init(module=model) """ - _ds_config = DeepSpeedConfig(config, mpu) if config is not None else None + _ds_config = DeepSpeedConfig(config_dict_or_path, + mpu) if config_dict_or_path is not None else None super().__init__(enabled=enabled, mem_efficient_linear=mem_efficient_linear, ds_config=_ds_config,