-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AutoParallel] Add paddle.distributed.dtensor_from_fn api #56565
Changes from 18 commits
bfca775
d20a032
9070579
03f62dc
95015fa
3902f7d
672ba4d
2d3fc3a
50fd5a9
3953a5e
1d510ea
11e76c7
221c03b
47564d2
dd6c04f
ff9928d
a9dd254
832d36b
bf0598b
71ea731
d545aed
f064d24
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -12,6 +12,7 @@ | |||||||||||
# See the License for the specific language governing permissions and | ||||||||||||
# limitations under the License. | ||||||||||||
|
||||||||||||
|
||||||||||||
import paddle | ||||||||||||
from paddle.distributed.auto_parallel.interface import ( | ||||||||||||
shard_tensor as shard_tensor_static, | ||||||||||||
|
@@ -68,15 +69,13 @@ def __init__(self, mesh, sharding_specs): | |||||||||||
|
||||||||||||
self.process_mesh = mesh | ||||||||||||
self.dims_mapping = dims_mapping | ||||||||||||
|
||||||||||||
self.mark_annotated("process_mesh") | ||||||||||||
self.mark_annotated("dims_mapping") | ||||||||||||
|
||||||||||||
@property | ||||||||||||
def sharding_specs(self): | ||||||||||||
""" | ||||||||||||
Get sharding_specs of the dist_attr | ||||||||||||
|
||||||||||||
Returns: | ||||||||||||
list[str]: sharding_specs | ||||||||||||
""" | ||||||||||||
|
@@ -142,3 +141,32 @@ def shard_tensor( | |||||||||||
return shard_tensor_static( | ||||||||||||
data, dist_attr.process_mesh, dist_attr.sharding_specs | ||||||||||||
) | ||||||||||||
|
||||||||||||
|
||||||||||||
def dtensor_from_fn(fn, dist_attr, *args, **kwargs): | ||||||||||||
""" | ||||||||||||
Construct a Distributed Tensor from a function of arguments. | ||||||||||||
|
||||||||||||
Args: | ||||||||||||
fn (callable): A callable function that takes arguments of Distributed Tensor and returns tensor. | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这句解释不太准确,翻译过来是:“一个可调用的函数,它接受分布式张量作为参数,并返回张量”,应该不是接收分布式张量作为参数吧 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 改为:fn (callable): A paddle api function that takes arguments of *args, **kwargs and returns tensor. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里修改了吗? |
||||||||||||
dist_attr (paddle.distributed.DistAttr): Specify how tensors are distributed or sliced on ProcessMesh. | ||||||||||||
*args (tuple): A tuple of arguments to be passed to the ``fn`` function. | ||||||||||||
**kwargs (dict): A dict of arguments to be passed to the ``fn`` function. | ||||||||||||
|
||||||||||||
Retruns: | ||||||||||||
Tensor: A Tensor constructed from ``fn`` with distributed attributes. | ||||||||||||
|
||||||||||||
Examples: | ||||||||||||
|
||||||||||||
.. code-block:: python | ||||||||||||
>>> import paddle | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 示例代码一般不加 >>> 吧? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 咨询过钟凯,目前的新要求是要加 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
code-block 下加个空行,否则无法正常预览代码块 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done, thx |
||||||||||||
>>> import paddle.distributed as dist | ||||||||||||
>>> # Create a distributed attribute | ||||||||||||
>>> mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) | ||||||||||||
>>> dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=[None]) | ||||||||||||
>>> # Call the function dtensor_from_fn with dist_attr parameter | ||||||||||||
>>> d_tensor = dist.dtensor_from_fn(paddle.ones, dist_attr=dist_attr, shape=[1]) | ||||||||||||
>>> print(d_tensor) | ||||||||||||
""" | ||||||||||||
tensor = fn(*args, **kwargs) | ||||||||||||
return shard_tensor(tensor, dist_attr=dist_attr) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -53,6 +53,73 @@ def test_dist_tensor_creation(self): | |
self.assertEqual(dist_tensor_with_tensor.dist_attr, dist_attr) | ||
|
||
|
||
class TestDistTensorFromFn(unittest.TestCase): | ||
def run_dtensor_from_fn(self): | ||
# Create a distributed attribute | ||
mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) | ||
dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=[None]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 是不是增加测试一个sharding_specs不为None的版本? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里本来是有sharding_specs不为None的版本的,但是后来雨芮建议改成None,我就改了 |
||
|
||
# Call the function dtensor_from_fn with dist_attr parameter | ||
result = dist.dtensor_from_fn( | ||
paddle.ones, dist_attr=dist_attr, shape=[16] | ||
) | ||
# Verify the result | ||
if paddle.in_dynamic_mode(): | ||
dist_attr.dynamic_dims = [] | ||
self.assertIsInstance(result, paddle.Tensor) | ||
self.assertEqual(result.shape, [16]) | ||
self.assertEqual(result.dist_attr, dist_attr) | ||
else: | ||
dist_attr.dynamic_dims = [0] | ||
self.assertIsInstance(result, paddle.static.Variable) | ||
self.assertEqual(result.shape, (16,)) | ||
self.assertEqual(result.dist_attr, dist_attr) | ||
|
||
result_zeros = dist.dtensor_from_fn( | ||
paddle.zeros, dist_attr=dist_attr, shape=[16] | ||
) | ||
if paddle.in_dynamic_mode(): | ||
dist_attr.dynamic_dims = [] | ||
self.assertIsInstance(result, paddle.Tensor) | ||
self.assertEqual(result.shape, [16]) | ||
self.assertEqual(result.dist_attr, dist_attr) | ||
else: | ||
dist_attr.dynamic_dims = [0] | ||
self.assertIsInstance(result, paddle.static.Variable) | ||
self.assertEqual(result.shape, (16,)) | ||
self.assertEqual(result.dist_attr, dist_attr) | ||
|
||
result_random = dist.dtensor_from_fn( | ||
paddle.rand, dist_attr=dist_attr, shape=[16] | ||
) | ||
if paddle.in_dynamic_mode(): | ||
dist_attr.dynamic_dims = [] | ||
self.assertIsInstance(result, paddle.Tensor) | ||
self.assertEqual(result.shape, [16]) | ||
self.assertEqual(result.dist_attr, dist_attr) | ||
else: | ||
dist_attr.dynamic_dims = [0] | ||
self.assertIsInstance(result, paddle.static.Variable) | ||
self.assertEqual(result.shape, (16,)) | ||
self.assertEqual(result.dist_attr, dist_attr) | ||
|
||
# Test with invalid sharding_specs length | ||
with self.assertRaises(AssertionError): | ||
invalid_dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=['x']) | ||
dist.dtensor_from_fn( | ||
paddle.ones, dist_attr=invalid_dist_attr, shape=[2, 3] | ||
) | ||
|
||
def test_dynamic_mode(self): | ||
self.run_dtensor_from_fn() | ||
|
||
# Test exceptions when running in static mode | ||
def test_static_mode(self): | ||
paddle.enable_static() | ||
self.run_dtensor_from_fn() | ||
paddle.disable_static() | ||
|
||
|
||
class TestDistTensorForDygraphAPI(unittest.TestCase): | ||
def check_tensor_eq(self, a, b): | ||
np1 = a.numpy() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是不是强调下是paddle api funciton,不是任意的function都可以
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done,thx