Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Typing][C-21,C-22,C-23,C-24] Add type annotations for python/paddle/distributed/communication/{reduce_scatter.py, scatter.py, send.py, all_gather.py} #66864

Merged
merged 12 commits into from
Aug 6, 2024
40 changes: 36 additions & 4 deletions python/paddle/distributed/communication/reduce_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING

import paddle
from paddle.distributed.communication import stream
from paddle.distributed.communication.reduce import ReduceOp
from paddle.distributed.communication.stream.reduce_scatter import (
_reduce_scatter_base as _reduce_scatter_base_stream,
)

if TYPE_CHECKING:
from paddle import Tensor
from paddle.base.core import task
from paddle.distributed.communication.group import Group
from paddle.distributed.communication.reduce import _ReduceOp


def reduce_scatter(
tensor, tensor_list, op=ReduceOp.SUM, group=None, sync_op=True
):
tensor: Tensor,
tensor_list: list[Tensor],
op: _ReduceOp = ReduceOp.SUM,
group: Group | None = None,
sync_op: bool = True,
) -> task:
"""
Reduces, then scatters a list of tensors to all processes in a group

Expand Down Expand Up @@ -62,6 +76,16 @@ def reduce_scatter(
>>> # [8, 10] (2 GPUs, out for rank 1)

"""
if op not in [
ReduceOp.AVG,
ReduceOp.MAX,
ReduceOp.MIN,
ReduceOp.PROD,
ReduceOp.SUM,
]:
raise RuntimeError(
"Invalid ``op`` function. Expected ``op`` to be of type ``ReduceOp.SUM``, ``ReduceOp.Max``, ``ReduceOp.MIN``, ``ReduceOp.PROD`` or ``ReduceOp.AVG``."
)
Comment on lines +79 to +88
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里还有下面的检查代码是为了解决什么呢?如果没什么特别需要解决的问题最好还是不要加,因为本任务只涉及类型提示修改,不涉及运行时逻辑修改

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是因为我之前做其他任务的时候,有类似的情况。
就是参数只能在特定值中间进行选择,不然就会抛出一个错误
所以我这里就根据python/paddle/distributed/communication/batch_isend_irecv.py仿写了一段

# AVG is only supported when nccl >= 2.10
if op == ReduceOp.AVG and paddle.base.core.nccl_version() < 21000:
group = (
Expand Down Expand Up @@ -89,8 +113,12 @@ def reduce_scatter(


def _reduce_scatter_base(
output, input, op=ReduceOp.SUM, group=None, sync_op=True
):
output: Tensor,
input: Tensor,
op: _ReduceOp = ReduceOp.SUM,
group: Group | None = None,
sync_op: bool = True,
) -> task | None:
"""
Reduces, then scatters a flattened tensor to all processes in a group.

Expand Down Expand Up @@ -126,6 +154,10 @@ def _reduce_scatter_base(
>>> # [5, 7] (2 GPUs, out for rank 1)

"""
if op not in [ReduceOp.MAX, ReduceOp.MIN, ReduceOp.PROD, ReduceOp.SUM]:
raise RuntimeError(
"Invalid ``op`` function. Expected ``op`` to be of type ``ReduceOp.SUM``, ``ReduceOp.Max``, ``ReduceOp.MIN`` or ``ReduceOp.PROD``."
)
return _reduce_scatter_base_stream(
output,
input,
Expand Down
23 changes: 20 additions & 3 deletions python/paddle/distributed/communication/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING, Sequence

if TYPE_CHECKING:
from paddle import Tensor
from paddle.distributed.communication.group import Group

import numpy as np

import paddle
Expand All @@ -25,7 +33,13 @@
)


def scatter(tensor, tensor_list=None, src=0, group=None, sync_op=True):
def scatter(
tensor: Tensor,
tensor_list: Sequence[Tensor] | None = None,
src: int = 0,
group: Group | None = None,
sync_op: bool = True,
) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
) -> None:
) -> task | None:

task 记得 import

"""

Scatter a tensor to all participators. As shown below, one process is started with a GPU and the source of the scatter
Expand Down Expand Up @@ -72,8 +86,11 @@ def scatter(tensor, tensor_list=None, src=0, group=None, sync_op=True):


def scatter_object_list(
out_object_list, in_object_list=None, src=0, group=None
):
out_object_list: list[object],
in_object_list: list[object] | None = None,
Lans1ot marked this conversation as resolved.
Show resolved Hide resolved
src: int = 0,
group: Group | None = None,
) -> None:
"""

Scatter picklable objects from the source to all others. Similiar to scatter(), but python object can be passed in.
Expand Down
18 changes: 16 additions & 2 deletions python/paddle/distributed/communication/send.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING

from paddle.distributed.communication import stream

if TYPE_CHECKING:
from paddle import Tensor
from paddle.base.core import task
from paddle.distributed.communication.group import Group


def send(tensor, dst=0, group=None, sync_op=True):
def send(
tensor: Tensor,
dst: int = 0,
group: Group | None = None,
sync_op: bool = True,
) -> task:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
) -> task:
) -> task | None:

"""
Send a tensor to the receiver.

Expand Down Expand Up @@ -51,7 +65,7 @@ def send(tensor, dst=0, group=None, sync_op=True):
)


def isend(tensor, dst, group=None):
def isend(tensor: Tensor, dst: int, group: Group | None = None) -> task:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def isend(tensor: Tensor, dst: int, group: Group | None = None) -> task:
def isend(tensor: Tensor, dst: int, group: Group | None = None) -> task | None:

"""
Send tensor asynchronously

Expand Down
41 changes: 30 additions & 11 deletions python/paddle/distributed/communication/stream/all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING

import paddle
import paddle.distributed as dist
from paddle import framework
from paddle.base import data_feeder
from paddle.distributed.communication.group import _get_global_group

if TYPE_CHECKING:
from paddle import Tensor
from paddle.base.core import task
from paddle.distributed.communication.group import Group


def _all_gather_into_tensor_in_dygraph(
out_tensor, in_tensor, group, sync_op, use_calc_stream
):
out_tensor: Tensor,
in_tensor: Tensor,
group: Group,
sync_op: bool,
use_calc_stream: bool,
) -> task:
group = _get_global_group() if group is None else group

if use_calc_stream:
Expand All @@ -40,8 +53,12 @@ def _all_gather_into_tensor_in_dygraph(


def _all_gather_in_dygraph(
tensor_list, tensor, group, sync_op, use_calc_stream
):
tensor_list: list[Tensor],
tensor: Tensor,
group: Group,
sync_op: bool,
use_calc_stream: bool,
) -> task:
group = _get_global_group() if group is None else group

if len(tensor_list) == 0:
Expand All @@ -59,7 +76,9 @@ def _all_gather_in_dygraph(
return task


def _all_gather_in_static_mode(tensor_list, tensor, group, sync_op):
def _all_gather_in_static_mode(
tensor_list: list[Tensor], tensor: Tensor, group: Group, sync_op: bool
) -> task:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
) -> task:
) -> None:

op_type = 'all_gather'
helper = framework.LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
Expand Down Expand Up @@ -121,12 +140,12 @@ def _all_gather_in_static_mode(tensor_list, tensor, group, sync_op):


def all_gather(
tensor_or_tensor_list,
tensor,
group=None,
sync_op=True,
use_calc_stream=False,
):
tensor_or_tensor_list: Tensor | list[Tensor],
tensor: Tensor,
group: Group | None = None,
sync_op: bool = True,
use_calc_stream: bool = False,
) -> task:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
) -> task:
) -> task | None:

"""

Gather tensors across devices to a correctly-sized tensor or a tensor list.
Expand Down