Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

转换规则 No. 153 #181

Merged
merged 16 commits into from
Sep 7, 2023
9 changes: 9 additions & 0 deletions paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -3968,6 +3968,15 @@
"other": "y"
}
},
"torch.distributed.all_gather_object": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distributed.all_gather_object",
"args_list": [
"object_list",
"obj",
"group"
]
},
"torch.distributed.broadcast": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distributed.broadcast",
Expand Down
39 changes: 39 additions & 0 deletions tests/test_distributed_all_gather_object.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap

from apibase import APIBase

obj = APIBase("torch.distributed.all_gather_object")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.distributed as dist
Copy link
Collaborator

@zhwesky2010 zhwesky2010 Sep 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CI上的环境只有CPU,单测中需要判断下if torch is cuda,不然跑不过

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

收到

dist.init_process_group("nccl", init_method='tcp://127.0.0.1:23456', rank=1, world_size=3)
gather_objects = ["foo", 12, {1: 2}] # any picklable object
output = [None for _ in gather_objects]
dist.all_gather_object(output, gather_objects[dist.get_rank()])
result=True
"""
)
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle does not support this function temporarily",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

上面不是刚实现了 torch.distributed.all_gather_object 吗,这里怎么能写unsupport呢

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    dist.init_process_group("nccl", init_method='tcp://127.0.0.1:23456', rank=1, world_size=3)

这个无法运行

Copy link
Collaborator

@zhwesky2010 zhwesky2010 Aug 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/PaddlePaddle/PaConvert/blob/master/paconvert/api_mapping.json#L4022-L4045

这个API也是支持的,为啥转完后,会有 >>> 标记呀,按道理应该全部转写了,使用unsupport=True是不能跑过的

)