Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 38 additions & 29 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,16 @@ def recurse(cl):
# Inserts _post_init_method at the end of init method
# for all sub classes of torch.nn.Module
class InsertPostInitMethodToModuleSubClasses(object):
def __init__(self, enabled=True, mem_efficient_linear=True, config=None, dtype=None):
def __init__(self,
enabled=True,
mem_efficient_linear=True,
ds_config=None,
dtype=None):
self.mem_efficient_linear = mem_efficient_linear
self.enabled = enabled
self._set_dtype(config, dtype)
assert self.dtype in [torch.half, torch.float], f"Invalid data type {self.dtype}, allowed values are [torch.half, torch.float]"
self._set_dtype(ds_config, dtype)
assert self.dtype in [
torch.half, torch.float], f"Invalid data type {self.dtype}, allowed values are [torch.half, torch.float]"

def __enter__(self):
if not self.enabled:
Expand Down Expand Up @@ -287,8 +292,8 @@ def _disable_class(cls):
torch.Tensor.__new__ = torch.Tensor.__old_new__
torch.empty = _orig_torch_empty

#un doing it here will undo it during training
#if self.mem_efficient_linear:
# un doing it here will undo it during training
# if self.mem_efficient_linear:
# torch.nn.functional.linear = self.linear_bk
# if self.mem_efficient_linear:
# torch.nn.functional.linear = self.linear_bk
Expand All @@ -303,8 +308,7 @@ def _post_init_method(self, module):

def _set_dtype(self, ds_config, dtype):
if ds_config is not None and dtype is None:
_ds_config = DeepSpeedConfig(ds_config)
self.dtype = torch.half if _ds_config.fp16_enabled else torch.float
self.dtype = torch.half if ds_config.fp16_enabled else torch.float
elif dtype is None:
self.dtype = torch.half
else:
Expand All @@ -321,9 +325,11 @@ def __init__(self,
mem_efficient_linear=True,
remote_device=None,
pin_memory=False,
config_dict_or_path=None,
config=None,
enabled=True,
dtype=None):
dtype=None,
mpu=None):
"""A context to enable massive model construction for training with
ZeRO-3. Models are automatically partitioned (or, sharded) across the
system and converted to half precision.
Expand All @@ -343,12 +349,14 @@ def __init__(self,
pin_memory (bool, optional): Potentially increase performance by
using pinned memory for model weights. ``remote_device`` must be
``"cpu"``. Defaults to ``False``.
config (``json file`` or dict, optional): If provided, provides configuration
config_dict_or_path (dict or ``json file``, optional): If provided, provides configuration
for swapping fp16 params to NVMe.
config (dict or ``json file``, optional): Deprecated, use config_dict_or_path instead.
enabled (bool, optional): If ``False``, this context has no
effect. Defaults to ``True``.
dtype (``dtype``, optional): Can be used to change the data type of the parameters.
Supported options are ``torch.half`` and ``torch.float``. Defaults to ``None``
mpu (``object``, optional): A model parallelism unit object that implements get_{model,data}_parallel_{rank,group,wolrd_size}

This context accelerates model initialization and enables models that
are too large to allocate in their entirety in CPU memory. It has the
Expand Down Expand Up @@ -420,9 +428,11 @@ def get_model():
model = deepspeed.zero.Init(module=model)
"""

_ds_config = DeepSpeedConfig(config_dict_or_path,
mpu) if config_dict_or_path is not None else None
super().__init__(enabled=enabled,
mem_efficient_linear=mem_efficient_linear,
config=config,
ds_config=_ds_config,
dtype=dtype)
if not torch.distributed.is_initialized():
init_distributed()
Expand All @@ -435,21 +445,20 @@ def get_model():
self.rank = torch.distributed.get_rank(group=self.ds_process_group)
self.world_size = torch.distributed.get_world_size(group=self.ds_process_group)

#Local device is the device where the parameters are consumed
#It is the device where parameters are fully instantiated using allgather
# Local device is the device where the parameters are consumed
# It is the device where parameters are fully instantiated using allgather
self.local_device = torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"]))

self._validate_remote_device(remote_device, config)
self._validate_remote_device(remote_device, _ds_config)

#Remote device is the device where parameter partiitons are stored
#It can be same as local_device or it could be CPU or NVMe.
# Remote device is the device where parameter partiitons are stored
# It can be same as local_device or it could be CPU or NVMe.
self.remote_device = self.local_device if remote_device is None else remote_device
self.pin_memory = pin_memory if (
self.remote_device == OFFLOAD_CPU_DEVICE) else False

# Enable fp16 param swapping to NVMe
if self.remote_device == OFFLOAD_NVME_DEVICE:
_ds_config = DeepSpeedConfig(config)
self.param_swapper = AsyncPartitionedParameterSwapper(_ds_config)
else:
self.param_swapper = None
Expand All @@ -463,22 +472,21 @@ def get_model():
self._convert_to_deepspeed_param(param)
param.partition()

def _validate_remote_device(self, remote_device, ds_config):
if ds_config is not None:
_ds_config = DeepSpeedConfig(ds_config)
def _validate_remote_device(self, remote_device, _ds_config):
if _ds_config is not None:
if remote_device in [None, OFFLOAD_CPU_DEVICE]:
if _ds_config.zero_config.offload_param is not None:
offload_param_device = _ds_config.zero_config.offload_param[
OFFLOAD_PARAM_DEVICE]
assert offload_param_device != OFFLOAD_NVME_DEVICE, \
f"{OFFLOAD_PARAM_DEVICE} in DeepSpeed Config cannot be {offload_param_device} if remote device is {remote_device}."
f"{OFFLOAD_PARAM_DEVICE} in DeepSpeed Config cannot be {offload_param_device} if remote device is {remote_device}."

if remote_device == OFFLOAD_NVME_DEVICE:
assert _ds_config.zero_config.offload_param is not None, \
f'{OFFLOAD_PARAM} must be defined in DeepSpeed Config if remote device is {OFFLOAD_NVME_DEVICE}.'
f'{OFFLOAD_PARAM} must be defined in DeepSpeed Config if remote device is {OFFLOAD_NVME_DEVICE}.'

assert _ds_config.zero_config.offload_param[OFFLOAD_PARAM_NVME_PATH] is not None, \
f'{OFFLOAD_PARAM_NVME_PATH} in DeepSpeed Config cannot be None if remote device is {OFFLOAD_NVME_DEVICE}'
f'{OFFLOAD_PARAM_NVME_PATH} in DeepSpeed Config cannot be None if remote device is {OFFLOAD_NVME_DEVICE}'

def _post_init_method(self, module):
#see_memory_usage(f"Before converting parmas in {module.__class__.__name__}", force=False)
Expand Down Expand Up @@ -624,7 +632,7 @@ def _ensure_availability_of_partitioned_params(self, params):

def _all_gather(self, param_list, async_op=False, hierarchy=None):

#fetches from nvme if the partition is not available and in nvme
# fetches from nvme if the partition is not available and in nvme
self._ensure_availability_of_partitioned_params(param_list)

handles = []
Expand All @@ -651,10 +659,10 @@ def _all_gather(self, param_list, async_op=False, hierarchy=None):
def _partition(self, param_list, force=False, has_been_updated=False):
for param in param_list:
#print_rank_0(f"Before Partitioning Param {param.ds_id}")
#self._param_status(param)
# self._param_status(param)
self._partition_param(param, has_been_updated=has_been_updated)
param.ds_status = ZeroParamStatus.NOT_AVAILABLE
#if param.ds_tensor is not None:
# if param.ds_tensor is not None:
# assert id(param.data) == id(param.ds_tensor.data), \
# "After the parameters are initially partitioned, make sure we are not recreating the partition."
#print_rank_0(f"After Partitioning Param {param.ds_id}")
Expand All @@ -678,7 +686,7 @@ def _partition_param(self, param, buffer=None, has_been_updated=False):
# if numel in empty_buffers:
# empty_buffers[numel].append(buffer)

#if torch.distributed.get_rank():
# if torch.distributed.get_rank():
# print(f"Releasing {param.data.numel()}")
if param.ds_tensor is not None and not has_been_updated:

Expand All @@ -687,7 +695,7 @@ def _partition_param(self, param, buffer=None, has_been_updated=False):
see_memory_usage(
f'Before partitioning param {param.ds_id} {param.shape}',
force=False)
#param.data does not store anything meaningful in partitioned state
# param.data does not store anything meaningful in partitioned state
param.data = torch.ones(1, dtype=self.dtype).to(param.device)
see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}',
force=False)
Expand Down Expand Up @@ -765,7 +773,7 @@ def _partition_param(self, param, buffer=None, has_been_updated=False):

#param.data = param.ds_tensor.data

#param.data does not store anything meaningful in partitioned state
# param.data does not store anything meaningful in partitioned state

see_memory_usage(f'Before partitioning param {param.ds_id} {param.shape}',
force=False)
Expand Down Expand Up @@ -1002,7 +1010,8 @@ def _partition_gradient(self, param, partition_buffer=None, accumulate=False):
dtype=param.dtype,
device=param.device)
else:
assert partition_buffer.numel() >= partition_size, f"The partition buffer size {partition_buffer.numel()} should match the size of param.ds_tensor {partition_size}"
assert partition_buffer.numel(
) >= partition_size, f"The partition buffer size {partition_buffer.numel()} should match the size of param.ds_tensor {partition_size}"

rank = torch.distributed.get_rank(group=self.ds_process_group)
start = partition_size * rank
Expand Down