Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoParallel] Add paddle.distributed.dtensor_from_fn api #56565

Merged
merged 22 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
bfca775
def dtensor_from_fn first edition
yangxiaoyu14 Aug 23, 2023
d20a032
dtensor_from_fn first edition
yangxiaoyu14 Aug 23, 2023
9070579
Merge branch 'develop' of https://github.com/yangxiaoyu14/Paddle into…
yangxiaoyu14 Aug 23, 2023
03f62dc
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
yangxiaoyu14 Aug 23, 2023
95015fa
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
yangxiaoyu14 Aug 23, 2023
3902f7d
Delete file /home/Paddle/build/test/auto_parallel/test_dist_tensor.py
yangxiaoyu14 Aug 24, 2023
672ba4d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
yangxiaoyu14 Aug 24, 2023
2d3fc3a
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
yangxiaoyu14 Aug 25, 2023
50fd5a9
polish code format
yangxiaoyu14 Aug 28, 2023
3953a5e
fix sample code formatting issues
yangxiaoyu14 Aug 29, 2023
1d510ea
change sample codes ' >>>' to '>>> '
yangxiaoyu14 Aug 30, 2023
11e76c7
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
yangxiaoyu14 Aug 30, 2023
221c03b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
yangxiaoyu14 Aug 30, 2023
47564d2
Add static image single measurement
yangxiaoyu14 Aug 31, 2023
dd6c04f
modify the Indent of Sample Code
yangxiaoyu14 Sep 1, 2023
ff9928d
complete the sample code modification according to ZhongKai's suggestion
yangxiaoyu14 Sep 1, 2023
a9dd254
modify according to the review
yangxiaoyu14 Sep 4, 2023
832d36b
change fluid.Variable to static.Variable
yangxiaoyu14 Sep 6, 2023
bf0598b
modify according to zhongkai's review
yangxiaoyu14 Sep 7, 2023
71ea731
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
yangxiaoyu14 Sep 11, 2023
d545aed
According to Yifan's suggestion, pull the latest code to resolve conf…
yangxiaoyu14 Sep 11, 2023
f064d24
remove paddle/fluid/ir/dialect/paddle_dialect/ir/generated/pd_ops_bac…
yangxiaoyu14 Sep 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/paddle/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@

from .auto_parallel import shard_op # noqa: F401
from .auto_parallel.api import shard_tensor # noqa: F401
from .auto_parallel.api import dtensor_from_fn # noqa: F401

from .fleet import BoxPSDataset # noqa: F401

Expand Down Expand Up @@ -126,4 +127,5 @@
"ProcessMesh",
"DistAttr",
"shard_tensor",
"dtensor_from_fn",
]
32 changes: 30 additions & 2 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import paddle
from paddle.distributed.auto_parallel.interface import (
shard_tensor as shard_tensor_static,
Expand Down Expand Up @@ -68,15 +69,13 @@ def __init__(self, mesh, sharding_specs):

self.process_mesh = mesh
self.dims_mapping = dims_mapping

self.mark_annotated("process_mesh")
self.mark_annotated("dims_mapping")

@property
def sharding_specs(self):
"""
Get sharding_specs of the dist_attr

Returns:
list[str]: sharding_specs
"""
Expand Down Expand Up @@ -142,3 +141,32 @@ def shard_tensor(
return shard_tensor_static(
data, dist_attr.process_mesh, dist_attr.sharding_specs
)


def dtensor_from_fn(fn, dist_attr, *args, **kwargs):
"""
Construct a Distributed Tensor from a function of arguments.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是不是强调下是paddle api funciton,不是任意的function都可以

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done,thx


Args:
fn (callable): A callable function that takes arguments of Distributed Tensor and returns tensor.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这句解释不太准确,翻译过来是:“一个可调用的函数,它接受分布式张量作为参数,并返回张量”,应该不是接收分布式张量作为参数吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改为:fn (callable): A paddle api function that takes arguments of *args, **kwargs and returns tensor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里修改了吗?

dist_attr (paddle.distributed.DistAttr): Specify how tensors are distributed or sliced on ProcessMesh.
*args (tuple): A tuple of arguments to be passed to the ``fn`` function.
**kwargs (dict): A dict of arguments to be passed to the ``fn`` function.

Retruns:
Tensor: A Tensor constructed from ``fn`` with distributed attributes.

Examples:

.. code-block:: python
>>> import paddle
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

示例代码一般不加 >>> 吧?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

咨询过钟凯,目前的新要求是要加

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
.. code-block:: python
>>> import paddle
.. code-block:: python
>>> import paddle

code-block 下加个空行,否则无法正常预览代码块

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx

>>> import paddle.distributed as dist
>>> # Create a distributed attribute
>>> mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
>>> dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=[None])
>>> # Call the function dtensor_from_fn with dist_attr parameter
>>> d_tensor = dist.dtensor_from_fn(paddle.ones, dist_attr=dist_attr, shape=[1])
>>> print(d_tensor)
"""
tensor = fn(*args, **kwargs)
return shard_tensor(tensor, dist_attr=dist_attr)
67 changes: 67 additions & 0 deletions test/auto_parallel/test_dist_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,73 @@ def test_dist_tensor_creation(self):
self.assertEqual(dist_tensor_with_tensor.dist_attr, dist_attr)


class TestDistTensorFromFn(unittest.TestCase):
def run_dtensor_from_fn(self):
# Create a distributed attribute
mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=[None])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是不是增加测试一个sharding_specs不为None的版本?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里本来是有sharding_specs不为None的版本的,但是后来雨芮建议改成None,我就改了


# Call the function dtensor_from_fn with dist_attr parameter
result = dist.dtensor_from_fn(
paddle.ones, dist_attr=dist_attr, shape=[16]
)
# Verify the result
if paddle.in_dynamic_mode():
dist_attr.dynamic_dims = []
self.assertIsInstance(result, paddle.Tensor)
self.assertEqual(result.shape, [16])
self.assertEqual(result.dist_attr, dist_attr)
else:
dist_attr.dynamic_dims = [0]
self.assertIsInstance(result, paddle.static.Variable)
self.assertEqual(result.shape, (16,))
self.assertEqual(result.dist_attr, dist_attr)

result_zeros = dist.dtensor_from_fn(
paddle.zeros, dist_attr=dist_attr, shape=[16]
)
if paddle.in_dynamic_mode():
dist_attr.dynamic_dims = []
self.assertIsInstance(result, paddle.Tensor)
self.assertEqual(result.shape, [16])
self.assertEqual(result.dist_attr, dist_attr)
else:
dist_attr.dynamic_dims = [0]
self.assertIsInstance(result, paddle.static.Variable)
self.assertEqual(result.shape, (16,))
self.assertEqual(result.dist_attr, dist_attr)

result_random = dist.dtensor_from_fn(
paddle.rand, dist_attr=dist_attr, shape=[16]
)
if paddle.in_dynamic_mode():
dist_attr.dynamic_dims = []
self.assertIsInstance(result, paddle.Tensor)
self.assertEqual(result.shape, [16])
self.assertEqual(result.dist_attr, dist_attr)
else:
dist_attr.dynamic_dims = [0]
self.assertIsInstance(result, paddle.static.Variable)
self.assertEqual(result.shape, (16,))
self.assertEqual(result.dist_attr, dist_attr)

# Test with invalid sharding_specs length
with self.assertRaises(AssertionError):
invalid_dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=['x'])
dist.dtensor_from_fn(
paddle.ones, dist_attr=invalid_dist_attr, shape=[2, 3]
)

def test_dynamic_mode(self):
self.run_dtensor_from_fn()

# Test exceptions when running in static mode
def test_static_mode(self):
paddle.enable_static()
self.run_dtensor_from_fn()
paddle.disable_static()


class TestDistTensorForDygraphAPI(unittest.TestCase):
def check_tensor_eq(self, a, b):
np1 = a.numpy()
Expand Down