From ce91dc737b542440cf717c81d7da5a3c233ead35 Mon Sep 17 00:00:00 2001 From: Lans1ot <47025645+Lans1ot@users.noreply.github.com> Date: Wed, 31 Jul 2024 09:20:41 +0800 Subject: [PATCH] [Typing][C-19] Add type annotations for `python/paddle/distributed/communication/recv.py` (#66694) --- .../communication/batch_isend_irecv.py | 8 +++++--- .../paddle/distributed/communication/recv.py | 20 +++++++++++++++++-- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/python/paddle/distributed/communication/batch_isend_irecv.py b/python/paddle/distributed/communication/batch_isend_irecv.py index 5a54c3d9efd2b..d1d2258b74653 100644 --- a/python/paddle/distributed/communication/batch_isend_irecv.py +++ b/python/paddle/distributed/communication/batch_isend_irecv.py @@ -14,7 +14,7 @@ from __future__ import annotations import contextlib -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING, Callable, Generator, Sequence import paddle.distributed as dist from paddle import framework @@ -96,7 +96,9 @@ def __init__( @contextlib.contextmanager -def _coalescing_manager(group, tasks=None): +def _coalescing_manager( + group: Group, tasks: task | None = None +) -> Generator[None, None, None]: group = _get_global_group() if group is None else group pg = group.process_group pg._start_coalescing() @@ -109,7 +111,7 @@ def _coalescing_manager(group, tasks=None): pg._end_coalescing(tasks) -def _check_p2p_op_list(p2p_op_list): +def _check_p2p_op_list(p2p_op_list: Sequence[P2POp]) -> None: """ Helper to check that the ``p2p_op_list`` is a list of P2POp instances and all ops use the same backend. diff --git a/python/paddle/distributed/communication/recv.py b/python/paddle/distributed/communication/recv.py index e7e0315b7dd51..4952dfd10a8c9 100644 --- a/python/paddle/distributed/communication/recv.py +++ b/python/paddle/distributed/communication/recv.py @@ -12,10 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + from paddle.distributed.communication import stream +if TYPE_CHECKING: + from paddle import Tensor + from paddle.base.core import task + from paddle.distributed.communication.group import Group + -def recv(tensor, src=0, group=None, sync_op=True): +def recv( + tensor: Tensor, + src: int = 0, + group: Group | None = None, + sync_op: bool = True, +) -> task: """ Receive a tensor to the sender. @@ -51,7 +65,9 @@ def recv(tensor, src=0, group=None, sync_op=True): ) -def irecv(tensor, src=None, group=None): +def irecv( + tensor: Tensor, src: int | None = None, group: Group | None = None +) -> task: """ Receive a tensor to the sender.