Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add all to all op #6283

Merged
merged 37 commits into from
Dec 11, 2021
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
292107a
add all to all op
liufengwei0103 Sep 14, 2021
ac67f64
add barrier
liufengwei0103 Sep 14, 2021
8d30b4c
format
liufengwei0103 Sep 14, 2021
9934500
add import
liufengwei0103 Sep 14, 2021
06e289a
fix test
liufengwei0103 Sep 14, 2021
f3f9439
fix conflicts
liufengwei0103 Sep 16, 2021
7ddf79a
delete barrier
liufengwei0103 Sep 16, 2021
aa397ea
delete barrier
liufengwei0103 Sep 16, 2021
c21dab0
Revert "delete barrier"
liufengwei0103 Sep 16, 2021
ab5c0fd
Revert "delete barrier"
liufengwei0103 Sep 16, 2021
038fc4d
check tensor meta between ranks
liufengwei0103 Sep 16, 2021
7bb0a3a
add more assert
liufengwei0103 Sep 17, 2021
df9a929
all_reduce operate in place
liufengwei0103 Sep 17, 2021
b897560
all_reduce operate in place
liufengwei0103 Sep 17, 2021
df59555
fix bug
liufengwei0103 Sep 22, 2021
5c19442
assert tensor.is_local
liufengwei0103 Sep 22, 2021
a820137
fix bug in scatter
liufengwei0103 Sep 22, 2021
c58b0be
add more assert
liufengwei0103 Sep 22, 2021
6a126bd
Merge branch 'master' of github.com:Oneflow-Inc/oneflow into comm_all…
liufengwei0103 Sep 22, 2021
916a4be
delete meta check
liufengwei0103 Sep 22, 2021
31f437c
add pytorch comparison test
liufengwei0103 Sep 22, 2021
848ae80
add pytorch comparison test
liufengwei0103 Sep 23, 2021
ecb2012
refine
liufengwei0103 Sep 23, 2021
9c10ea0
fix conflicts
liufengwei0103 Oct 9, 2021
140f70a
Merge branch 'master' into comm_all_to_all
oneflow-ci-bot Dec 10, 2021
9e45995
Merge branch 'master' into comm_all_to_all
oneflow-ci-bot Dec 10, 2021
6c8c8fa
add ONEFLOW_TEST_CPU_ONLY
liufengwei0103 Dec 10, 2021
36b4ae2
Merge branch 'comm_all_to_all' of github.com:Oneflow-Inc/oneflow into…
liufengwei0103 Dec 10, 2021
4cd371d
Merge branch 'master' into comm_all_to_all
oneflow-ci-bot Dec 10, 2021
664de09
Merge branch 'master' into comm_all_to_all
oneflow-ci-bot Dec 10, 2021
5498f2b
Merge branch 'master' into comm_all_to_all
oneflow-ci-bot Dec 10, 2021
9e4157c
Merge branch 'master' into comm_all_to_all
oneflow-ci-bot Dec 10, 2021
e02d5c7
fix bug from torch gloo
liufengwei0103 Dec 11, 2021
542c441
Merge branch 'master' into comm_all_to_all
oneflow-ci-bot Dec 11, 2021
1a79b9b
Merge branch 'master' into comm_all_to_all
oneflow-ci-bot Dec 11, 2021
a3ffeb3
Merge branch 'master' into comm_all_to_all
oneflow-ci-bot Dec 11, 2021
4b35b2f
Merge branch 'master' into comm_all_to_all
oneflow-ci-bot Dec 11, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/comm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ oneflow communication function
all_gather,
broadcast,
scatter,
all_to_all,
reduce,
gather,
reduce_scatter,
send,
recv,
barrier,
2 changes: 2 additions & 0 deletions python/oneflow/comm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from oneflow.comm.comm_ops import broadcast
from oneflow.comm.comm_ops import scatter
from oneflow.comm.comm_ops import reduce
from oneflow.comm.comm_ops import all_to_all
from oneflow.comm.comm_ops import barrier
from oneflow.comm.comm_ops import reduce_scatter
from oneflow.comm.comm_ops import gather
from oneflow._C import send, recv
81 changes: 64 additions & 17 deletions python/oneflow/comm/comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,17 @@ def all_reduce(tensor):
>>> # We have 1 process groups, 2 ranks.
>>> import oneflow as flow

>>> input = flow.tensor([[1, 2], [3, 4]], device="cuda") + flow.env.get_local_rank()
>>> # input on rank0
>>> input # doctest: +ONLY_CHECK_RANK_0
>>> tensor = flow.tensor([[1, 2], [3, 4]], device="cuda") + flow.env.get_local_rank()
>>> # tensor on rank0
>>> tensor # doctest: +ONLY_CHECK_RANK_0
tensor([[1, 2],
[3, 4]], device='cuda:0', dtype=oneflow.int64)
>>> # input on rank1
>>> input # doctest: +ONLY_CHECK_RANK_1
>>> # tensor on rank1
>>> tensor # doctest: +ONLY_CHECK_RANK_1
tensor([[2, 3],
[4, 5]], device='cuda:1', dtype=oneflow.int64)
>>> out = flow.comm.all_reduce(input)
>>> out.numpy()
>>> flow.comm.all_reduce(tensor)
>>> tensor.numpy()
array([[3, 5],
[7, 9]])

Expand All @@ -54,11 +54,11 @@ def all_reduce(tensor):
assert tensor.is_local
device_type = tensor.device.type
placement = flow.env.all_device_placement(device_type)
tensor = tensor.to_consistent(
result = tensor.to_consistent(
placement=placement, sbp=flow.sbp.partial_sum
).to_consistent(placement=placement, sbp=flow.sbp.broadcast)

return tensor.to_local()
tensor.data = result.to_local()


def all_gather(tensor_list, tensor):
Expand Down Expand Up @@ -107,10 +107,10 @@ def all_gather(tensor_list, tensor):
assert tensor.is_local
tensor = tensor.expand([1] + list(tensor.shape))
device_type = tensor.device.type
placement = flow.env.all_device_placement(device_type)
tensor = tensor.to_consistent(
placement=flow.env.all_device_placement(device_type), sbp=flow.sbp.split(0)
)
tensor = tensor.to_consistent(sbp=flow.sbp.broadcast)
placement=placement, sbp=flow.sbp.split(0)
).to_consistent(placement=placement, sbp=flow.sbp.broadcast)
assert len(tensor_list) == flow.env.get_world_size()
for i in range(tensor.shape[0]):
tensor_list[i] = tensor[i].to_local()
Expand Down Expand Up @@ -203,9 +203,58 @@ def reduce(tensor, dst):
assert isinstance(tensor, flow._oneflow_internal.Tensor)
assert tensor.is_local
assert isinstance(dst, int)
result = flow.comm.all_reduce(tensor)
if flow.env.get_rank() == dst:
tensor.data = result
original_tensor = flow._C.identity(tensor)
flow.comm.all_reduce(tensor)
if flow.env.get_rank() != dst:
tensor.data = original_tensor


def all_to_all(output_tensor_list, input_tensor_list):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

api的语义和pytorch对齐了吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

api的语义和pytorch对齐了吗?

对齐了

"""
Each process scatters list of input tensors to all processes in a group and
return gathered list of tensors in output list.

Args:
output_tensor_list (list[Tensor]): List of tensors to be gathered one
per rank.
input_tensor_list (list[Tensor]): List of tensors to scatter one per rank.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我看torch还有group参数?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我看torch还有group参数?

底层暂时不支持对齐这个参数。


"""

def _check_list(tensor_list):
assert isinstance(tensor_list, list)
assert len(tensor_list) == flow.env.get_world_size()
shape = tensor_list[0].shape
dtype = tensor_list[0].dtype
device = tensor_list[0].device
for tensor in tensor_list:
assert isinstance(tensor, flow._oneflow_internal.Tensor)
assert tensor.is_local
assert shape == tensor.shape
assert dtype == tensor.dtype
assert device == tensor.device

_check_list(output_tensor_list)
_check_list(input_tensor_list)

assert input_tensor_list[0].shape == output_tensor_list[0].shape
assert input_tensor_list[0].dtype == output_tensor_list[0].dtype
assert input_tensor_list[0].device == output_tensor_list[0].device

for i in range(flow.env.get_world_size()):
flow.comm.scatter(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的scatter和pytorch对齐了吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的scatter和pytorch对齐了吗?

已加

output_tensor_list[i],
input_tensor_list if i == flow.env.get_rank() else [],
src=i,
)


def barrier():
"""
Synchronizes all processes.

"""
oneflow._oneflow_internal.eager.multi_client.Sync()


def reduce_scatter(output, input_list):
Expand Down Expand Up @@ -266,7 +315,5 @@ def gather(tensor, gather_list=None, dst=0):
assert gather_list is not None
assert isinstance(gather_list, list)
assert len(gather_list) == flow.env.get_world_size()
# "to_consistent(placement=flow.env.all_device_placement("cuda/cpu"), sbp=flow.sbp.broadcast)"
# after here will fail, if do getitem on some a rank
for i in range(tensor.shape[0]):
gather_list[i] = tensor[i].to_local()
Loading