Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoParallel] Add param name for dist_tensor parameter #60574

Merged
merged 1 commit into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/paddle/base/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -7598,10 +7598,12 @@ def from_tensor(cls, tensor, **kwargs):
mesh = kwargs.get("process_mesh", None)
placements = kwargs.get("placements", None)
src_tensor = tensor

if mesh is not None and placements is not None:
src_tensor = core.eager.Tensor(
tensor, process_mesh=mesh, placements=placements
)
param.name = tensor.name + ".dist"

# 3. set param data
param._set_impl(src_tensor)
Expand Down
3 changes: 2 additions & 1 deletion python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,15 +169,16 @@ def shard_tensor(
place = paddle.framework._get_paddle_place(place)

# 1. create dense tensor
# `paddle.to_tensor` supports both dynamic and static mode
if stop_gradient is None:
stop_gradient = getattr(data, "stop_gradient", True)

if isinstance(data, EagerParamBase) and not data._is_initialized():
assert (
data._init_func is not None
), "Get an uninitialized param with an unregistered init_func."
tensor = data
else:
# `paddle.to_tensor` supports both dynamic and static mode
tensor = paddle.to_tensor(
data, dtype=dtype, place=place, stop_gradient=stop_gradient
)
Expand Down
11 changes: 11 additions & 0 deletions test/auto_parallel/test_dist_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,17 @@ def test_dist_tensor_creation(self):
self.assertEqual(dist_tensor_with_numpy.placements, placements)
self.assertEqual(dist_tensor_with_tensor.placements, placements)

def test_dist_parameter(self):
mesh = dist.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"])
placements = [Replicate(), Replicate()]

dense_param = paddle.create_parameter(
[10, 5], name="linear_1.weight", dtype='float32'
)
dist_param = dist.shard_tensor(dense_param, mesh, placements)

self.assertEqual(dense_param.name + ".dist", dist_param.name)


class TestDistTensorFromFn(unittest.TestCase):
def run_dtensor_from_fn(self):
Expand Down