Skip to content

Commit

Permalink
[auto parallel] Add pp lazy init, bug fix for xavier (#60441)
Browse files Browse the repository at this point in the history
  • Loading branch information
FeixLiu authored Jan 2, 2024
1 parent 1761931 commit c0d6d7d
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 32 deletions.
17 changes: 11 additions & 6 deletions paddle/fluid/pybind/eager_method.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1078,12 +1078,17 @@ static PyObject* tensor__share_underline_tensor_to(TensorObject* self,
EAGER_TRY
paddle::Tensor* src_ptr =
&(reinterpret_cast<TensorObject*>(PyTuple_GET_ITEM(args, 0))->tensor);
PADDLE_ENFORCE_EQ(self->tensor.initialized(),
true,
platform::errors::InvalidArgument(
"Tensor %s has not been initialized! please initialize "
"src tensor before share_buffer_with to other.",
self->tensor.name()));
if (!self->tensor.initialized()) {
PADDLE_ENFORCE(self->tensor.is_dist_tensor() &&
!phi::distributed::IsCurRankInMesh(
static_cast<phi::distributed::DistTensor*>(
self->tensor.impl().get())
->process_mesh()),
platform::errors::InvalidArgument(
"Tensor %s has not been initialized! Please initialize "
"src tensor before share_buffer_with to other.",
self->tensor.name()));
}
src_ptr->set_impl(self->tensor.impl());
RETURN_PY_NONE

Expand Down
6 changes: 6 additions & 0 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,12 @@ def shard_tensor(
if isinstance(data, EagerParamBase):

def lazy_init_hook(param, origin_hook):
for placement in param.placements:
assert not placement.is_partial(), (
"Lazy init not support partial reshard. Notice that: shard a param to partial "
"won't save any memory, but will increase the communication cost!"
)

# lazy init hook with randomness controlling
def _init_func(var, block):
# get the unique rng name
Expand Down
18 changes: 13 additions & 5 deletions python/paddle/nn/initializer/xavier.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ def forward(self, var, block=None):
if self._seed == 0:
self._seed = block.program.random_seed

out_var_shape = (
var._local_shape
if (isinstance(var, framework.EagerParamBase) and var.is_dist())
else var.shape
)
# to be compatible of fp16 initalizers
if var.dtype == core.VarDesc.VarType.FP16 or (
var.dtype == core.VarDesc.VarType.BF16 and not self._uniform
Expand All @@ -114,9 +119,7 @@ def forward(self, var, block=None):
name=unique_name.generate(
".".join(['xavier_init', var.name, 'tmp'])
),
shape=var._local_shape
if (isinstance(var, framework.EagerParamBase) and var.is_dist())
else var.shape,
shape=out_var_shape,
dtype=out_dtype,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
Expand All @@ -135,7 +138,7 @@ def forward(self, var, block=None):
if self._uniform:
limit = math.sqrt(6.0 / float(fan_in + fan_out))
out_var = _C_ops.uniform(
out_var.shape,
out_var_shape,
out_dtype,
-limit,
limit,
Expand All @@ -147,7 +150,12 @@ def forward(self, var, block=None):

place = _current_expected_place()
out_var = _C_ops.gaussian(
out_var.shape, 0.0, std, self._seed, out_dtype, place
out_var_shape,
0.0,
std,
self._seed,
out_dtype,
place,
)

if var.dtype == core.VarDesc.VarType.FP16 or (
Expand Down
81 changes: 61 additions & 20 deletions test/auto_parallel/semi_auto_parallel_lazy_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,40 +22,81 @@
class TestSemiAutoParallelLazyInit:
def __init__(self):
self._backend = os.getenv("backend")
self._placements_type = os.getenv("_placements_type")
self._seed = eval(os.getenv("seed"))
self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
if self._placements_type == "DP":
self._mesh_weight = dist.ProcessMesh([0, 1], dim_names=["x"])
self._mesh_bias = dist.ProcessMesh([0, 1], dim_names=["x"])
self._placements_weight = [dist.Replicate()]
self._placements_bias = [dist.Replicate()]
elif self._placements_type == "PP":
self._mesh_weight = dist.ProcessMesh([0], dim_names=["x"])
self._mesh_bias = dist.ProcessMesh([1], dim_names=["x"])
self._placements_weight = [dist.Replicate()]
self._placements_bias = [dist.Replicate()]

def test_replicate(self):
def test_different_xavier(self):
paddle.distributed.auto_parallel.parallel_manual_seed(self._seed)
weight_attr = paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.XavierNormal()
)
bias_attr = paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.XavierUniform()
)
with LazyGuard():
linear = paddle.nn.Linear(
10, 10, weight_attr=weight_attr, bias_attr=bias_attr
)
linear.weight = dist.shard_tensor(
linear.weight, self._mesh_weight, self._placements_weight
)
linear.bias = dist.shard_tensor(
linear.bias, self._mesh_bias, self._placements_bias
)

def test_placements(self):
paddle.distributed.auto_parallel.parallel_manual_seed(self._seed)
with LazyGuard():
linear = paddle.nn.Linear(10, 10)
linear.weight = dist.shard_tensor(
linear.weight, self._mesh, [dist.Replicate()]
linear.weight, self._mesh_weight, self._placements_weight
)
linear.bias = dist.shard_tensor(
linear.bias, self._mesh, [dist.Replicate()]
linear.bias, self._mesh_bias, self._placements_bias
)
for param in linear.parameters():
assert not param._is_initialized()
param.initialize()
assert param._is_initialized()

local_weight_md5 = linear.weight._local_value()._md5sum()
mesh0 = dist.ProcessMesh([0], dim_names=["x"])
mesh1 = dist.ProcessMesh([1], dim_names=["x"])
tmp = paddle.distributed.auto_parallel.api.dtensor_from_local(
linear.weight._local_value(),
mesh0 if dist.get_rank() == 0 else mesh1,
[dist.Replicate()],
)
tmp = dist.reshard(
tmp, mesh1 if dist.get_rank() == 0 else mesh0, [dist.Replicate()]
)
tmp_md5 = tmp._local_value()._md5sum()
assert local_weight_md5 == tmp_md5

if self._placements_type == "DP":
assert linear.weight._is_initialized()
assert linear.bias._is_initialized()
local_weight_md5 = linear.weight._local_value()._md5sum()
mesh0 = dist.ProcessMesh([0], dim_names=["x"])
mesh1 = dist.ProcessMesh([1], dim_names=["x"])
tmp = paddle.distributed.auto_parallel.api.dtensor_from_local(
linear.weight._local_value(),
mesh0 if dist.get_rank() == 0 else mesh1,
[dist.Replicate()],
)
tmp = dist.reshard(
tmp,
mesh1 if dist.get_rank() == 0 else mesh0,
[dist.Replicate()],
)
tmp_md5 = tmp._local_value()._md5sum()
assert local_weight_md5 == tmp_md5
elif self._placements_type == "PP":
if dist.get_rank() == 0:
assert linear.weight._is_initialized()
assert not linear.bias._is_initialized()
else:
assert not linear.weight._is_initialized()
assert linear.bias._is_initialized()

def run_test_case(self):
self.test_replicate()
self.test_placements()
self.test_different_xavier()


if __name__ == '__main__':
Expand Down
5 changes: 4 additions & 1 deletion test/auto_parallel/test_semi_auto_parallel_lazy_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ def setUp(self):
"dtype": "float32",
"seed": "2023",
}
self._changeable_envs = {"backend": ["cpu", "gpu"]}
self._changeable_envs = {
"backend": ["cpu", "gpu"],
"_placements_type": ["DP", "PP"],
}

def test_lazy_init(self):
envs_list = test_base.gen_product_envs_list(
Expand Down

0 comments on commit c0d6d7d

Please sign in to comment.