diff --git a/python/paddle/base/framework.py b/python/paddle/base/framework.py index 1225eba4e4242..f8d689c1e578e 100644 --- a/python/paddle/base/framework.py +++ b/python/paddle/base/framework.py @@ -7598,10 +7598,12 @@ def from_tensor(cls, tensor, **kwargs): mesh = kwargs.get("process_mesh", None) placements = kwargs.get("placements", None) src_tensor = tensor + if mesh is not None and placements is not None: src_tensor = core.eager.Tensor( tensor, process_mesh=mesh, placements=placements ) + param.name = tensor.name + ".dist" # 3. set param data param._set_impl(src_tensor) diff --git a/python/paddle/distributed/auto_parallel/api.py b/python/paddle/distributed/auto_parallel/api.py index 7e734bd95b1b1..1604a8bea63a5 100644 --- a/python/paddle/distributed/auto_parallel/api.py +++ b/python/paddle/distributed/auto_parallel/api.py @@ -169,15 +169,16 @@ def shard_tensor( place = paddle.framework._get_paddle_place(place) # 1. create dense tensor - # `paddle.to_tensor` supports both dynamic and static mode if stop_gradient is None: stop_gradient = getattr(data, "stop_gradient", True) + if isinstance(data, EagerParamBase) and not data._is_initialized(): assert ( data._init_func is not None ), "Get an uninitialized param with an unregistered init_func." tensor = data else: + # `paddle.to_tensor` supports both dynamic and static mode tensor = paddle.to_tensor( data, dtype=dtype, place=place, stop_gradient=stop_gradient ) diff --git a/test/auto_parallel/test_dist_tensor.py b/test/auto_parallel/test_dist_tensor.py index f1e882e013e66..ad2fa4a040c39 100644 --- a/test/auto_parallel/test_dist_tensor.py +++ b/test/auto_parallel/test_dist_tensor.py @@ -52,6 +52,17 @@ def test_dist_tensor_creation(self): self.assertEqual(dist_tensor_with_numpy.placements, placements) self.assertEqual(dist_tensor_with_tensor.placements, placements) + def test_dist_parameter(self): + mesh = dist.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"]) + placements = [Replicate(), Replicate()] + + dense_param = paddle.create_parameter( + [10, 5], name="linear_1.weight", dtype='float32' + ) + dist_param = dist.shard_tensor(dense_param, mesh, placements) + + self.assertEqual(dense_param.name + ".dist", dist_param.name) + class TestDistTensorFromFn(unittest.TestCase): def run_dtensor_from_fn(self):