Skip to content
33 changes: 20 additions & 13 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,14 @@ def recurse(cl):
# Inserts _post_init_method at the end of init method
# for all sub classes of torch.nn.Module
class InsertPostInitMethodToModuleSubClasses(object):
def __init__(self, enabled=True, mem_efficient_linear=True, config=None, dtype=None):
def __init__(self,
enabled=True,
mem_efficient_linear=True,
ds_config=None,
dtype=None):
self.mem_efficient_linear = mem_efficient_linear
self.enabled = enabled
self._set_dtype(config, dtype)
self._set_dtype(ds_config, dtype)
assert self.dtype in [torch.half, torch.float], f"Invalid data type {self.dtype}, allowed values are [torch.half, torch.float]"

def __enter__(self):
Expand Down Expand Up @@ -327,8 +331,7 @@ def _post_init_method(self, module):

def _set_dtype(self, ds_config, dtype):
if ds_config is not None and dtype is None:
_ds_config = DeepSpeedConfig(ds_config)
self.dtype = torch.half if _ds_config.fp16_enabled else torch.float
self.dtype = torch.half if ds_config.fp16_enabled else torch.float
elif dtype is None:
self.dtype = torch.half
else:
Expand All @@ -345,9 +348,11 @@ def __init__(self,
mem_efficient_linear=True,
remote_device=None,
pin_memory=False,
config_dict_or_path=None,
config=None,
enabled=True,
dtype=None):
dtype=None,
mpu=None):
"""A context to enable massive model construction for training with
ZeRO-3. Models are automatically partitioned (or, sharded) across the
system and converted to half precision.
Expand All @@ -367,12 +372,14 @@ def __init__(self,
pin_memory (bool, optional): Potentially increase performance by
using pinned memory for model weights. ``remote_device`` must be
``"cpu"``. Defaults to ``False``.
config_dict_or_path (dict or ``json file``, optional): If provided, provides configuration.
config (``json file`` or dict, optional): If provided, provides configuration
for swapping fp16 params to NVMe.
enabled (bool, optional): If ``False``, this context has no
effect. Defaults to ``True``.
dtype (``dtype``, optional): Can be used to change the data type of the parameters.
Supported options are ``torch.half`` and ``torch.float``. Defaults to ``None``
mpu (``object``, optional): A model parallelism unit object that implements get_{model,data}_parallel_{rank,group,wolrd_size}.

This context accelerates model initialization and enables models that
are too large to allocate in their entirety in CPU memory. It has the
Expand Down Expand Up @@ -443,10 +450,11 @@ def get_model():

model = deepspeed.zero.Init(module=model)
"""

_ds_config = DeepSpeedConfig(config_dict_or_path,
mpu) if config_dict_or_path is not None else None
super().__init__(enabled=enabled,
mem_efficient_linear=mem_efficient_linear,
config=config,
ds_config=_ds_config,
dtype=dtype)
if not torch.distributed.is_initialized():
init_distributed()
Expand All @@ -463,7 +471,7 @@ def get_model():
#It is the device where parameters are fully instantiated using allgather
self.local_device = torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"]))

self._validate_remote_device(remote_device, config)
self._validate_remote_device(remote_device, _ds_config)

#Remote device is the device where parameter partiitons are stored
#It can be same as local_device or it could be CPU or NVMe.
Expand All @@ -489,19 +497,18 @@ def get_model():

def _validate_remote_device(self, remote_device, ds_config):
if ds_config is not None:
_ds_config = DeepSpeedConfig(ds_config)
if remote_device in [None, OFFLOAD_CPU_DEVICE]:
if _ds_config.zero_config.offload_param is not None:
offload_param_device = _ds_config.zero_config.offload_param[
if ds_config.zero_config.offload_param is not None:
offload_param_device = ds_config.zero_config.offload_param[
OFFLOAD_PARAM_DEVICE]
assert offload_param_device != OFFLOAD_NVME_DEVICE, \
f"{OFFLOAD_PARAM_DEVICE} in DeepSpeed Config cannot be {offload_param_device} if remote device is {remote_device}."

if remote_device == OFFLOAD_NVME_DEVICE:
assert _ds_config.zero_config.offload_param is not None, \
assert ds_config.zero_config.offload_param is not None, \
f'{OFFLOAD_PARAM} must be defined in DeepSpeed Config if remote device is {OFFLOAD_NVME_DEVICE}.'

assert _ds_config.zero_config.offload_param[OFFLOAD_PARAM_NVME_PATH] is not None, \
assert ds_config.zero_config.offload_param[OFFLOAD_PARAM_NVME_PATH] is not None, \
f'{OFFLOAD_PARAM_NVME_PATH} in DeepSpeed Config cannot be None if remote device is {OFFLOAD_NVME_DEVICE}'

def _post_init_method(self, module):
Expand Down