Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SemiAuto] add static branch for shard_tensor #56561

Merged
merged 4 commits into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 39 additions & 4 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
# limitations under the License.

import paddle
from paddle.distributed.auto_parallel.interface import (
shard_tensor as shard_tensor_static,
)
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.framework import core

Expand Down Expand Up @@ -55,15 +58,46 @@ def __init__(self, mesh, sharding_specs):
for dim_name in sharding_specs
), 'The dimension name in sharding_specs must be an instance of str.'

self._sharding_specs = sharding_specs
dims_mapping = [
mesh.dim_names.index(dim_name) if dim_name is not None else -1
for dim_name in sharding_specs
]

# 2. init core.TensorDistAttr
core.TensorDistAttr.__init__(self)
self.process_mesh = mesh
self.dims_mapping = dims_mapping
self._process_mesh = mesh
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the process_mesh field in c++ TensorDistAttr might remain unchanged ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

self._dims_mapping = dims_mapping

@property
def process_mesh(self):
"""
Get process_mesh of the dist_attr

Returns:
paddle.distributed.ProcessMesh: process_mesh
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

还需要补充下示例吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉不加也可以,比较简单的property函数。如果需要,后续我补一下。

"""
return self._process_mesh

@property
def dims_mapping(self):
"""
Get dims_mapping of the dist_attr

Returns:
list[int]: dims_mapping
"""
return self._dims_mapping

@property
def sharding_specs(self):
"""
Get sharding_specs of the dist_attr

Returns:
list[str]: sharding_specs
"""
return self._sharding_specs


def shard_tensor(
Expand Down Expand Up @@ -121,6 +155,7 @@ def shard_tensor(
if paddle.in_dynamic_mode():
return paddle.Tensor(data, dist_attr=dist_attr)
else:
raise NotImplementedError(
"The `paddle.distributed.shard_tensor` for static mode will be implemented later."
# TODO(zhiqiu): we need to refine the static shard_tensor
return shard_tensor_static(
data, dist_attr.process_mesh, dist_attr.sharding_specs
)
70 changes: 53 additions & 17 deletions test/auto_parallel/test_shard_tensor_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@

import paddle
import paddle.distributed as dist
from paddle.distributed.auto_parallel.static.dist_context import (
get_default_distributed_context,
)
from paddle.fluid.dygraph.base import switch_to_static_graph


class TestDistAttrBasic(unittest.TestCase):
Expand Down Expand Up @@ -51,27 +55,59 @@ def test_sharding_specs_argument_error(self):
self.assertIsNotNone(exception)


class TestShardTensorBasic(unittest.TestCase):
# remove this test after static mode is supported
def test_static_mode_unimplemented(self):
exception = None
try:
paddle.enable_static()
class TestShardTensorStatic(unittest.TestCase):
def setUp(self):
self.mesh = dist.ProcessMesh(
[[0, 1, 2, 3], [4, 5, 6, 7]], dim_names=["x", "y"]
)

@switch_to_static_graph
def test_static_mode(self):
dist_attr = dist.DistAttr(
mesh=self.mesh, sharding_specs=['x', None, None]
)

input = paddle.static.data(
name="input",
shape=[4, 1024, 512],
dtype='float32',
)
d_tensor = dist.shard_tensor(input, dist_attr=dist_attr)

default_dist_context = get_default_distributed_context()
dist_input = default_dist_context.get_dist_tensor_for_program(input)
self.assertEqual(dist_input.dist_attr.process_mesh, self.mesh)
self.assertEqual(dist_input.dist_attr.dims_mapping, [0, -1, -1])
self.assertTrue(dist_input.dist_attr.is_annotated("process_mesh"))
self.assertTrue(dist_input.dist_attr.is_annotated("dims_mapping"))


class TestShardTensorStaticDy2Static(unittest.TestCase):
def test_dy2static(self):
@paddle.jit.to_static
def func():
mesh = dist.ProcessMesh(
[[2, 4, 5], [0, 1, 3]], dim_names=["x", "y"]
[[0, 1, 2, 3], [4, 5, 6, 7]], dim_names=["x", "y"]
)
dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=['x', 'y'])
a = paddle.to_tensor([[1, 2, 3], [5, 6, 7]])
d_tensor = dist.shard_tensor(a, dist_attr=dist_attr)
except NotImplementedError as ex:
self.assertIn(
"The `paddle.distributed.shard_tensor` for static mode will be implemented later",
str(ex),
dist_attr = dist.DistAttr(
mesh=mesh, sharding_specs=['x', None, None]
)
exception = ex
paddle.disable_static()

self.assertIsNotNone(exception)
input = paddle.rand([4, 1024, 512])
d_tensor = dist.shard_tensor(input, dist_attr=dist_attr)
return input, mesh

dy_tensor, mesh = func()
static_tensor = func.outputs[0] # get the inputs of static program

default_dist_context = get_default_distributed_context()
dist_input = default_dist_context.get_dist_tensor_for_program(
static_tensor
)
self.assertEqual(dist_input.dist_attr.process_mesh, mesh)
self.assertEqual(dist_input.dist_attr.dims_mapping, [0, -1, -1])
self.assertTrue(dist_input.dist_attr.is_annotated("process_mesh"))
self.assertTrue(dist_input.dist_attr.is_annotated("dims_mapping"))


if __name__ == "__main__":
Expand Down