diff --git a/python/paddle/base/core.pyi b/python/paddle/base/core.pyi index 12a1620540a5d0..499e1c87fde35a 100644 --- a/python/paddle/base/core.pyi +++ b/python/paddle/base/core.pyi @@ -96,3 +96,9 @@ class iinfo: def is_compiled_with_cuda() -> bool: ... def set_nan_inf_debug_path(arg0: str) -> None: ... + +class task: + def is_completed(self) -> bool: ... + def is_sync(self) -> bool: ... + def synchronize(self) -> None: ... + def wait(self, timeout: int = 0) -> bool: ... diff --git a/python/paddle/distributed/communication/all_gather.py b/python/paddle/distributed/communication/all_gather.py index ab7c074a16d1e0..01a486f05d808d 100644 --- a/python/paddle/distributed/communication/all_gather.py +++ b/python/paddle/distributed/communication/all_gather.py @@ -12,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, TypeVar + import numpy as np import paddle @@ -23,8 +27,20 @@ convert_tensor_to_object, ) +if TYPE_CHECKING: + from paddle import Tensor + from paddle.base.core import task + from paddle.distributed.communication.group import Group + + _T = TypeVar("_T") + -def all_gather(tensor_list, tensor, group=None, sync_op=True): +def all_gather( + tensor_list: list[Tensor], + tensor: Tensor, + group: Group | None = None, + sync_op: bool = True, +) -> task | None: """ Gather tensors from all participators and all get the result. As shown @@ -42,7 +58,7 @@ def all_gather(tensor_list, tensor, group=None, sync_op=True): should be float16, float32, float64, int32, int64, int8, uint8, bool, bfloat16, complex64 or complex128. tensor (Tensor): The Tensor to send. Its data type should be float16, float32, float64, int32, int64, int8, uint8, bool, bfloat16, complex64 or complex128. - group (Group, optional): The group instance return by new_group or None for global default group. + group (Group|None, optional): The group instance return by new_group or None for global default group. sync_op (bool, optional): Whether this op is a sync op. The default value is True. Returns: @@ -68,7 +84,9 @@ def all_gather(tensor_list, tensor, group=None, sync_op=True): return stream.all_gather(tensor_list, tensor, group, sync_op) -def all_gather_object(object_list, obj, group=None): +def all_gather_object( + object_list: list[_T], obj: _T, group: Group = None +) -> None: """ Gather picklable objects from all participators and all get the result. Similar to all_gather(), but python object can be passed in.