diff --git a/ignite/distributed/comp_models/base.py b/ignite/distributed/comp_models/base.py index f31b074c398e..d3a12dc8c17a 100644 --- a/ignite/distributed/comp_models/base.py +++ b/ignite/distributed/comp_models/base.py @@ -136,7 +136,7 @@ def _apply_op( def _collective_op( self, tensor: Union[torch.Tensor, Number, str], fn: Callable, *args: Any, **kwargs: Any - ) -> Union[torch.Tensor, Number, List[str]]: + ) -> Union[torch.Tensor, Number, List[Number], List[str]]: tensor_to_number = tensor_to_str = False device = self.device() if isinstance(tensor, Number): @@ -148,8 +148,11 @@ def _collective_op( tensor = self._apply_op(tensor, device, fn, *args, **kwargs) - if tensor_to_number and tensor.numel() == 1: - return cast(Number, tensor.item()) + if tensor_to_number: + if tensor.numel() == 1: + return cast(Number, tensor.item()) + else: + return tensor.tolist() elif tensor_to_str: return self._decode_str(tensor) return tensor @@ -160,7 +163,9 @@ def all_reduce(self, tensor: Union[torch.Tensor, Number], op: str = "sum") -> Un return cast(Union[torch.Tensor, Number], self._collective_op(tensor, self._do_all_reduce, op)) - def all_gather(self, tensor: Union[torch.Tensor, Number, str]) -> Union[torch.Tensor, Number, List[str]]: + def all_gather( + self, tensor: Union[torch.Tensor, Number, str] + ) -> Union[torch.Tensor, Number, List[Number], List[str]]: if not isinstance(tensor, (torch.Tensor, Number, str)): raise TypeError("Unhandled input type {}".format(type(tensor))) @@ -271,8 +276,12 @@ def spawn(*args: Any, **kwargs: Any) -> None: def all_reduce(self, tensor: Union[torch.Tensor, Number], op: str = "sum") -> Union[torch.Tensor, Number]: return tensor - def all_gather(self, tensor: Union[torch.Tensor, Number]) -> Union[torch.Tensor, Number]: # type: ignore - return tensor + def all_gather( + self, tensor: Union[torch.Tensor, Number, str] + ) -> Union[torch.Tensor, Number, List[Number], List[str]]: + if isinstance(tensor, torch.Tensor): + return tensor + return cast(Union[List[Number], List[str]], [tensor]) def broadcast(self, tensor: Union[torch.Tensor, Number, str], src: int = 0) -> Union[torch.Tensor, Number, str]: return tensor diff --git a/ignite/distributed/utils.py b/ignite/distributed/utils.py index ec17dbe9a29b..81467b4eca1f 100644 --- a/ignite/distributed/utils.py +++ b/ignite/distributed/utils.py @@ -334,7 +334,7 @@ def all_reduce(tensor: Union[torch.Tensor, Number], op: str = "SUM") -> Union[to return _model.all_reduce(tensor, op) -def all_gather(tensor: Union[torch.Tensor, Number, str]) -> Union[torch.Tensor, Number, List[str]]: +def all_gather(tensor: Union[torch.Tensor, Number, str]) -> Union[torch.Tensor, Number, List[Number], List[str]]: """Helper method to perform all gather operation. Args: @@ -349,7 +349,7 @@ def all_gather(tensor: Union[torch.Tensor, Number, str]) -> Union[torch.Tensor, if _need_to_sync and isinstance(_model, _SerialModel): sync(temporary=True) - return _model.all_gather(tensor) # type: ignore[arg-type] + return _model.all_gather(tensor) def broadcast(tensor: Union[torch.Tensor, Number, str], src: int = 0) -> Union[torch.Tensor, Number, str]: diff --git a/tests/ignite/distributed/utils/__init__.py b/tests/ignite/distributed/utils/__init__.py index f109bce006ed..3c16da42c506 100644 --- a/tests/ignite/distributed/utils/__init__.py +++ b/tests/ignite/distributed/utils/__init__.py @@ -83,7 +83,7 @@ def _test_distrib_all_reduce(device): def _test_distrib_all_gather(device): - res = idist.all_gather(10) + res = torch.tensor(idist.all_gather(10), device=device) true_res = torch.tensor([10,] * idist.get_world_size(), device=device) assert (res == true_res).all() diff --git a/tests/ignite/distributed/utils/test_serial.py b/tests/ignite/distributed/utils/test_serial.py index 3a91baa13fc1..22824169ddd7 100644 --- a/tests/ignite/distributed/utils/test_serial.py +++ b/tests/ignite/distributed/utils/test_serial.py @@ -54,4 +54,5 @@ def test_idist_all_reduce_no_dist(): def test_idist_all_gather_no_dist(): - assert idist.all_gather(10) == 10 + assert idist.all_gather(10) == [10] + assert (idist.all_gather(torch.tensor(10)) == torch.tensor(10)).all()