Skip to content

Commit

Permalink
[Typing][C-19] Add type annotations for `python/paddle/distributed/co…
Browse files Browse the repository at this point in the history
…mmunication/recv.py` (#66694)
  • Loading branch information
Lans1ot authored Jul 31, 2024
1 parent 8ddfb0a commit 220ffec
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 5 deletions.
8 changes: 5 additions & 3 deletions python/paddle/distributed/communication/batch_isend_irecv.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import annotations

import contextlib
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING, Callable, Generator, Sequence

import paddle.distributed as dist
from paddle import framework
Expand Down Expand Up @@ -96,7 +96,9 @@ def __init__(


@contextlib.contextmanager
def _coalescing_manager(group, tasks=None):
def _coalescing_manager(
group: Group, tasks: task | None = None
) -> Generator[None, None, None]:
group = _get_global_group() if group is None else group
pg = group.process_group
pg._start_coalescing()
Expand All @@ -109,7 +111,7 @@ def _coalescing_manager(group, tasks=None):
pg._end_coalescing(tasks)


def _check_p2p_op_list(p2p_op_list):
def _check_p2p_op_list(p2p_op_list: Sequence[P2POp]) -> None:
"""
Helper to check that the ``p2p_op_list`` is a list of P2POp instances and
all ops use the same backend.
Expand Down
20 changes: 18 additions & 2 deletions python/paddle/distributed/communication/recv.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING

from paddle.distributed.communication import stream

if TYPE_CHECKING:
from paddle import Tensor
from paddle.base.core import task
from paddle.distributed.communication.group import Group


def recv(tensor, src=0, group=None, sync_op=True):
def recv(
tensor: Tensor,
src: int = 0,
group: Group | None = None,
sync_op: bool = True,
) -> task:
"""
Receive a tensor to the sender.
Expand Down Expand Up @@ -51,7 +65,9 @@ def recv(tensor, src=0, group=None, sync_op=True):
)


def irecv(tensor, src=None, group=None):
def irecv(
tensor: Tensor, src: int | None = None, group: Group | None = None
) -> task:
"""
Receive a tensor to the sender.
Expand Down

0 comments on commit 220ffec

Please sign in to comment.