-
Notifications
You must be signed in to change notification settings - Fork 5.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* gather with doc * resolve comment * polish * polish * code style * polish doc * add_test * polish * polish * add test check * add test check * polish * polish * polish * polish * fix_time_out * polish * fix timeout * fix_timeout * polish * polish * polish * polish * polish
- Loading branch information
1 parent
20ee0d7
commit 77d2485
Showing
20 changed files
with
523 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
from paddle import framework | ||
from paddle.distributed.communication import stream | ||
|
||
|
||
def gather(tensor, gather_list=None, dst=0, group=None, sync_op=True): | ||
""" | ||
Gather tensors from all participators. | ||
Args: | ||
tensor (Tensor): The input Tensor. Its data type | ||
should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16. | ||
gather_list (list): A list of Tensors to hold the gathered tensors. Every element in the list must be a Tensor whose data type | ||
should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16. Default value is None. | ||
dst (int): The dst rank id. Default value is 0. | ||
group (Group, optional): The group instance return by new_group or None for global default group. | ||
sync_op (bool, optional): Whether this op is a sync op. The default value is True. | ||
Returns: | ||
Async work handle,which can be wait on, if async_op is set to True. | ||
None, if not async_op | ||
Examples: | ||
.. code-block:: python | ||
# required: distributed | ||
import paddle | ||
import paddle.distributed as dist | ||
dist.init_parallel_env() | ||
gather_list = [] | ||
if dist.get_rank() == 0: | ||
data = paddle.to_tensor([1, 2, 3]) | ||
dist.gather(data, gather_list, dst=0) | ||
else: | ||
data = paddle.to_tensor([4, 5, 6]) | ||
dist.gather(data1, gather_list, dst=0) | ||
print(gather_list) | ||
# [[1, 2, 3], [4, 5, 6]] (2 GPUs, out for rank 0) | ||
# [] (2 GPUs, out for rank 1) | ||
""" | ||
assert ( | ||
framework.in_dygraph_mode() | ||
), "gather doesn't support static graph mode yet." | ||
return stream.gather(tensor, gather_list, dst, group, sync_op) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.