Skip to content

Commit

Permalink
[Typing][C-12] Add type annotations for `python/paddle/distributed/co…
Browse files Browse the repository at this point in the history
…mmunication/all_gather.py` (PaddlePaddle#66051)
  • Loading branch information
megemini authored and lixcli committed Jul 22, 2024
1 parent a88809b commit b022ab4
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
6 changes: 6 additions & 0 deletions python/paddle/base/core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,9 @@ class iinfo:

def is_compiled_with_cuda() -> bool: ...
def set_nan_inf_debug_path(arg0: str) -> None: ...

class task:
def is_completed(self) -> bool: ...
def is_sync(self) -> bool: ...
def synchronize(self) -> None: ...
def wait(self, timeout: int = 0) -> bool: ...
24 changes: 21 additions & 3 deletions python/paddle/distributed/communication/all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING, TypeVar

import numpy as np

import paddle
Expand All @@ -23,8 +27,20 @@
convert_tensor_to_object,
)

if TYPE_CHECKING:
from paddle import Tensor
from paddle.base.core import task
from paddle.distributed.communication.group import Group

_T = TypeVar("_T")


def all_gather(tensor_list, tensor, group=None, sync_op=True):
def all_gather(
tensor_list: list[Tensor],
tensor: Tensor,
group: Group | None = None,
sync_op: bool = True,
) -> task | None:
"""
Gather tensors from all participators and all get the result. As shown
Expand All @@ -42,7 +58,7 @@ def all_gather(tensor_list, tensor, group=None, sync_op=True):
should be float16, float32, float64, int32, int64, int8, uint8, bool, bfloat16, complex64 or complex128.
tensor (Tensor): The Tensor to send. Its data type
should be float16, float32, float64, int32, int64, int8, uint8, bool, bfloat16, complex64 or complex128.
group (Group, optional): The group instance return by new_group or None for global default group.
group (Group|None, optional): The group instance return by new_group or None for global default group.
sync_op (bool, optional): Whether this op is a sync op. The default value is True.
Returns:
Expand All @@ -68,7 +84,9 @@ def all_gather(tensor_list, tensor, group=None, sync_op=True):
return stream.all_gather(tensor_list, tensor, group, sync_op)


def all_gather_object(object_list, obj, group=None):
def all_gather_object(
object_list: list[_T], obj: _T, group: Group = None
) -> None:
"""
Gather picklable objects from all participators and all get the result. Similar to all_gather(), but python object can be passed in.
Expand Down

0 comments on commit b022ab4

Please sign in to comment.