-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AutoParallel] Add paddle.distributed.dtensor_from_fn api #56565
Changes from 11 commits
bfca775
d20a032
9070579
03f62dc
95015fa
3902f7d
672ba4d
2d3fc3a
50fd5a9
3953a5e
1d510ea
11e76c7
221c03b
47564d2
dd6c04f
ff9928d
a9dd254
832d36b
bf0598b
71ea731
d545aed
f064d24
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
import paddle | ||
from paddle.distributed.auto_parallel.interface import ( | ||
shard_tensor as shard_tensor_static, | ||
|
@@ -69,19 +70,6 @@ def __init__(self, mesh, sharding_specs): | |
self.process_mesh = mesh | ||
self.dims_mapping = dims_mapping | ||
|
||
self.mark_annotated("process_mesh") | ||
self.mark_annotated("dims_mapping") | ||
|
||
@property | ||
def sharding_specs(self): | ||
""" | ||
Get sharding_specs of the dist_attr | ||
|
||
Returns: | ||
list[str]: sharding_specs | ||
""" | ||
return self._sharding_specs | ||
|
||
|
||
def shard_tensor( | ||
data, dtype=None, place=None, stop_gradient=True, dist_attr=None | ||
|
@@ -142,3 +130,36 @@ def shard_tensor( | |
return shard_tensor_static( | ||
data, dist_attr.process_mesh, dist_attr.sharding_specs | ||
) | ||
|
||
|
||
def dtensor_from_fn(fn, dist_attr, *args, **kwargs): | ||
""" | ||
Construct a Distributed Tensor from a paddle api function of arguments. | ||
|
||
Args: | ||
fn (callable): A paddle api function that takes arguments of *args, **kwargs and returns tensor. | ||
dist_attr(paddle.distributed.DistAttr): Specify how tensors are distributed or sliced on ProcessMesh. | ||
*args: A list of arguments to be passed to the ``fn`` function. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 建议也增加括号统一格式
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里两个参数一个是tuple,一个是dict,不是list,建议区分一下 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
**kwargs: A list of arguments to be passed to the ``fn`` function. | ||
|
||
Retruns: | ||
Tensor: A Tensor constructed from ``fn`` with distributed attributes. | ||
|
||
Examples: | ||
|
||
.. code-block:: python | ||
|
||
>>> import paddle | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里代码是不是需要缩进4个空格 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
>>> import paddle.distribute as dist | ||
|
||
>>> # Create a distributed attribute | ||
>>> mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]], dim_names=["x", "y"]) | ||
>>> dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=['x', 'y']) | ||
|
||
>>> # Call the function dtensor_from_fn with dist_attr parameter | ||
>>> d_tensor = dist.dtensor_from_fn(paddle.ones, dist_attr=dist_attr, shape=[2, 3]) | ||
|
||
>>> print(d_tensor) | ||
""" | ||
tensor = fn(*args, **kwargs) | ||
return shard_tensor(tensor, dist_attr=dist_attr) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -53,6 +53,56 @@ def test_dist_tensor_creation(self): | |
self.assertEqual(dist_tensor_with_tensor.dist_attr, dist_attr) | ||
|
||
|
||
class TestDistributedTensor(unittest.TestCase): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 测试类名要和测试内容对应,改为TestDistTensorFromFn? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
def test_dtensor_from_fn(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 目前还是没有测试静态图,需要再补充一下 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 补充了静态图enable_static()状态下的测试 |
||
# Create a distributed attribute | ||
mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]], dim_names=["x", "y"]) | ||
dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=['x', 'y']) | ||
|
||
# Test with generate_tensor_ones() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个注释对应的函数已经没有了,注释建议修改一下,和下面的代码关联不明确,例如改为Test with paddle.ones There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 决定删除了不对应的注释 |
||
# Call the function dtensor_from_fn with dist_attr parameter | ||
result = dist.dtensor_from_fn( | ||
paddle.ones, dist_attr=dist_attr, shape=[2, 3] | ||
) | ||
|
||
# Verify the result | ||
self.assertIsInstance(result, paddle.Tensor) | ||
self.assertEqual(result.shape, [2, 3]) | ||
self.assertEqual(result.dist_attr, dist_attr) | ||
|
||
# Test with generate_tensor_zeros() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里的注释还需要删除一下 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
result_zeros = dist.dtensor_from_fn( | ||
paddle.zeros, dist_attr=dist_attr, shape=[2, 3] | ||
) | ||
self.assertIsInstance(result_zeros, paddle.Tensor) | ||
self.assertEqual(result_zeros.shape, [2, 3]) | ||
self.assertEqual(result_zeros.dist_attr, dist_attr) | ||
|
||
# Test with generate_tensor_random() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里的注释的也是 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
result_random = dist.dtensor_from_fn( | ||
paddle.rand, dist_attr=dist_attr, shape=[2, 3] | ||
) | ||
self.assertIsInstance(result_random, paddle.Tensor) | ||
self.assertEqual(result_random.shape, [2, 3]) | ||
self.assertEqual(result_random.dist_attr, dist_attr) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 是不是需要加几个异常case,测试一下报错的情况 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 是不是也测试下静态图? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 是的 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done,thx |
||
|
||
# Test with invalid sharding_specs length | ||
with self.assertRaises(AssertionError): | ||
invalid_dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=['x']) | ||
dist.dtensor_from_fn( | ||
paddle.ones, dist_attr=invalid_dist_attr, shape=[2, 3] | ||
) | ||
|
||
# Test exceptions when running in static mode | ||
paddle.enable_static() | ||
with self.assertRaises(NotImplementedError): | ||
with paddle.static.program_guard(paddle.static.Program()): | ||
dist.dtensor_from_fn( | ||
paddle.ones, dist_attr=dist_attr, shape=[2, 3] | ||
) | ||
paddle.disable_static() | ||
|
||
|
||
class TestDistTensorForDygraphAPI(unittest.TestCase): | ||
def check_tensor_eq(self, a, b): | ||
np1 = a.numpy() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fn和(callable之间有空格,dist_attr和括号之间也建议增加空格
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done