Skip to content

Commit

Permalink
[Typing][C-20] Add type annotations for `python/paddle/distributed/co…
Browse files Browse the repository at this point in the history
…mmunication/reduce.py` (PaddlePaddle#66803)
  • Loading branch information
megemini authored and Lans1ot committed Aug 5, 2024
1 parent 65bb318 commit 589c1ce
Showing 1 changed file with 27 additions and 8 deletions.
35 changes: 27 additions & 8 deletions python/paddle/distributed/communication/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING, ClassVar, Literal

import paddle
from paddle import framework
from paddle.distributed.communication import stream

if TYPE_CHECKING:
from typing_extensions import TypeAlias

from paddle import Tensor
from paddle.base.core import task
from paddle.distributed.communication.group import Group

_ReduceOp: TypeAlias = Literal[0, 1, 2, 3, 4]


class ReduceOp:
"""
Expand Down Expand Up @@ -48,11 +61,11 @@ class ReduceOp:
>>> # [[5, 7, 9], [5, 7, 9]] (2 GPUs)
"""

SUM = 0
MAX = 1
MIN = 2
PROD = 3
AVG = 4
SUM: ClassVar[Literal[0]] = 0
MAX: ClassVar[Literal[1]] = 1
MIN: ClassVar[Literal[2]] = 2
PROD: ClassVar[Literal[3]] = 3
AVG: ClassVar[Literal[4]] = 4


def _get_reduce_op(reduce_op, func_name):
Expand Down Expand Up @@ -86,7 +99,13 @@ def _to_inplace_op(op_name):
return f"{op_name}_"


def reduce(tensor, dst, op=ReduceOp.SUM, group=None, sync_op=True):
def reduce(
tensor: Tensor,
dst: int,
op: _ReduceOp = ReduceOp.SUM,
group: Group | None = None,
sync_op: bool = True,
) -> task:
"""
Reduce a tensor to the destination from all others. As shown below, one process is started with a GPU and the data of this process is represented
Expand All @@ -103,7 +122,7 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, sync_op=True):
should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
dst (int): The destination rank id.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD|ReduceOp.AVG, optional): The operation used. Default value is ReduceOp.SUM.
group (Group, optional): The group instance return by new_group or None for global default group.
group (Group|None, optional): The group instance return by new_group or None for global default group.
sync_op (bool, optional): Whether this op is a sync op. The default value is True.
Returns:
Expand Down Expand Up @@ -207,7 +226,7 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, sync_op=True):
raise ValueError(f"Unknown parameter: {op}.")


def is_avg_reduce_op_supported():
def is_avg_reduce_op_supported() -> bool:
if paddle.is_compiled_with_cuda():
return paddle.base.core.nccl_version() >= 21000
else:
Expand Down

0 comments on commit 589c1ce

Please sign in to comment.