diff --git a/mmrazor/core/searcher/evolution_search.py b/mmrazor/core/searcher/evolution_search.py index 43cfe558a..a4fdee41c 100644 --- a/mmrazor/core/searcher/evolution_search.py +++ b/mmrazor/core/searcher/evolution_search.py @@ -135,14 +135,11 @@ def search(self): if self.check_constraints(): self.candidate_pool.append(candidate) - - broadcast_candidate_pool = self.candidate_pool else: - broadcast_candidate_pool = [None] * self.candidate_pool_size - broadcast_candidate_pool = broadcast_object_list( - broadcast_candidate_pool) + self.candidate_pool = [None] * self.candidate_pool_size + broadcast_object_list(self.candidate_pool) - for i, candidate in enumerate(broadcast_candidate_pool): + for i, candidate in enumerate(self.candidate_pool): self.algorithm.mutator.set_subnet(candidate) outputs = self.test_fn(self.algorithm_for_test, self.dataloader) @@ -213,7 +210,7 @@ def search(self): self.logger.info( f'Epoch:[{epoch + 1}/{self.max_epoch}], top1_score: ' f'{list(self.top_k_candidates_with_score.keys())[0]}') - self.candidate_pool = broadcast_object_list(self.candidate_pool) + broadcast_object_list(self.candidate_pool) if rank == 0: final_subnet_dict = list( diff --git a/mmrazor/core/searcher/greedy_search.py b/mmrazor/core/searcher/greedy_search.py index 2d1410389..215d0fffe 100644 --- a/mmrazor/core/searcher/greedy_search.py +++ b/mmrazor/core/searcher/greedy_search.py @@ -146,7 +146,7 @@ def search(self): # Broadcasts scores in broadcast_scores to the whole # group. - broadcast_scores = broadcast_object_list(broadcast_scores) + broadcast_object_list(broadcast_scores) score = broadcast_scores[0] self.logger.info( f'Slimming group {name}, {self.score_key}: {score}') diff --git a/mmrazor/core/utils/__init__.py b/mmrazor/core/utils/__init__.py index 415b5418b..9267327cc 100644 --- a/mmrazor/core/utils/__init__.py +++ b/mmrazor/core/utils/__init__.py @@ -1,5 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. from .broadcast import broadcast_object_list from .lr import set_lr +from .utils import get_backend, get_default_group, get_rank, get_world_size -__all__ = ['broadcast_object_list', 'set_lr'] +__all__ = [ + 'broadcast_object_list', 'set_lr', 'get_world_size', 'get_rank', + 'get_backend', 'get_default_group' +] diff --git a/mmrazor/core/utils/broadcast.py b/mmrazor/core/utils/broadcast.py index 6326311c0..41c2998b4 100644 --- a/mmrazor/core/utils/broadcast.py +++ b/mmrazor/core/utils/broadcast.py @@ -1,49 +1,155 @@ # Copyright (c) OpenMMLab. All rights reserved. -import os.path as osp -import shutil -import tempfile +import pickle +import warnings +from typing import Any, List, Optional, Tuple -import mmcv.fileio import torch -import torch.distributed as dist -from mmcv.runner import get_dist_info +from mmcv.utils import TORCH_VERSION, digit_version +from torch import Tensor +from torch import distributed as dist +from .utils import get_backend, get_default_group, get_rank, get_world_size -def broadcast_object_list(object_list, src=0): - """Broadcasts picklable objects in ``object_list`` to the whole group. - Note that all objects in ``object_list`` must be picklable in order to be - broadcasted. +def _object_to_tensor(obj: Any) -> Tuple[Tensor, Tensor]: + """Serialize picklable python object to tensor.""" + byte_storage = torch.ByteStorage.from_buffer(pickle.dumps(obj)) + # Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor + # and specifying dtype. Otherwise, it will cause 100X slowdown. + # See: https://github.com/pytorch/pytorch/issues/65696 + byte_tensor = torch.ByteTensor(byte_storage) + local_size = torch.LongTensor([byte_tensor.numel()]) + return byte_tensor, local_size - Args: - object_list (List[Any]): List of input objects to broadcast. - Each object must be picklable. Only objects on the src rank will be - broadcast, but each rank must provide lists of equal sizes. - src (int): Source rank from which to broadcast ``object_list``. + +def _tensor_to_object(tensor: Tensor, tensor_size: int) -> Any: + """Deserialize tensor to picklable python object.""" + buf = tensor.cpu().numpy().tobytes()[:tensor_size] + return pickle.loads(buf) + + +def _broadcast_object_list(object_list: List[Any], + src: int = 0, + group: Optional[dist.ProcessGroup] = None) -> None: + """Broadcast picklable objects in ``object_list`` to the whole group. + + Similar to :func:`broadcast`, but Python objects can be passed in. Note + that all objects in ``object_list`` must be picklable in order to be + broadcasted. """ - my_rank, _ = get_dist_info() + if dist.distributed_c10d._rank_not_in_group(group): + return - MAX_LEN = 512 - # 32 is whitespace - dir_tensor = torch.full((MAX_LEN, ), 32, dtype=torch.uint8, device='cuda') - object_list_return = list() + my_rank = get_rank() + # Serialize object_list elements to tensors on src rank. if my_rank == src: - mmcv.mkdir_or_exist('.dist_broadcast') - tmpdir = tempfile.mkdtemp(dir='.dist_broadcast') - mmcv.dump(object_list, osp.join(tmpdir, 'object_list.pkl')) - tmpdir = torch.tensor( - bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda') - dir_tensor[:len(tmpdir)] = tmpdir + tensor_list, size_list = zip( + *[_object_to_tensor(obj) for obj in object_list]) + object_sizes_tensor = torch.cat(size_list) + else: + object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long) - dist.broadcast(dir_tensor, src) - tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip() + # Current device selection. + # To preserve backwards compatibility, ``device`` is ``None`` by default. + # in which case we run current logic of device selection, i.e. + # ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In + # the case it is not ``None`` we move the size and object tensors to be + # broadcasted to this device. + group_backend = get_backend(group) + is_nccl_backend = group_backend == dist.Backend.NCCL + current_device = torch.device('cpu') + if is_nccl_backend: + # See note about using torch.cuda.current_device() here in + # docstring. We cannot simply use my_rank since rank == device is + # not necessarily true. + current_device = torch.device('cuda', torch.cuda.current_device()) + object_sizes_tensor = object_sizes_tensor.to(current_device) - if my_rank != src: - object_list_return = mmcv.load(osp.join(tmpdir, 'object_list.pkl')) + # Broadcast object sizes + dist.broadcast(object_sizes_tensor, src=src, group=group) - dist.barrier() + # Concatenate and broadcast serialized object tensors if my_rank == src: - shutil.rmtree(tmpdir) - object_list_return = object_list + object_tensor = torch.cat(tensor_list) + else: + object_tensor = torch.empty( + torch.sum(object_sizes_tensor).int().item(), + dtype=torch.uint8, + ) + + if is_nccl_backend: + object_tensor = object_tensor.to(current_device) + dist.broadcast(object_tensor, src=src, group=group) + # Deserialize objects using their stored sizes. + offset = 0 + if my_rank != src: + for i, obj_size in enumerate(object_sizes_tensor): + obj_view = object_tensor[offset:offset + obj_size] + obj_view = obj_view.type(torch.uint8) + if obj_view.device != torch.device('cpu'): + obj_view = obj_view.cpu() + offset += obj_size + object_list[i] = _tensor_to_object(obj_view, obj_size) + + +def broadcast_object_list(data: List[Any], + src: int = 0, + group: Optional[dist.ProcessGroup] = None) -> None: + """Broadcasts picklable objects in ``object_list`` to the whole group. + Similar to :func:`broadcast`, but Python objects can be passed in. Note + that all objects in ``object_list`` must be picklable in order to be + broadcasted. + Note: + Calling ``broadcast_object_list`` in non-distributed environment does + nothing. + Args: + data (List[Any]): List of input objects to broadcast. + Each object must be picklable. Only objects on the ``src`` rank + will be broadcast, but each rank must provide lists of equal sizes. + src (int): Source rank from which to broadcast ``object_list``. + group: (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Default is ``None``. + device (``torch.device``, optional): If not None, the objects are + serialized and converted to tensors which are moved to the + ``device`` before broadcasting. Default is ``None``. + Note: + For NCCL-based process groups, internal tensor representations of + objects must be moved to the GPU device before communication starts. + In this case, the used device is given by + ``torch.cuda.current_device()`` and it is the user's responsibility to + ensure that this is correctly set so that each rank has an individual + GPU, via ``torch.cuda.set_device()``. + Examples: + >>> import torch + >>> import mmrazor.core.utils as dist + >>> # non-distributed environment + >>> data = ['foo', 12, {1: 2}] + >>> dist.broadcast_object_list(data) + >>> data + ['foo', 12, {1: 2}] + >>> # distributed environment + >>> # We have 2 process groups, 2 ranks. + >>> if dist.get_rank() == 0: + >>> # Assumes world_size of 3. + >>> data = ["foo", 12, {1: 2}] # any picklable object + >>> else: + >>> data = [None, None, None] + >>> dist.broadcast_object_list(data) + >>> data + ["foo", 12, {1: 2}] # Rank 0 + ["foo", 12, {1: 2}] # Rank 1 + """ + warnings.warn( + '`broadcast_object_list` is now without return value, ' + 'and it\'s input parameters are: `data`,`src` and ' + '`group`, but its function is similar to the old\'s', UserWarning) + assert isinstance(data, list) + + if get_world_size(group) > 1: + if group is None: + group = get_default_group() - return object_list_return + if digit_version(TORCH_VERSION) >= digit_version('1.8.0'): + dist.broadcast_object_list(data, src, group) + else: + _broadcast_object_list(data, src, group) diff --git a/mmrazor/core/utils/utils.py b/mmrazor/core/utils/utils.py new file mode 100644 index 000000000..58b87d2e5 --- /dev/null +++ b/mmrazor/core/utils/utils.py @@ -0,0 +1,89 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +from torch import distributed as dist + + +def is_distributed() -> bool: + """Return True if distributed environment has been initialized.""" + return dist.is_available() and dist.is_initialized() + + +def get_default_group() -> Optional[dist.ProcessGroup]: + """Return default process group.""" + + return dist.distributed_c10d._get_default_group() + + +def get_rank(group: Optional[dist.ProcessGroup] = None) -> int: + """Return the rank of the given process group. + + Rank is a unique identifier assigned to each process within a distributed + process group. They are always consecutive integers ranging from 0 to + ``world_size``. + Note: + Calling ``get_rank`` in non-distributed environment will return 0. + Args: + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Defaults to None. + Returns: + int: Return the rank of the process group if in distributed + environment, otherwise 0. + """ + + if is_distributed(): + # handle low versions of torch like 1.5.0 which does not support + # passing in None for group argument + if group is None: + group = get_default_group() + return dist.get_rank(group) + else: + return 0 + + +def get_backend(group: Optional[dist.ProcessGroup] = None) -> Optional[str]: + """Return the backend of the given process group. + + Note: + Calling ``get_backend`` in non-distributed environment will return + None. + Args: + group (ProcessGroup, optional): The process group to work on. The + default is the general main process group. If another specific + group is specified, the calling process must be part of + :attr:`group`. Defaults to None. + Returns: + str or None: Return the backend of the given process group as a lower + case string if in distributed environment, otherwise None. + """ + if is_distributed(): + # handle low versions of torch like 1.5.0 which does not support + # passing in None for group argument + if group is None: + group = get_default_group() + return dist.get_backend(group) + else: + return None + + +def get_world_size(group: Optional[dist.ProcessGroup] = None) -> int: + """Return the number of the given process group. + + Note: + Calling ``get_world_size`` in non-distributed environment will return + 1. + Args: + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Defaults to None. + Returns: + int: Return the number of processes of the given process group if in + distributed environment, otherwise 1. + """ + if is_distributed(): + # handle low versions of torch like 1.5.0 which does not support + # passing in None for group argument + if group is None: + group = get_default_group() + return dist.get_world_size(group) + else: + return 1