-
Notifications
You must be signed in to change notification settings - Fork 691
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add all to all op #6283
add all to all op #6283
Changes from all commits
292107a
ac67f64
8d30b4c
9934500
06e289a
f3f9439
7ddf79a
aa397ea
c21dab0
ab5c0fd
038fc4d
7bb0a3a
df9a929
b897560
df59555
5c19442
a820137
c58b0be
6a126bd
916a4be
31f437c
848ae80
ecb2012
9c10ea0
140f70a
9e45995
6c8c8fa
36b4ae2
4cd371d
664de09
5498f2b
9e4157c
e02d5c7
542c441
1a79b9b
a3ffeb3
4b35b2f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,17 +34,17 @@ def all_reduce(tensor): | |
>>> # We have 1 process groups, 2 ranks. | ||
>>> import oneflow as flow | ||
|
||
>>> input = flow.tensor([[1, 2], [3, 4]], device="cuda") + flow.env.get_local_rank() | ||
>>> # input on rank0 | ||
>>> input # doctest: +ONLY_CHECK_RANK_0 | ||
>>> tensor = flow.tensor([[1, 2], [3, 4]], device="cuda") + flow.env.get_local_rank() | ||
>>> # tensor on rank0 | ||
>>> tensor # doctest: +ONLY_CHECK_RANK_0 | ||
tensor([[1, 2], | ||
[3, 4]], device='cuda:0', dtype=oneflow.int64) | ||
>>> # input on rank1 | ||
>>> input # doctest: +ONLY_CHECK_RANK_1 | ||
>>> # tensor on rank1 | ||
>>> tensor # doctest: +ONLY_CHECK_RANK_1 | ||
tensor([[2, 3], | ||
[4, 5]], device='cuda:1', dtype=oneflow.int64) | ||
>>> out = flow.comm.all_reduce(input) | ||
>>> out.numpy() | ||
>>> flow.comm.all_reduce(tensor) | ||
>>> tensor.numpy() | ||
array([[3, 5], | ||
[7, 9]]) | ||
|
||
|
@@ -54,11 +54,11 @@ def all_reduce(tensor): | |
assert tensor.is_local | ||
device_type = tensor.device.type | ||
placement = flow.env.all_device_placement(device_type) | ||
tensor = tensor.to_consistent( | ||
result = tensor.to_consistent( | ||
placement=placement, sbp=flow.sbp.partial_sum | ||
).to_consistent(placement=placement, sbp=flow.sbp.broadcast) | ||
|
||
return tensor.to_local() | ||
tensor.data = result.to_local() | ||
|
||
|
||
def all_gather(tensor_list, tensor): | ||
|
@@ -107,10 +107,10 @@ def all_gather(tensor_list, tensor): | |
assert tensor.is_local | ||
tensor = tensor.expand(*([1] + list(tensor.shape))) | ||
device_type = tensor.device.type | ||
placement = flow.env.all_device_placement(device_type) | ||
tensor = tensor.to_consistent( | ||
placement=flow.env.all_device_placement(device_type), sbp=flow.sbp.split(0) | ||
) | ||
tensor = tensor.to_consistent(sbp=flow.sbp.broadcast) | ||
placement=placement, sbp=flow.sbp.split(0) | ||
).to_consistent(placement=placement, sbp=flow.sbp.broadcast) | ||
assert len(tensor_list) == flow.env.get_world_size() | ||
for i in range(tensor.shape[0]): | ||
tensor_list[i] = tensor[i].to_local() | ||
|
@@ -203,9 +203,58 @@ def reduce(tensor, dst): | |
assert isinstance(tensor, flow._oneflow_internal.Tensor) | ||
assert tensor.is_local | ||
assert isinstance(dst, int) | ||
result = flow.comm.all_reduce(tensor) | ||
if flow.env.get_rank() == dst: | ||
tensor.data = result | ||
original_tensor = flow._C.identity(tensor) | ||
flow.comm.all_reduce(tensor) | ||
if flow.env.get_rank() != dst: | ||
tensor.data = original_tensor | ||
|
||
|
||
def all_to_all(output_tensor_list, input_tensor_list): | ||
""" | ||
Each process scatters list of input tensors to all processes in a group and | ||
return gathered list of tensors in output list. | ||
|
||
Args: | ||
output_tensor_list (list[Tensor]): List of tensors to be gathered one | ||
per rank. | ||
input_tensor_list (list[Tensor]): List of tensors to scatter one per rank. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 我看torch还有group参数? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
底层暂时不支持对齐这个参数。 |
||
|
||
""" | ||
|
||
def _check_list(tensor_list): | ||
assert isinstance(tensor_list, list) | ||
assert len(tensor_list) == flow.env.get_world_size() | ||
shape = tensor_list[0].shape | ||
dtype = tensor_list[0].dtype | ||
device = tensor_list[0].device | ||
for tensor in tensor_list: | ||
assert isinstance(tensor, flow._oneflow_internal.Tensor) | ||
assert tensor.is_local | ||
assert shape == tensor.shape | ||
assert dtype == tensor.dtype | ||
assert device == tensor.device | ||
|
||
_check_list(output_tensor_list) | ||
_check_list(input_tensor_list) | ||
|
||
assert input_tensor_list[0].shape == output_tensor_list[0].shape | ||
assert input_tensor_list[0].dtype == output_tensor_list[0].dtype | ||
assert input_tensor_list[0].device == output_tensor_list[0].device | ||
|
||
for i in range(flow.env.get_world_size()): | ||
flow.comm.scatter( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里的scatter和pytorch对齐了吗? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
已加 |
||
output_tensor_list[i], | ||
input_tensor_list if i == flow.env.get_rank() else [], | ||
src=i, | ||
) | ||
|
||
|
||
def barrier(): | ||
""" | ||
Synchronizes all processes. | ||
|
||
""" | ||
oneflow._oneflow_internal.eager.multi_client.Sync() | ||
|
||
|
||
def reduce_scatter(output, input_list): | ||
|
@@ -266,7 +315,5 @@ def gather(tensor, gather_list=None, dst=0): | |
assert gather_list is not None | ||
assert isinstance(gather_list, list) | ||
assert len(gather_list) == flow.env.get_world_size() | ||
# "to_consistent(placement=flow.env.all_device_placement("cuda/cpu"), sbp=flow.sbp.broadcast)" | ||
# after here will fail, if do getitem on some a rank | ||
for i in range(tensor.shape[0]): | ||
gather_list[i] = tensor[i].to_local() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
api的语义和pytorch对齐了吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
对齐了