Skip to content

Commit

Permalink
shard_optimizer and ShardOptimizer API
Browse files Browse the repository at this point in the history
  • Loading branch information
FeixLiu committed Nov 27, 2023
1 parent c7968ac commit e6e123c
Show file tree
Hide file tree
Showing 5 changed files with 476 additions and 8 deletions.
4 changes: 4 additions & 0 deletions python/paddle/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@
dtensor_from_fn,
reshard,
shard_layer,
shard_optimizer,
ShardOptimizer,
)

from .fleet import BoxPSDataset # noqa: F401
Expand Down Expand Up @@ -157,4 +159,6 @@
"Shard",
"Replicate",
"Partial",
"shard_optimizer",
"ShardOptimizer",
]
242 changes: 241 additions & 1 deletion python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import defaultdict
from typing import Callable

import paddle
Expand Down Expand Up @@ -406,3 +406,243 @@ def replicate_layer_params_and_buffers(
"`paddle.distributed.shard_layer` only supports dynamic graph mode "
"now. It will be supported for static graph mode later."
)


def shard_optimizer(
optimizer: paddle.optimizer.Optimizer,
process_mesh: dist.ProcessMesh,
parameter_list: list = None,
shard_fn: Callable = None,
) -> paddle.optimizer.Optimizer:
"""
Create accumulators for the optimizer and then convert them to DistTensor.
The `shard_fn` should have the following signature:
def shard_fn(accumulator_name, param, accumulator, process_mesh) -> None
Args:
optimizer (paddle.optimizer.Optimizer): The optimizer to be sharded.
process_mesh (paddle.distributed.ProcessMesh): The `ProcessMesh` information
to be place the optimzier status.
parameter_list (optional, list): A list of parameters that should create accumulators.
If not specified, will crate accumulators for all parameters.
shard_fn (optional, Callable): The function to shard accumulators across the `process_mesh`. If not
specified, we replicate all accumulators across the `process_mesh` on the unshard dim.
Returns:
Optimzier: An optimzier which all accumulators have been sharded.
Examples:
.. code-block:: python
>>> import paddle
>>> import paddle.distributed as dist
>>> mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
>>> class MLP(paddle.nn.Layer):
... def __init__(self):
... super().__init__()
... self.fc1 = paddle.nn.Linear(8, 8)
... self.fc2 = paddle.nn.Linear(8, 8)
...
... def forward(self, input):
... return self.fc2(self.fc1(input))
>>> layer = MLP()
>>> opt = paddle.optimizer.AdamW(parameters=layer.parameters())
>>> opt = dist.shard_optimizer(opt, mesh)
>>> # This case need to be executed in multi-card environment
>>> # python -m paddle.distributed.launch --gpus=0,1 {test_case}.py
"""
assert (
paddle.in_dynamic_mode()
), "`paddle.distributed.shard_optimizer` only supports dynamic mode for now."
assert optimizer is not None, "The argument `optimizer` cannot be empty."
assert isinstance(
optimizer, paddle.optimizer.AdamW
), "`paddle.distributed.shard_optimizer` only supports AdamW optimizer for now."
assert (
process_mesh is not None
), "The argument `process_mesh` cannot be empty."
assert isinstance(
process_mesh, dist.ProcessMesh
), "The argument `process_mesh` is not `dist.ProcessMesh` type."

target_block = paddle.base.framework.default_main_program().global_block()
optimizer.helper = paddle.base.layer_helper.LayerHelper(
optimizer.__class__.__name__
)

def shard_accumulator(param: paddle.Tensor) -> None:
# create the accumulators
optimizer._create_accumulators(target_block, [param])

target_name = param.name
if param.name in optimizer._master_weights.keys():
target_name = optimizer._master_weights[param.name].name

# shard the accumulators
for key in optimizer._accumulators.keys():
accumulator = optimizer._accumulators[key][target_name]
if accumulator.is_dist():
continue
if shard_fn is not None:
optimizer._accumulators[key][target_name] = shard_fn(
key, param, accumulator, process_mesh
)
else:
placements = [
dist.Replicate() for _ in range(len(process_mesh.shape))
]
if (
'beta' not in key
and param.is_dist()
and param.process_mesh == process_mesh
):
# If param is a dist tensor and the param's process_mesh is identical
# with the target process mesh, then should keep the shard info
# for accumulators except beta.
placements = param.placements
optimizer._accumulators[key][target_name] = shard_tensor(
accumulator, mesh=process_mesh, placements=placements
)

if parameter_list is not None:
for p in parameter_list:
shard_accumulator(p)
else:
if not isinstance(optimizer._parameter_list[0], dict):
for p in optimizer._parameter_list:
shard_accumulator(p)
else:
for param_group in optimizer._param_groups:
params = param_group['params']
for p in params:
shard_accumulator(p)
return optimizer


class ShardOptimizer:
"""
Warp the global view optimizer to distributed view.
Args:
optimizer (paddle.optimizer.Optimizer): The optimizer to be sharded.
process_mesh (paddle.distributed.ProcessMesh): The `ProcessMesh` information
to be place the optimzier status.
shard_dims_name (optional, str): One of process mesh dim to be used as sharding parallel.
If not specified, the class won't do sharding parallel, will only pass the
params' shard status to accumulators. The default value is None.
gather_output (optional, boolean): Whether to allgather the sharded param on sharding parallel view.
Only valid when do the sharding parallel. The default value is True.
Method:
step(): same with optimzier.step()
clear_grad(set_to_zero): same with optimzier.clear_grad()
Examples:
.. code-block:: python
>>> import paddle
>>> import paddle.distributed as dist
>>> mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
>>> class MLP(paddle.nn.Layer):
... def __init__(self):
... super().__init__()
... self.fc1 = paddle.nn.Linear(8, 8)
... self.fc2 = paddle.nn.Linear(8, 8)
...
... def forward(self, input):
... return self.fc2(self.fc1(input))
>>> layer = MLP()
>>> batch = paddle.rand(shape=[8, 8])
>>> opt = paddle.optimizer.AdamW(parameters=layer.parameters())
>>> opt = dist.ShardOptimizer(opt, mesh)
>>> for _ in range(5):
>>> loss = layer(batch)
>>> loss.backward()
>>> opt.step()
>>> opt.clear_grad()
>>> # This case need to be executed in multi-card environment
>>> # python -m paddle.distributed.launch --gpus=0,1 {test_case}.py
"""

def __init__(
self,
optimizer,
process_mesh,
sharding_mesh_axis=None,
gather_output=True,
):
assert (
paddle.in_dynamic_mode()
), "`paddle.distributed.ShardOptimizer` only supports dynamic mode for now."
assert (
optimizer is not None
), "The argument `optimizer` cannot be empty."
assert isinstance(
optimizer, paddle.optimizer.AdamW
), "`paddle.distributed.ShardOptimizer` only supports AdamW optimizer for now."
assert (
process_mesh is not None
), "The argument `process_mesh` cannot be empty."
assert isinstance(
process_mesh, dist.ProcessMesh
), "The argument `process_mesh` is not `dist.ProcessMesh` type."

# TODO(Yuang Liu): support sharding parallel
assert sharding_mesh_axis is None
# if shard_dims_name is not None:
# assert isinstance(
# shard_dims_name, str
# ), "The argument `shard_dims_name` is not `str` type."
# assert (
# shard_dims_name in process_mesh.dim_names
# ), "The `shard_dims_name` should in `process_mesh`."

self.optimizer = optimizer
self.process_mesh = process_mesh
self.gather_output = gather_output
self.sharding_mesh_axis = sharding_mesh_axis

def clear_grad(self, set_to_zero=True):
self.optimizer.clear_grad(set_to_zero)

def _shard_fn(self, accumulator_name, param, accumulator, process_mesh):
pass

def step(self):
shard_fn = (
self._shard_fn if self.sharding_mesh_axis is not None else None
)
if not isinstance(self.optimizer._parameter_list[0], dict):
params_grads = []
parameter_list = []
for param in self.optimizer._parameter_list:
if param.stop_gradient:
continue
if param._grad_ivar() is not None:
grad_var = param._grad_ivar()
parameter_list.append(param)
params_grads.append((param, grad_var))
self.optimizer = shard_optimizer(
optimizer=self.optimizer,
process_mesh=self.process_mesh,
parameter_list=parameter_list,
shard_fn=shard_fn,
)
self.optimizer._apply_optimize(
loss=None, startup_program=None, params_grads=params_grads
)
else:
for param_group in self.optimizer._param_groups:
params_grads = defaultdict(lambda: [])
parameter_list = []
for param in param_group['params']:
if param.stop_gradient:
continue
if param._grad_ivar() is not None:
grad_var = param._grad_ivar()
parameter_list.append(param)
params_grads['params'].append((param, grad_var))
params_grads.update(
{k: v for k, v in param_group.items() if k != 'params'}
)
self.optimizer = shard_optimizer(
optimizer=self.optimizer,
process_mesh=self.process_mesh,
parameter_list=parameter_list,
shard_fn=shard_fn,
)
self.optimizer._apply_optimize(
loss=None, startup_program=None, params_grads=params_grads
)
12 changes: 5 additions & 7 deletions test/auto_parallel/semi_auto_parallel_shard_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,12 @@ def test_adamw_mp(self):
opt.clear_grad()
for key in opt._accumulators.keys():
for k, v in opt._accumulators[key].items():
if 'momentum' in key:
if 'moment' in key:
assert opt._accumulators[key][k].is_dist()
if 'w' in k:
assert opt._accumulators[key][k].shape == [10, 10]
assert opt._accumulators[key][k]._local_shape == [10, 5]
else:
assert opt._accumulators[key][k].shape == [10]
assert opt._accumulators[key][k]._local_shape == [5]
assert (
opt._accumulators[key][k].shape[-1]
== opt._accumulators[key][k]._local_shape[-1] * 2
)
self.check_tensor_eq(self.weight, linear.weight.numpy())
self.check_tensor_eq(self.bias, linear.bias.numpy())

Expand Down
Loading

0 comments on commit e6e123c

Please sign in to comment.