Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Typing][C-25,C-26] Add type annotations for python/paddle/distributed/communication/stream/{all_reduce,all_to_all}.py #67112

Merged
merged 2 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 32 additions & 6 deletions python/paddle/distributed/communication/stream/all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from typing import TYPE_CHECKING

from paddle import _C_ops, framework
from paddle.base import data_feeder
Expand All @@ -25,8 +28,21 @@
)
from paddle.framework import in_pir_mode

if TYPE_CHECKING:
from paddle import Tensor
from paddle.base.core import task
from paddle.distributed.communication.group import Group

from ..all_reduce import _ReduceOp


def _all_reduce_in_dygraph(tensor, op, group, sync_op, use_calc_stream):
def _all_reduce_in_dygraph(
tensor: Tensor,
op: _ReduceOp,
group: Group,
sync_op: bool,
use_calc_stream: bool,
) -> task:
op_type = _get_reduce_op(op, "allreduce")

if use_calc_stream:
Expand All @@ -39,7 +55,13 @@ def _all_reduce_in_dygraph(tensor, op, group, sync_op, use_calc_stream):
return task


def _all_reduce_in_static_mode(tensor, op, group, sync_op, use_calc_stream):
def _all_reduce_in_static_mode(
tensor: Tensor,
op: _ReduceOp,
group: Group,
sync_op: bool,
use_calc_stream: bool,
) -> None:
data_feeder.check_variable_and_dtype(
tensor,
'tensor',
Expand Down Expand Up @@ -80,8 +102,12 @@ def _all_reduce_in_static_mode(tensor, op, group, sync_op, use_calc_stream):


def all_reduce(
tensor, op=ReduceOp.SUM, group=None, sync_op=True, use_calc_stream=False
):
tensor: Tensor,
op: _ReduceOp = ReduceOp.SUM,
group: Group | None = None,
sync_op: bool = True,
use_calc_stream: bool = False,
) -> task | None:
"""

Perform specific reduction (for example, sum, max) on inputs across devices.
Expand All @@ -90,7 +116,7 @@ def all_reduce(
tensor (Tensor): The input tensor on each rank. The result will overwrite this tenor after communication. Support
float16, float32, float64, int32, int64, int8, uint8 or bool as the input data type.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD, optional): The reduction used. If none is given, use ReduceOp.SUM as default.
group (Group, optional): Communicate in which group. If none is given, use the global group as default.
group (Group|None, optional): Communicate in which group. If none is given, use the global group as default.
sync_op (bool, optional): Indicate whether the communication is sync or not. If none is given, use true as default.
use_calc_stream (bool, optional): Indicate whether the communication is done on calculation stream. If none is given, use false as default. This
option is designed for high performance demand, be careful to turn it on except you are clearly know its meaning.
Expand All @@ -113,7 +139,7 @@ def all_reduce(
>>> else:
... data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
>>> task = dist.stream.all_reduce(data, sync_op=False)
>>> task.wait()
>>> task.wait() # type: ignore[union-attr]
>>> out = data
>>> print(out)
[[5, 7, 9], [5, 7, 9]]
Expand Down
96 changes: 57 additions & 39 deletions python/paddle/distributed/communication/stream/all_to_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from typing import TYPE_CHECKING

import paddle
import paddle.distributed as dist
Expand All @@ -21,10 +24,21 @@
_warn_cur_rank_not_in_group,
)

if TYPE_CHECKING:
from collections.abc import Sequence

from paddle import Tensor
from paddle.base.core import task
from paddle.distributed.communication.group import Group


def _all_to_all_tensor_in_dygraph(
out_tensor, in_tensor, group, sync_op, use_calc_stream
):
out_tensor: Tensor,
in_tensor: Tensor,
group: Group,
sync_op: bool,
use_calc_stream: bool,
) -> task:
if use_calc_stream:
return group.process_group.all_to_all_tensor_on_calc_stream(
in_tensor, out_tensor
Expand All @@ -38,8 +52,12 @@ def _all_to_all_tensor_in_dygraph(


def _all_to_all_in_dygraph(
out_tensor_list, in_tensor_list, group, sync_op, use_calc_stream
):
out_tensor_list: Tensor,
in_tensor_list: Tensor,
group: Group,
sync_op: bool,
use_calc_stream: bool,
) -> task:
if len(in_tensor_list) == 0:
raise RuntimeError("The input tensor_list should not be empty.")

Expand All @@ -63,12 +81,12 @@ def _all_to_all_in_dygraph(


def _all_to_all_in_static_mode(
out_tensor_or_tensor_list,
in_tensor_or_tensor_list,
group,
sync_op,
use_calc_stream,
):
out_tensor_or_tensor_list: Tensor | Sequence[Tensor],
in_tensor_or_tensor_list: Tensor | Sequence[Tensor],
group: Group,
sync_op: bool,
use_calc_stream: bool,
) -> None:
op_type = 'alltoall'
ring_id = 0 if group is None else group.id
nranks = dist.get_world_size()
Expand Down Expand Up @@ -125,12 +143,12 @@ def _all_to_all_in_static_mode(


def alltoall(
out_tensor_or_tensor_list,
in_tensor_or_tensor_list,
group=None,
sync_op=True,
use_calc_stream=False,
):
out_tensor_or_tensor_list: Tensor | Sequence[Tensor],
in_tensor_or_tensor_list: Tensor | Sequence[Tensor],
group: Group | None = None,
sync_op: bool = True,
use_calc_stream: bool = False,
) -> task | None:
"""

Scatter a tensor (or a tensor list) across devices and gather outputs to another tensor (or a tensor list, respectively).
Expand All @@ -141,7 +159,7 @@ def alltoall(
in_tensor_or_tensor_list (Union[Tensor, List[Tensor]]): The input to scatter (must be specified on the source rank).
If it is a tensor, it should be correctly-sized. If it is a list, it should contain correctly-sized tensors. Support
float16, float32, float64, int32, int64, int8, uint8 or bool as the input data type.
group (Group, optional): Communicate in which group. If none is given, use the global group as default.
group (Group|None, optional): Communicate in which group. If none is given, use the global group as default.
sync_op (bool, optional): Indicate whether the communication is sync or not. If none is given, use true as default.
use_calc_stream (bool, optional): Indicate whether the communication is done on calculation stream. If none is given, use false as default. This
option is designed for high performance demand, be careful to turn it on except you are clearly know its meaning.
Expand All @@ -165,7 +183,7 @@ def alltoall(
... data1 = paddle.to_tensor([[13, 14, 15], [16, 17, 18]])
... data2 = paddle.to_tensor([[19, 20, 21], [22, 23, 24]])
>>> task = dist.stream.alltoall(out_tensor_list, [data1, data2], sync_op=False)
>>> task.wait()
>>> task.wait() # type: ignore[union-attr]
>>> print(out_tensor_list)
>>> # [[[1, 2, 3], [4, 5, 6]], [[13, 14, 15], [16, 17, 18]]] (2 GPUs, out for rank 0)
>>> # [[[7, 8, 9], [10, 11, 12]], [[19, 20, 21], [22, 23, 24]]] (2 GPUs, out for rank 1)
Expand Down Expand Up @@ -221,14 +239,14 @@ def alltoall(


def _alltoall_single_in_dygraph(
out_tensor,
in_tensor,
out_split_sizes,
in_split_sizes,
group,
sync_op,
use_calc_stream,
):
out_tensor: Tensor,
in_tensor: Tensor,
out_split_sizes: list[int],
in_split_sizes: list[int],
group: Group,
sync_op: bool,
use_calc_stream: bool,
) -> task:
world_size = dist.get_world_size(group)
if out_split_sizes is None:
out_split_sizes = [
Expand All @@ -254,26 +272,26 @@ def _alltoall_single_in_dygraph(


def alltoall_single(
out_tensor,
in_tensor,
out_split_sizes=None,
in_split_sizes=None,
group=None,
sync_op=True,
use_calc_stream=False,
):
out_tensor: Tensor,
in_tensor: Tensor,
out_split_sizes: list[int] | None = None,
in_split_sizes: list[int] | None = None,
group: Group | None = None,
sync_op: bool = True,
use_calc_stream: bool = False,
) -> task:
"""

Split and Scatter the splitted input tensor to the out tensor across devices.

Args:
out_tensor(Tensor): The output tensor. Its data type should be the same as the input.
in_tensor (Tensor): The input tensor. Its data type should be float16, float32, float64, int32, int64, int8, uint8 or bool.
out_split_sizes (List[int], optional): Split sizes of out_tensor for dim[0]. If not given, dim[0] of out_tensor must be divisible
out_split_sizes (List[int]|None, optional): Split sizes of out_tensor for dim[0]. If not given, dim[0] of out_tensor must be divisible
by group size and out_tensor will be gathered averagely from all participators. If none is given, use a empty list as default.
in_split_sizes (List[int], optional): Split sizes of in_tensor for dim[0]. If not given, dim[0] of in_tensor must be divisible
in_split_sizes (List[int]|None, optional): Split sizes of in_tensor for dim[0]. If not given, dim[0] of in_tensor must be divisible
by group size and in_tensor will be scattered averagely to all participators. If none is given, use a empty list as default.
group (Group, optional): Communicate in which group. If none is given, use the global group as default.
group (Group|None, optional): Communicate in which group. If none is given, use the global group as default.
sync_op (bool, optional): Indicate whether the communication is sync or not. If none is given, use true as default.
use_calc_stream (bool, optional): Indicate whether the communication is done on calculation stream. If none is given, use false as default. This
option is designed for high performance demand, be careful to turn it on except you are clearly know its meaning.
Expand Down Expand Up @@ -301,7 +319,7 @@ def alltoall_single(
>>> else:
... data = paddle.to_tensor([2, 3])
>>> task = dist.stream.alltoall_single(output, data, sync_op=False)
>>> task.wait()
>>> task.wait() # type: ignore[union-attr]
>>> out = output.numpy()
>>> print(out)
>>> # [0, 2] (2 GPUs, out for rank 0)
Expand All @@ -321,7 +339,7 @@ def alltoall_single(
... out_split_sizes,
... in_split_sizes,
... sync_op=False)
>>> task.wait()
>>> task.wait() # type: ignore[union-attr]
>>> out = output.numpy()
>>> print(out)
>>> # [[0., 0.], [1., 1.]] (2 GPUs, out for rank 0)
Expand Down