Skip to content

Commit

Permalink
fix_group_sharded_note
Browse files Browse the repository at this point in the history
  • Loading branch information
Baibaifan committed Mar 13, 2022
1 parent ec09ef2 commit b6bd632
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -912,7 +912,6 @@ def _device2cpu(trans_param, convert_dtype=False):

def _cpu2device(param):
tmp_p = param.fw_storage.cuda(DEV_ID)
param.fw_storage._clear()
if tmp_p.dtype == Type.fp32.value and param2dtype[
param.name] == Type.fp16.value:
tmp_p = paddle.cast(tmp_p, Type.fp16.value)
Expand Down
26 changes: 15 additions & 11 deletions python/paddle/distributed/sharding/group_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,20 @@ def group_sharded_parallel(model,
segment_size=2**20,
sync_comm=False):
"""
Use this module to configure and wrap up the parameters of the group shared module.
Use group_sharded_parallel can perform group shared configuration on the model, optimizer and GradScaler. Level has three string options, 'os', 'os_g' and 'p_g_os' corresponds to three different usage scenarios: optimizer state segmentation, optimizer state + gradient segmentation, and parameter + gradient + optimizer state segmentation.
Usually, optimizer state + gradient segmentation is actually a re optimization of optimizer state segmentation, so optimizer state + gradient segmentation can be used to realize optimizer state segmentation.
Args:
model (Layer): The layer to be wrapped with group_sharded_parallel.
optimizer (Optimizer): The optimizer to be wrapped with group_sharded_parallel.
level (str): The different level of the group sharded. Such as `os`, `os_g`, `p_g_os`.
scaler (GradScaler, optional): The scaler to be wrapped with group_sharded_parallel. Defaults to None.
group (Group, optional): The group instance. Defaults to None.d
offload (bool, optional): Whether to perform optimizer state and gradient transfer CPU. Defaults to False.
sync_buffers (bool, optional): Whether to broadcast model buffers. Defaults to False.
buffer_max_size (int, optional): The max size of the buffer used to integrate gradient in `os_g`. Defaults to 2**23.
segment_size (int, optional): The smallest size of parameter to be sharded in `p_g_os`. Defaults to 2**20.
sync_comm (bool, optional): Whether to use synchronous communication, only in `p_g_os` used. Defaults to False.
scaler (GradScaler, optional): If AMP is used, you need to pass GradScaler. Defaults to None, indicating that GradScaler is not used.
group (Group, optional): The group instance. Defaults to None, indicating that the default environment group is used.
offload (bool, optional): Whether to use the offload function. Defaults to False, which means that the offload function is not used.
sync_buffers (bool, optional): Whether to broadcast model buffers. It is generally used when there are registered model buffers. Defaults to False, indicating that model buffers are not used.
buffer_max_size (int, optional): The max size of the buffer used to integrate gradient in `os_g`. The larger the size, the more GPU memory will be used. Defaults to 2**23, which means that the dimension of the buffer is 2**23.
segment_size (int, optional): The smallest size of parameter to be sharded in `p_g_os`. Defaults to 2**20, indicating that the dimension of the minimum segmented parameter is 2**20.
sync_comm (bool, optional): Whether to use synchronous communication, only in `p_g_os` used. Defaults to False, indicating that asynchronous communication is used.
Returns:
model: A wrapper for group sharded given model.
Expand Down Expand Up @@ -101,7 +102,7 @@ def group_sharded_parallel(model,
def check_dtype(param):
return param.dtype == paddle.float16

params_fp16 = filter(check_dtype, model.parameters())
params_fp16 = list(filter(check_dtype, model.parameters()))
if scaler is None and len(params_fp16) > 0:
raise ValueError("Please enter the correct scaler.")
# convert model/optimizer/scaler
Expand Down Expand Up @@ -146,10 +147,13 @@ def save_group_sharded_model(model, output, optimizer=None):
"""
Group sharded encapsulated model and optimizer state saving module.
.. note::
If using save_group_sharded_model saves the model. When loading again, you need to set the model or optimizer state before using group_sharded_parallel.
Args:
model (Layer): A wrapper for group sharded given model.
output (str): Save directory.
optimizer (Optimizer, optional): Group sharded encapsulated optimizer. Defaults to None.
optimizer (Optimizer, optional): Group sharded encapsulated optimizer. Defaults to None, indicating that the optimizer state is not saved.
Examples:
.. code-block:: python
Expand Down Expand Up @@ -182,7 +186,7 @@ def save_group_sharded_model(model, output, optimizer=None):
optimizer.clear_grad()
# save model and optimizer state_dict
save_group_sharded_model(model, optimizeroutput=output_dir)
save_group_sharded_model(model, optimizer, output=output_dir)
"""
logger_.info(
"==========Begin to save group sharded model and optimizer==========")
Expand Down

0 comments on commit b6bd632

Please sign in to comment.