-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Typing][C-21,C-22,C-23,C-24] Add type annotations for python/paddle/distributed/communication/{reduce_scatter.py, scatter.py, send.py, all_gather.py}
#66864
Changes from 7 commits
13e76d7
feb3ba9
f7dac7a
b4d048e
0e4d12c
8c7e446
bb40649
4c0c198
f488c7d
3d4abf3
9f1f05f
3832021
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -12,6 +12,14 @@ | |||||
# See the License for the specific language governing permissions and | ||||||
# limitations under the License. | ||||||
|
||||||
from __future__ import annotations | ||||||
|
||||||
from typing import TYPE_CHECKING, Sequence | ||||||
|
||||||
if TYPE_CHECKING: | ||||||
from paddle import Tensor | ||||||
from paddle.distributed.communication.group import Group | ||||||
|
||||||
import numpy as np | ||||||
|
||||||
import paddle | ||||||
|
@@ -25,7 +33,13 @@ | |||||
) | ||||||
|
||||||
|
||||||
def scatter(tensor, tensor_list=None, src=0, group=None, sync_op=True): | ||||||
def scatter( | ||||||
tensor: Tensor, | ||||||
tensor_list: Sequence[Tensor] | None = None, | ||||||
src: int = 0, | ||||||
group: Group | None = None, | ||||||
sync_op: bool = True, | ||||||
) -> None: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
""" | ||||||
|
||||||
Scatter a tensor to all participators. As shown below, one process is started with a GPU and the source of the scatter | ||||||
|
@@ -72,8 +86,11 @@ def scatter(tensor, tensor_list=None, src=0, group=None, sync_op=True): | |||||
|
||||||
|
||||||
def scatter_object_list( | ||||||
out_object_list, in_object_list=None, src=0, group=None | ||||||
): | ||||||
out_object_list: list[object], | ||||||
in_object_list: list[object] | None = None, | ||||||
Lans1ot marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
src: int = 0, | ||||||
group: Group | None = None, | ||||||
) -> None: | ||||||
""" | ||||||
|
||||||
Scatter picklable objects from the source to all others. Similiar to scatter(), but python object can be passed in. | ||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -12,10 +12,24 @@ | |||||
# See the License for the specific language governing permissions and | ||||||
# limitations under the License. | ||||||
|
||||||
from __future__ import annotations | ||||||
|
||||||
from typing import TYPE_CHECKING | ||||||
|
||||||
from paddle.distributed.communication import stream | ||||||
|
||||||
if TYPE_CHECKING: | ||||||
from paddle import Tensor | ||||||
from paddle.base.core import task | ||||||
from paddle.distributed.communication.group import Group | ||||||
|
||||||
|
||||||
def send(tensor, dst=0, group=None, sync_op=True): | ||||||
def send( | ||||||
tensor: Tensor, | ||||||
dst: int = 0, | ||||||
group: Group | None = None, | ||||||
sync_op: bool = True, | ||||||
) -> task: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
""" | ||||||
Send a tensor to the receiver. | ||||||
|
||||||
|
@@ -51,7 +65,7 @@ def send(tensor, dst=0, group=None, sync_op=True): | |||||
) | ||||||
|
||||||
|
||||||
def isend(tensor, dst, group=None): | ||||||
def isend(tensor: Tensor, dst: int, group: Group | None = None) -> task: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
""" | ||||||
Send tensor asynchronously | ||||||
|
||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -12,16 +12,29 @@ | |||||
# See the License for the specific language governing permissions and | ||||||
# limitations under the License. | ||||||
|
||||||
from __future__ import annotations | ||||||
|
||||||
from typing import TYPE_CHECKING | ||||||
|
||||||
import paddle | ||||||
import paddle.distributed as dist | ||||||
from paddle import framework | ||||||
from paddle.base import data_feeder | ||||||
from paddle.distributed.communication.group import _get_global_group | ||||||
|
||||||
if TYPE_CHECKING: | ||||||
from paddle import Tensor | ||||||
from paddle.base.core import task | ||||||
from paddle.distributed.communication.group import Group | ||||||
|
||||||
|
||||||
def _all_gather_into_tensor_in_dygraph( | ||||||
out_tensor, in_tensor, group, sync_op, use_calc_stream | ||||||
): | ||||||
out_tensor: Tensor, | ||||||
in_tensor: Tensor, | ||||||
group: Group, | ||||||
sync_op: bool, | ||||||
use_calc_stream: bool, | ||||||
) -> task: | ||||||
group = _get_global_group() if group is None else group | ||||||
|
||||||
if use_calc_stream: | ||||||
|
@@ -40,8 +53,12 @@ def _all_gather_into_tensor_in_dygraph( | |||||
|
||||||
|
||||||
def _all_gather_in_dygraph( | ||||||
tensor_list, tensor, group, sync_op, use_calc_stream | ||||||
): | ||||||
tensor_list: list[Tensor], | ||||||
tensor: Tensor, | ||||||
group: Group, | ||||||
sync_op: bool, | ||||||
use_calc_stream: bool, | ||||||
) -> task: | ||||||
group = _get_global_group() if group is None else group | ||||||
|
||||||
if len(tensor_list) == 0: | ||||||
|
@@ -59,7 +76,9 @@ def _all_gather_in_dygraph( | |||||
return task | ||||||
|
||||||
|
||||||
def _all_gather_in_static_mode(tensor_list, tensor, group, sync_op): | ||||||
def _all_gather_in_static_mode( | ||||||
tensor_list: list[Tensor], tensor: Tensor, group: Group, sync_op: bool | ||||||
) -> task: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
op_type = 'all_gather' | ||||||
helper = framework.LayerHelper(op_type, **locals()) | ||||||
out = helper.create_variable_for_type_inference(dtype=tensor.dtype) | ||||||
|
@@ -121,12 +140,12 @@ def _all_gather_in_static_mode(tensor_list, tensor, group, sync_op): | |||||
|
||||||
|
||||||
def all_gather( | ||||||
tensor_or_tensor_list, | ||||||
tensor, | ||||||
group=None, | ||||||
sync_op=True, | ||||||
use_calc_stream=False, | ||||||
): | ||||||
tensor_or_tensor_list: Tensor | list[Tensor], | ||||||
tensor: Tensor, | ||||||
group: Group | None = None, | ||||||
sync_op: bool = True, | ||||||
use_calc_stream: bool = False, | ||||||
) -> task: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
""" | ||||||
|
||||||
Gather tensors across devices to a correctly-sized tensor or a tensor list. | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里还有下面的检查代码是为了解决什么呢?如果没什么特别需要解决的问题最好还是不要加,因为本任务只涉及类型提示修改,不涉及运行时逻辑修改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这是因为我之前做其他任务的时候,有类似的情况。
就是参数只能在特定值中间进行选择,不然就会抛出一个错误
所以我这里就根据
python/paddle/distributed/communication/batch_isend_irecv.py
仿写了一段