-
Notifications
You must be signed in to change notification settings - Fork 231
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Enhancement]Support broadcast_object_list in multi-machines & support Searcher running in single GPU #153
Merged
Merged
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,9 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .broadcast import broadcast_object_list | ||
from .lr import set_lr | ||
from .utils import get_backend, get_default_group, get_rank, get_world_size | ||
|
||
__all__ = ['broadcast_object_list', 'set_lr'] | ||
__all__ = [ | ||
'broadcast_object_list', 'set_lr', 'get_world_size', 'get_rank', | ||
'get_backend', 'get_default_group' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,49 +1,155 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import os.path as osp | ||
import shutil | ||
import tempfile | ||
import pickle | ||
import warnings | ||
from typing import Any, List, Optional, Tuple | ||
|
||
import mmcv.fileio | ||
import torch | ||
import torch.distributed as dist | ||
from mmcv.runner import get_dist_info | ||
from mmcv.utils import TORCH_VERSION, digit_version | ||
from torch import Tensor | ||
from torch import distributed as dist | ||
|
||
from .utils import get_backend, get_default_group, get_rank, get_world_size | ||
|
||
def broadcast_object_list(object_list, src=0): | ||
"""Broadcasts picklable objects in ``object_list`` to the whole group. | ||
|
||
Note that all objects in ``object_list`` must be picklable in order to be | ||
broadcasted. | ||
def _object_to_tensor(obj: Any) -> Tuple[Tensor, Tensor]: | ||
"""Serialize picklable python object to tensor.""" | ||
byte_storage = torch.ByteStorage.from_buffer(pickle.dumps(obj)) | ||
# Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor | ||
# and specifying dtype. Otherwise, it will cause 100X slowdown. | ||
# See: https://github.com/pytorch/pytorch/issues/65696 | ||
byte_tensor = torch.ByteTensor(byte_storage) | ||
local_size = torch.LongTensor([byte_tensor.numel()]) | ||
return byte_tensor, local_size | ||
|
||
Args: | ||
object_list (List[Any]): List of input objects to broadcast. | ||
Each object must be picklable. Only objects on the src rank will be | ||
broadcast, but each rank must provide lists of equal sizes. | ||
src (int): Source rank from which to broadcast ``object_list``. | ||
|
||
def _tensor_to_object(tensor: Tensor, tensor_size: int) -> Any: | ||
"""Deserialize tensor to picklable python object.""" | ||
buf = tensor.cpu().numpy().tobytes()[:tensor_size] | ||
return pickle.loads(buf) | ||
|
||
|
||
def _broadcast_object_list(object_list: List[Any], | ||
src: int = 0, | ||
group: Optional[dist.ProcessGroup] = None) -> None: | ||
"""Broadcast picklable objects in ``object_list`` to the whole group. | ||
|
||
Similar to :func:`broadcast`, but Python objects can be passed in. Note | ||
that all objects in ``object_list`` must be picklable in order to be | ||
broadcasted. | ||
""" | ||
my_rank, _ = get_dist_info() | ||
if dist.distributed_c10d._rank_not_in_group(group): | ||
return | ||
|
||
MAX_LEN = 512 | ||
# 32 is whitespace | ||
dir_tensor = torch.full((MAX_LEN, ), 32, dtype=torch.uint8, device='cuda') | ||
object_list_return = list() | ||
my_rank = get_rank() | ||
# Serialize object_list elements to tensors on src rank. | ||
if my_rank == src: | ||
mmcv.mkdir_or_exist('.dist_broadcast') | ||
tmpdir = tempfile.mkdtemp(dir='.dist_broadcast') | ||
mmcv.dump(object_list, osp.join(tmpdir, 'object_list.pkl')) | ||
tmpdir = torch.tensor( | ||
bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda') | ||
dir_tensor[:len(tmpdir)] = tmpdir | ||
tensor_list, size_list = zip( | ||
*[_object_to_tensor(obj) for obj in object_list]) | ||
object_sizes_tensor = torch.cat(size_list) | ||
else: | ||
object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long) | ||
|
||
dist.broadcast(dir_tensor, src) | ||
tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip() | ||
# Current device selection. | ||
# To preserve backwards compatibility, ``device`` is ``None`` by default. | ||
# in which case we run current logic of device selection, i.e. | ||
# ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In | ||
# the case it is not ``None`` we move the size and object tensors to be | ||
# broadcasted to this device. | ||
group_backend = get_backend(group) | ||
is_nccl_backend = group_backend == dist.Backend.NCCL | ||
current_device = torch.device('cpu') | ||
if is_nccl_backend: | ||
# See note about using torch.cuda.current_device() here in | ||
# docstring. We cannot simply use my_rank since rank == device is | ||
# not necessarily true. | ||
current_device = torch.device('cuda', torch.cuda.current_device()) | ||
object_sizes_tensor = object_sizes_tensor.to(current_device) | ||
|
||
if my_rank != src: | ||
object_list_return = mmcv.load(osp.join(tmpdir, 'object_list.pkl')) | ||
# Broadcast object sizes | ||
dist.broadcast(object_sizes_tensor, src=src, group=group) | ||
|
||
dist.barrier() | ||
# Concatenate and broadcast serialized object tensors | ||
if my_rank == src: | ||
shutil.rmtree(tmpdir) | ||
object_list_return = object_list | ||
object_tensor = torch.cat(tensor_list) | ||
else: | ||
object_tensor = torch.empty( | ||
torch.sum(object_sizes_tensor).int().item(), | ||
dtype=torch.uint8, | ||
) | ||
|
||
if is_nccl_backend: | ||
object_tensor = object_tensor.to(current_device) | ||
dist.broadcast(object_tensor, src=src, group=group) | ||
# Deserialize objects using their stored sizes. | ||
offset = 0 | ||
if my_rank != src: | ||
for i, obj_size in enumerate(object_sizes_tensor): | ||
obj_view = object_tensor[offset:offset + obj_size] | ||
obj_view = obj_view.type(torch.uint8) | ||
if obj_view.device != torch.device('cpu'): | ||
obj_view = obj_view.cpu() | ||
offset += obj_size | ||
object_list[i] = _tensor_to_object(obj_view, obj_size) | ||
|
||
|
||
def broadcast_object_list(data: List[Any], | ||
src: int = 0, | ||
group: Optional[dist.ProcessGroup] = None) -> None: | ||
"""Broadcasts picklable objects in ``object_list`` to the whole group. | ||
Similar to :func:`broadcast`, but Python objects can be passed in. Note | ||
that all objects in ``object_list`` must be picklable in order to be | ||
broadcasted. | ||
Note: | ||
Calling ``broadcast_object_list`` in non-distributed environment does | ||
nothing. | ||
Args: | ||
data (List[Any]): List of input objects to broadcast. | ||
Each object must be picklable. Only objects on the ``src`` rank | ||
will be broadcast, but each rank must provide lists of equal sizes. | ||
src (int): Source rank from which to broadcast ``object_list``. | ||
group: (ProcessGroup, optional): The process group to work on. If None, | ||
the default process group will be used. Default is ``None``. | ||
device (``torch.device``, optional): If not None, the objects are | ||
serialized and converted to tensors which are moved to the | ||
``device`` before broadcasting. Default is ``None``. | ||
Note: | ||
For NCCL-based process groups, internal tensor representations of | ||
objects must be moved to the GPU device before communication starts. | ||
In this case, the used device is given by | ||
``torch.cuda.current_device()`` and it is the user's responsibility to | ||
ensure that this is correctly set so that each rank has an individual | ||
GPU, via ``torch.cuda.set_device()``. | ||
Examples: | ||
>>> import torch | ||
>>> import mmrazor.core.utils as dist | ||
>>> # non-distributed environment | ||
>>> data = ['foo', 12, {1: 2}] | ||
>>> dist.broadcast_object_list(data) | ||
>>> data | ||
['foo', 12, {1: 2}] | ||
>>> # distributed environment | ||
>>> # We have 2 process groups, 2 ranks. | ||
>>> if dist.get_rank() == 0: | ||
>>> # Assumes world_size of 3. | ||
>>> data = ["foo", 12, {1: 2}] # any picklable object | ||
>>> else: | ||
>>> data = [None, None, None] | ||
>>> dist.broadcast_object_list(data) | ||
>>> data | ||
["foo", 12, {1: 2}] # Rank 0 | ||
["foo", 12, {1: 2}] # Rank 1 | ||
""" | ||
warnings.warn( | ||
'`broadcast_object_list` is now without return value, ' | ||
'and it\'s input parameters are: `data`,`src` and ' | ||
'`group`, but its function is similar to the old\'s', UserWarning) | ||
assert isinstance(data, list) | ||
|
||
if get_world_size(group) > 1: | ||
if group is None: | ||
group = get_default_group() | ||
|
||
return object_list_return | ||
if digit_version(TORCH_VERSION) >= digit_version('1.8.0'): | ||
dist.broadcast_object_list(data, src, group) | ||
else: | ||
_broadcast_object_list(data, src, group) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from typing import Optional | ||
|
||
from torch import distributed as dist | ||
|
||
|
||
def is_distributed() -> bool: | ||
"""Return True if distributed environment has been initialized.""" | ||
return dist.is_available() and dist.is_initialized() | ||
|
||
|
||
def get_default_group() -> Optional[dist.ProcessGroup]: | ||
"""Return default process group.""" | ||
|
||
return dist.distributed_c10d._get_default_group() | ||
|
||
|
||
def get_rank(group: Optional[dist.ProcessGroup] = None) -> int: | ||
"""Return the rank of the given process group. | ||
|
||
Rank is a unique identifier assigned to each process within a distributed | ||
process group. They are always consecutive integers ranging from 0 to | ||
``world_size``. | ||
Note: | ||
Calling ``get_rank`` in non-distributed environment will return 0. | ||
Args: | ||
group (ProcessGroup, optional): The process group to work on. If None, | ||
the default process group will be used. Defaults to None. | ||
Returns: | ||
int: Return the rank of the process group if in distributed | ||
environment, otherwise 0. | ||
""" | ||
|
||
if is_distributed(): | ||
# handle low versions of torch like 1.5.0 which does not support | ||
# passing in None for group argument | ||
if group is None: | ||
group = get_default_group() | ||
return dist.get_rank(group) | ||
else: | ||
return 0 | ||
|
||
|
||
def get_backend(group: Optional[dist.ProcessGroup] = None) -> Optional[str]: | ||
"""Return the backend of the given process group. | ||
|
||
Note: | ||
Calling ``get_backend`` in non-distributed environment will return | ||
None. | ||
Args: | ||
group (ProcessGroup, optional): The process group to work on. The | ||
default is the general main process group. If another specific | ||
group is specified, the calling process must be part of | ||
:attr:`group`. Defaults to None. | ||
Returns: | ||
str or None: Return the backend of the given process group as a lower | ||
case string if in distributed environment, otherwise None. | ||
""" | ||
if is_distributed(): | ||
# handle low versions of torch like 1.5.0 which does not support | ||
# passing in None for group argument | ||
if group is None: | ||
group = get_default_group() | ||
return dist.get_backend(group) | ||
else: | ||
return None | ||
|
||
|
||
def get_world_size(group: Optional[dist.ProcessGroup] = None) -> int: | ||
"""Return the number of the given process group. | ||
|
||
Note: | ||
Calling ``get_world_size`` in non-distributed environment will return | ||
1. | ||
Args: | ||
group (ProcessGroup, optional): The process group to work on. If None, | ||
the default process group will be used. Defaults to None. | ||
Returns: | ||
int: Return the number of processes of the given process group if in | ||
distributed environment, otherwise 1. | ||
""" | ||
if is_distributed(): | ||
# handle low versions of torch like 1.5.0 which does not support | ||
# passing in None for group argument | ||
if group is None: | ||
group = get_default_group() | ||
return dist.get_world_size(group) | ||
else: | ||
return 1 |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A warning needs to be added here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done