Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoParallel] Add paddle.distributed.dtensor_from_fn api #56565

Merged
merged 22 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
bfca775
def dtensor_from_fn first edition
yangxiaoyu14 Aug 23, 2023
d20a032
dtensor_from_fn first edition
yangxiaoyu14 Aug 23, 2023
9070579
Merge branch 'develop' of https://github.com/yangxiaoyu14/Paddle into…
yangxiaoyu14 Aug 23, 2023
03f62dc
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
yangxiaoyu14 Aug 23, 2023
95015fa
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
yangxiaoyu14 Aug 23, 2023
3902f7d
Delete file /home/Paddle/build/test/auto_parallel/test_dist_tensor.py
yangxiaoyu14 Aug 24, 2023
672ba4d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
yangxiaoyu14 Aug 24, 2023
2d3fc3a
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
yangxiaoyu14 Aug 25, 2023
50fd5a9
polish code format
yangxiaoyu14 Aug 28, 2023
3953a5e
fix sample code formatting issues
yangxiaoyu14 Aug 29, 2023
1d510ea
change sample codes ' >>>' to '>>> '
yangxiaoyu14 Aug 30, 2023
11e76c7
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
yangxiaoyu14 Aug 30, 2023
221c03b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
yangxiaoyu14 Aug 30, 2023
47564d2
Add static image single measurement
yangxiaoyu14 Aug 31, 2023
dd6c04f
modify the Indent of Sample Code
yangxiaoyu14 Sep 1, 2023
ff9928d
complete the sample code modification according to ZhongKai's suggestion
yangxiaoyu14 Sep 1, 2023
a9dd254
modify according to the review
yangxiaoyu14 Sep 4, 2023
832d36b
change fluid.Variable to static.Variable
yangxiaoyu14 Sep 6, 2023
bf0598b
modify according to zhongkai's review
yangxiaoyu14 Sep 7, 2023
71ea731
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
yangxiaoyu14 Sep 11, 2023
d545aed
According to Yifan's suggestion, pull the latest code to resolve conf…
yangxiaoyu14 Sep 11, 2023
f064d24
remove paddle/fluid/ir/dialect/paddle_dialect/ir/generated/pd_ops_bac…
yangxiaoyu14 Sep 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/paddle/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@

from .auto_parallel import shard_op # noqa: F401
from .auto_parallel.api import shard_tensor # noqa: F401
from .auto_parallel.api import dtensor_from_fn # noqa: F401

from .fleet import BoxPSDataset # noqa: F401

Expand Down Expand Up @@ -126,4 +127,5 @@
"ProcessMesh",
"DistAttr",
"shard_tensor",
"dtensor_from_fn",
]
47 changes: 34 additions & 13 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import paddle
from paddle.distributed.auto_parallel.interface import (
shard_tensor as shard_tensor_static,
Expand Down Expand Up @@ -69,19 +70,6 @@ def __init__(self, mesh, sharding_specs):
self.process_mesh = mesh
self.dims_mapping = dims_mapping

self.mark_annotated("process_mesh")
self.mark_annotated("dims_mapping")

@property
def sharding_specs(self):
"""
Get sharding_specs of the dist_attr

Returns:
list[str]: sharding_specs
"""
return self._sharding_specs


def shard_tensor(
data, dtype=None, place=None, stop_gradient=True, dist_attr=None
Expand Down Expand Up @@ -142,3 +130,36 @@ def shard_tensor(
return shard_tensor_static(
data, dist_attr.process_mesh, dist_attr.sharding_specs
)


def dtensor_from_fn(fn, dist_attr, *args, **kwargs):
"""
Construct a Distributed Tensor from a paddle api function of arguments.

Args:
fn (callable): A paddle api function that takes arguments of *args, **kwargs and returns tensor.
dist_attr(paddle.distributed.DistAttr): Specify how tensors are distributed or sliced on ProcessMesh.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fn和(callable之间有空格,dist_attr和括号之间也建议增加空格

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

*args: A list of arguments to be passed to the ``fn`` function.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议也增加括号统一格式

*args (tuple): 
**kwargs (dict): 

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里两个参数一个是tuple,一个是dict,不是list,建议区分一下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

**kwargs: A list of arguments to be passed to the ``fn`` function.

Retruns:
Tensor: A Tensor constructed from ``fn`` with distributed attributes.

Examples:

.. code-block:: python

>>> import paddle
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里代码是不是需要缩进4个空格

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

>>> import paddle.distribute as dist

>>> # Create a distributed attribute
>>> mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]], dim_names=["x", "y"])
>>> dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=['x', 'y'])

>>> # Call the function dtensor_from_fn with dist_attr parameter
>>> d_tensor = dist.dtensor_from_fn(paddle.ones, dist_attr=dist_attr, shape=[2, 3])

>>> print(d_tensor)
"""
tensor = fn(*args, **kwargs)
return shard_tensor(tensor, dist_attr=dist_attr)
50 changes: 50 additions & 0 deletions test/auto_parallel/test_dist_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,56 @@ def test_dist_tensor_creation(self):
self.assertEqual(dist_tensor_with_tensor.dist_attr, dist_attr)


class TestDistributedTensor(unittest.TestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

测试类名要和测试内容对应,改为TestDistTensorFromFn?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

def test_dtensor_from_fn(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前还是没有测试静态图,需要再补充一下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

补充了静态图enable_static()状态下的测试

# Create a distributed attribute
mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]], dim_names=["x", "y"])
dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=['x', 'y'])

# Test with generate_tensor_ones()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个注释对应的函数已经没有了,注释建议修改一下,和下面的代码关联不明确,例如改为Test with paddle.ones

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

决定删除了不对应的注释

# Call the function dtensor_from_fn with dist_attr parameter
result = dist.dtensor_from_fn(
paddle.ones, dist_attr=dist_attr, shape=[2, 3]
)

# Verify the result
self.assertIsInstance(result, paddle.Tensor)
self.assertEqual(result.shape, [2, 3])
self.assertEqual(result.dist_attr, dist_attr)

# Test with generate_tensor_zeros()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的注释还需要删除一下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

result_zeros = dist.dtensor_from_fn(
paddle.zeros, dist_attr=dist_attr, shape=[2, 3]
)
self.assertIsInstance(result_zeros, paddle.Tensor)
self.assertEqual(result_zeros.shape, [2, 3])
self.assertEqual(result_zeros.dist_attr, dist_attr)

# Test with generate_tensor_random()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的注释的也是

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

result_random = dist.dtensor_from_fn(
paddle.rand, dist_attr=dist_attr, shape=[2, 3]
)
self.assertIsInstance(result_random, paddle.Tensor)
self.assertEqual(result_random.shape, [2, 3])
self.assertEqual(result_random.dist_attr, dist_attr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是不是需要加几个异常case,测试一下报错的情况

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是不是也测试下静态图?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done,thx


# Test with invalid sharding_specs length
with self.assertRaises(AssertionError):
invalid_dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=['x'])
dist.dtensor_from_fn(
paddle.ones, dist_attr=invalid_dist_attr, shape=[2, 3]
)

# Test exceptions when running in static mode
paddle.enable_static()
with self.assertRaises(NotImplementedError):
with paddle.static.program_guard(paddle.static.Program()):
dist.dtensor_from_fn(
paddle.ones, dist_attr=dist_attr, shape=[2, 3]
)
paddle.disable_static()


class TestDistTensorForDygraphAPI(unittest.TestCase):
def check_tensor_eq(self, a, b):
np1 = a.numpy()
Expand Down