Skip to content

Commit

Permalink
support sharding stage 2 (#62486)
Browse files Browse the repository at this point in the history
  • Loading branch information
LiYuRio authored Mar 11, 2024
1 parent f8fbbb5 commit ce5a3a8
Show file tree
Hide file tree
Showing 6 changed files with 353 additions and 15 deletions.
2 changes: 2 additions & 0 deletions python/paddle/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
shard_optimizer,
shard_scaler,
ShardingStage1,
ShardingStage2,
ShardingStage3,
to_static,
Strategy,
Expand Down Expand Up @@ -174,6 +175,7 @@
"shard_optimizer",
"shard_scaler",
"ShardingStage1",
"ShardingStage2",
"ShardingStage3",
"to_static",
"Strategy",
Expand Down
132 changes: 117 additions & 15 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,13 +584,14 @@ def get_placement_with_sharding(param, sharding_mesh_axis):
# for example, [Shard(0), Shard(1)], assert here in case
assert (
shard_axis == -1
), "The parameter can't be shard twice even in different mesh now."
), "The parameter can't be shard twice with sharding strategy even in different mesh now."
shard_axis = placement.get_dim()

placement_with_sharding = None
for dim in range(param.ndim):
if dim != shard_axis:
placement_with_sharding = dist.Shard(dim)
break

new_placements = param.placements
if placement_with_sharding is not None:
Expand Down Expand Up @@ -626,10 +627,17 @@ def __init__(self, optimizer, shard_fn=None):
self._sharding_mesh_axis = None
self._sharding_degree = None

if isinstance(self._shard_fn, (ShardingStage1, ShardingStage3)):
if isinstance(
self._shard_fn, (ShardingStage1, ShardingStage2, ShardingStage3)
):
self._set_and_check_sharding_prop_from_param()
self._shard_fn._set_sharding_mesh_axis(self._sharding_mesh_axis)

# Invoke register hook for sharding stage 2 strategy
if isinstance(self._shard_fn, ShardingStage2):
for param in self._inner_opt._parameter_list:
self._shard_fn._register_hook_for_param_grad(param)

# Invoke shard_parameter in sharding stage 3 strategy
if isinstance(self._shard_fn, ShardingStage3):
for param in self._inner_opt._parameter_list:
Expand Down Expand Up @@ -835,10 +843,22 @@ def __getattr__(self, item):
return getattr(self._inner_opt, item)


class ShardingStage1:
class _ShardingStageBase:
def __init__(self, mesh):
self._mesh = mesh
self._sharding_mesh_axis = None

def _set_sharding_mesh_axis(self, sharding_mesh_axis):
self._sharding_mesh_axis = sharding_mesh_axis


class ShardingStage1(_ShardingStageBase):
"""
A builtin shard_fn for shard_optimizer interface, users can pass it to shard_optimizer to implement sharding optimization with stage 1.
Args:
mesh(paddle.distributed.ProcessMesh): The `ProcessMesh` object describes the Cartesian topology of the used processes.
Examples:
.. code-block:: python
Expand All @@ -860,7 +880,7 @@ class ShardingStage1:
>>> layer = MLP()
>>> batch = paddle.rand(shape=[8, 8])
>>> opt = paddle.optimizer.AdamW(parameters=layer.parameters())
>>> opt = dist.shard_optimizer(opt, dist.ShardingStage1())
>>> opt = dist.shard_optimizer(opt, dist.ShardingStage1(mesh))
>>> for _ in range(5):
>>> loss = layer(batch)
>>> loss.backward()
Expand All @@ -871,8 +891,7 @@ class ShardingStage1:
"""

def __init__(self, mesh):
self._mesh = mesh
self._sharding_mesh_axis = None
super().__init__(mesh)

def __call__(self, key, param, accumulator):
if param.is_dist():
Expand All @@ -893,11 +912,94 @@ def __call__(self, key, param, accumulator):
)
return accumulator

def _set_sharding_mesh_axis(self, sharding_mesh_axis):
self._sharding_mesh_axis = sharding_mesh_axis

class ShardingStage2(_ShardingStageBase):
"""
A builtin shard_fn for shard_optimizer interface, users can pass it to shard_optimizer to implement sharding optimization with stage 2.
Args:
mesh(paddle.distributed.ProcessMesh): The `ProcessMesh` object describes the Cartesian topology of the used processes.
Examples:
.. code-block:: python
>>> import paddle
>>> import paddle.distributed as dist
>>> mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
class ShardingStage3:
>>> class MLP(paddle.nn.Layer):
... def __init__(self):
... super().__init__()
... self.fc1 = paddle.nn.Linear(8, 8)
... self.fc2 = paddle.nn.Linear(8, 8)
...
... def forward(self, input):
... return self.fc2(self.fc1(input))
>>> # doctest: +REQUIRES(env:DISTRIBUTED)
>>> layer = MLP()
>>> batch = paddle.rand(shape=[8, 8])
>>> opt = paddle.optimizer.AdamW(parameters=layer.parameters())
>>> opt = dist.shard_optimizer(opt, dist.ShardingStage2(mesh))
>>> for _ in range(5):
>>> loss = layer(batch)
>>> loss.backward()
>>> opt.step()
>>> opt.clear_grad()
>>> # This case need to be executed in multi-card environment
>>> # python -m paddle.distributed.launch --gpus=0,1 {test_case}.py
"""

def __init__(self, mesh):
super().__init__(mesh)

def __call__(self, key, param, accumulator):
if param.is_dist():
# Only deal with momentum in optimizer, beta should be replicated cross param's mesh
if 'beta' not in key:
placements = get_placement_with_sharding(
param, self._sharding_mesh_axis
)
else:
placements = [
dist.Replicate()
for _ in range(len(param.process_mesh.shape))
]
return shard_tensor(
accumulator,
mesh=param.process_mesh,
placements=placements,
)
return accumulator

@staticmethod
def _grad_hook(grad):
# do reshard only if the grad is dist tensor and in partial status
if grad.is_dist():
partial_mesh_axis = None
for mesh_axis, placement in enumerate(grad.placements):
if isinstance(placement, dist.Partial):
partial_mesh_axis = mesh_axis
if partial_mesh_axis is not None:
new_placements = get_placement_with_sharding(
grad, partial_mesh_axis
)
return reshard(grad, grad.process_mesh, new_placements)

return grad

def _register_hook_for_param_grad(self, param):
if param.is_dense():
placements = []
for _ in range(len(self._mesh.shape)):
placements.append(dist.Replicate())
param._to_dist_(placements, self._mesh)

param.register_hook(ShardingStage2._grad_hook)


class ShardingStage3(_ShardingStageBase):
"""
A builtin shard_fn for shard_optimizer interface, users can pass it to shard_optimizer to implement sharding optimization with stage 3.
Expand Down Expand Up @@ -936,11 +1038,7 @@ class ShardingStage3:
"""

def __init__(self, mesh):
self._mesh = mesh
self._sharding_mesh_axis = None

def _set_sharding_mesh_axis(self, sharding_mesh_axis):
self._sharding_mesh_axis = sharding_mesh_axis
super().__init__(mesh)

def _shard_parameter(self, param):
if param.is_dense():
Expand Down Expand Up @@ -2000,6 +2098,10 @@ def to_static(
strategy.sharding.enable = True
strategy.sharding.stage = 1
strategy.sharding.degree = sharding_degree
elif isinstance(shard_fn, ShardingStage2):
strategy.sharding.enable = True
strategy.sharding.stage = 2
strategy.sharding.degree = sharding_degree
elif isinstance(shard_fn, ShardingStage3):
strategy.sharding.enable = True
strategy.sharding.stage = 3
Expand All @@ -2008,7 +2110,7 @@ def to_static(
shard_fn._unshard_parameter(param)
else:
raise NotImplementedError(
"Only sharding stage 1 and 3 can to_static for now. User-defined shard_fn and sharding stage 2 will be supported later."
"Only sharding stage 1, 2 and 3 can to_static for now. User-defined shard_fn will be supported later."
)

dist_model = DistModel(layer, loader, loss, optimizer, strategy)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import numpy as np
from auto_parallel.semi_auto_parallel_dist_to_static_api import (
DemoNet,
create_data_loader,
)

import paddle
import paddle.distributed as dist
from paddle import nn


class TestSemiAutoParallelShardingStage2:
def __init__(self):
self._backend = os.getenv("backend")
self._seed = eval(os.getenv("seed"))
self._mesh = dist.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"])

def check_tensor_eq(self, a, b, rtol=1e-05, atol=0, verbose=True):
np.testing.assert_allclose(a, b, rtol=rtol, atol=atol, verbose=verbose)

def shard_layer_fn(self, layer_name, layer, process_mesh):
layer.weight = dist.shard_tensor(
layer.weight, process_mesh, [dist.Shard(1)]
)
layer.bias = dist.shard_tensor(
layer.bias, process_mesh, [dist.Shard(0)]
)

def get_single_card_rst(self):
paddle.seed(self._seed)
linear = paddle.nn.Linear(10, 10)
batch = paddle.rand(shape=[10, 10])
opt = paddle.optimizer.AdamW(parameters=linear.parameters())
for _ in range(5):
loss = linear(batch)
loss.backward()
opt.step()
opt.clear_grad()
self.weight = linear.weight.numpy()
self.bias = linear.bias.numpy()

def test_sharding_stage_2_with_mp(self):
paddle.seed(self._seed)
linear = paddle.nn.Linear(10, 10)
linear = dist.shard_layer(linear, self._mesh, self.shard_layer_fn)
batch = paddle.rand(shape=[10, 10])
# shard the input by sharding degree
batch = dist.shard_tensor(batch, self._mesh, [dist.Shard(0)])
# shard optimizer with stage 1 fn
opt = paddle.optimizer.AdamW(parameters=linear.parameters())
opt = dist.shard_optimizer(opt, dist.ShardingStage2(self._mesh))
for _ in range(5):
loss = linear(batch)
loss.backward()
opt.step()
opt.clear_grad()
self.check_tensor_eq(self.weight, linear.weight.numpy())
self.check_tensor_eq(self.bias, linear.bias.numpy())

def test_sharding_stage_2_with_mp_to_static(self):
data_loader = create_data_loader()
layer = DemoNet(
self._mesh, "sharding_with_mp_demonet", shard_weight=True
)
opt = paddle.optimizer.SGD(
learning_rate=0.1, parameters=layer.parameters()
)
opt = dist.shard_optimizer(opt, dist.ShardingStage2(self._mesh))
loss_fn = nn.MSELoss()

dist_loader = dist.shard_dataloader(
dataloader=data_loader,
meshes=[self._mesh],
shard_dims=0,
)

dist_model = dist.to_static(layer, dist_loader, loss_fn, opt)

dist_model.train()
for epoch in range(2):
for batch_id, (image, label) in enumerate(dist_loader()):
loss = dist_model(image, label)

def run_test_case(self):
if self._backend == "cpu":
paddle.set_device("cpu")
elif self._backend == "gpu":
paddle.set_device("gpu:" + str(dist.get_rank()))
else:
raise ValueError("Only support cpu or gpu backend.")

self.get_single_card_rst()
self.test_sharding_stage_2_with_mp()
self.test_sharding_stage_2_with_mp_to_static()


if __name__ == '__main__':
TestSemiAutoParallelShardingStage2().run_test_case()
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ def test_sharding_stage_1_strategy(self):
user_defined_envs=envs,
)

def test_sharding_stage_2_strategy(self):
envs_list = test_base.gen_product_envs_list(
self._default_envs, self._changeable_envs
)
for envs in envs_list:
self.run_test_case(
"semi_auto_parallel_sharding_stage_2.py",
user_defined_envs=envs,
)

def test_sharding_stage_3_strategy(self):
envs_list = test_base.gen_product_envs_list(
self._default_envs, self._changeable_envs
Expand Down
Loading

0 comments on commit ce5a3a8

Please sign in to comment.