Skip to content

Commit 8ed0434

Browse files
committed
Removed size limitation for str on collective ops (#1702)
* - Removed 1024 limitation for str on collective ops - Added MIN, MAX, PRODUCT options for horovod all reduce Fixes #1697 * Fixed failing test and added more tests + minor consistency fixes
1 parent d09ed01 commit 8ed0434

File tree

10 files changed

+146
-58
lines changed

10 files changed

+146
-58
lines changed

.github/workflows/hvd-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ jobs:
6363
- name: Run Tests
6464
shell: bash -l {0}
6565
run: |
66-
SKIP_DISTRIB_TESTS=${{ matrix.skip-distrib-tests }} bash tests/run_cpu_tests.sh
66+
bash tests/run_cpu_tests.sh
6767
6868
- name: Upload coverage to Codecov
6969
uses: codecov/codecov-action@v1

ignite/distributed/comp_models/base.py

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -87,23 +87,22 @@ def spawn(*args: Any, **kwargs: Any) -> None:
8787
_collective_op_dtype = None # type: Any
8888

8989
@staticmethod
90-
def _encode_str(x: str, device: torch.device) -> torch.Tensor:
91-
# use fix padded size
92-
size = 1024
93-
if len(x) > size:
94-
warnings.warn(f"Input string size {len(x)} is larger than {size} and thus will be truncated")
95-
x = x[:size]
96-
90+
def _encode_str(x: str, device: torch.device, size: int) -> torch.Tensor:
9791
name = torch.tensor(bytearray(x, "utf-8")).to(device)
9892
padded_x = torch.zeros(size + 1, device=device, dtype=torch.long)
9993
padded_x[: len(name)] = name
10094
padded_x[-1] = len(name)
101-
# output is tensor of shape (1, 1025)
95+
# output is tensor of shape (1, size + 1)
10296
return padded_x.unsqueeze(0)
10397

98+
def _get_max_length(self, x: str, device: torch.device) -> int:
99+
size = torch.tensor([len(x),], device=device)
100+
size = self._do_all_reduce(size, "MAX")
101+
return cast(int, size.item())
102+
104103
@staticmethod
105104
def _decode_str(xs: torch.Tensor) -> List[str]:
106-
# xs.shape = (n, 1025), e.g. (world_size, 1025)
105+
# xs.shape = (n, size + 1), e.g. (world_size, size + 1)
107106
out = [bytearray(x[: x[-1]].tolist()).decode("utf-8") for x in xs]
108107
return out
109108

@@ -144,7 +143,8 @@ def _collective_op(
144143
tensor = torch.tensor(tensor, device=device, dtype=self._collective_op_dtype)
145144
elif isinstance(tensor, str):
146145
tensor_to_str = True
147-
tensor = self._encode_str(tensor, device)
146+
max_length = self._get_max_length(tensor, device)
147+
tensor = self._encode_str(tensor, device, size=max_length)
148148

149149
tensor = self._apply_op(tensor, device, fn, *args, **kwargs)
150150

@@ -176,20 +176,20 @@ def broadcast(self, tensor: Union[torch.Tensor, float, str], src: int = 0) -> Un
176176
rank = self.get_rank()
177177
device = self.device()
178178
tensor_to_number = tensor_to_str = False
179-
if rank != src:
180-
if isinstance(tensor, Number):
181-
tensor_to_number = True
182-
tensor = torch.empty(1, device=self.device(), dtype=torch.float)
183-
elif isinstance(tensor, str):
184-
tensor_to_str = True
185-
tensor = torch.empty(1, 1025, device=self.device(), dtype=torch.long)
186-
else:
187-
if isinstance(tensor, Number):
188-
tensor_to_number = True
179+
180+
if isinstance(tensor, Number):
181+
tensor_to_number = True
182+
if rank != src:
183+
tensor = torch.empty(1, device=device, dtype=torch.float)
184+
else:
189185
tensor = torch.tensor([tensor,], device=device, dtype=torch.float)
190-
elif isinstance(tensor, str):
191-
tensor_to_str = True
192-
tensor = self._encode_str(tensor, device)
186+
elif isinstance(tensor, str):
187+
tensor_to_str = True
188+
max_length = self._get_max_length(tensor, device)
189+
if rank != src:
190+
tensor = torch.empty(1, max_length + 1, device=device, dtype=torch.long)
191+
else:
192+
tensor = self._encode_str(tensor, device, size=max_length)
193193

194194
tensor = self._apply_op(tensor, device, self._do_broadcast, src)
195195

@@ -201,7 +201,7 @@ def broadcast(self, tensor: Union[torch.Tensor, float, str], src: int = 0) -> Un
201201
return tensor
202202

203203
@abstractmethod
204-
def _do_all_reduce(self, tensor: torch.Tensor, op: str = "sum") -> torch.Tensor:
204+
def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM") -> torch.Tensor:
205205
pass
206206

207207
@abstractmethod
@@ -271,7 +271,7 @@ def create_from_backend(backend: Optional[str] = None, **kwargs: Any) -> "_Seria
271271
def spawn(*args: Any, **kwargs: Any) -> None:
272272
raise NotImplementedError("Serial computation model does not implement spawn method")
273273

274-
def all_reduce(self, tensor: Union[torch.Tensor, float], op: str = "sum") -> Union[torch.Tensor, float]:
274+
def all_reduce(self, tensor: Union[torch.Tensor, float], op: str = "SUM") -> Union[torch.Tensor, float]:
275275
return tensor
276276

277277
def all_gather(self, tensor: Union[torch.Tensor, float, str]) -> Union[torch.Tensor, float, List[float], List[str]]:
@@ -282,14 +282,14 @@ def all_gather(self, tensor: Union[torch.Tensor, float, str]) -> Union[torch.Ten
282282
def broadcast(self, tensor: Union[torch.Tensor, float, str], src: int = 0) -> Union[torch.Tensor, float, str]:
283283
return tensor
284284

285-
def _do_all_reduce(self, tensor: torch.Tensor, op: str = "sum") -> torch.Tensor:
286-
pass
285+
def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM") -> torch.Tensor:
286+
return tensor
287287

288288
def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor:
289-
pass
289+
return tensor
290290

291291
def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
292-
pass
292+
return tensor
293293

294294
def barrier(self) -> None:
295295
pass

ignite/distributed/comp_models/horovod.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,12 +165,24 @@ def spawn( # type: ignore[override]
165165
"ADASUM": hvd.mpi_ops.Adasum,
166166
}
167167

168+
_manual_reduce_op_map = {"MIN": torch.min, "MAX": torch.max, "PRODUCT": torch.prod}
169+
168170
def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM") -> torch.Tensor:
171+
if op in self._manual_reduce_op_map:
172+
op_fn = self._manual_reduce_op_map[op]
173+
return self._do_manual_all_reduce(tensor, op_fn)
169174
if op not in self._reduce_op_map:
170175
raise ValueError(f"Unsupported reduction operation: '{op}'")
171176
op = self._reduce_op_map[op]
172177
return hvd.allreduce(tensor, op=op)
173178

179+
def _do_manual_all_reduce(self, tensor: torch.Tensor, op: Any) -> torch.Tensor:
180+
res = self._do_all_gather(tensor)
181+
reduced_res = op(res, dim=0)
182+
if isinstance(reduced_res, torch.Tensor):
183+
return reduced_res
184+
return reduced_res[0]
185+
174186
def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor:
175187
if tensor.ndimension() == 0:
176188
tensor = tensor.unsqueeze(0)

ignite/distributed/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def all_reduce(tensor: Union[torch.Tensor, float], op: str = "SUM") -> Union[tor
331331
Args:
332332
tensor: tensor or number to collect across participating processes.
333333
op: reduction operation, "SUM" by default. Possible values: "SUM", "PRODUCT", "MIN", "MAX", "AND", "OR".
334-
Please, several values are not supported for the backend like "horovod".
334+
Horovod backend supports only "SUM", "AVERAGE", "ADASUM", "MIN", "MAX", "PRODUCT".
335335
336336
Returns:
337337
torch.Tensor or number

tests/ignite/distributed/comp_models/test_base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,17 @@ def test_serial_model():
2727
model.all_reduce(1)
2828
model.all_gather(1)
2929
model.broadcast(1)
30-
model._do_all_reduce(torch.tensor(1))
31-
model._do_all_gather(torch.tensor(1))
32-
model._do_broadcast(torch.tensor(1), src=0)
30+
assert model._do_all_reduce(torch.tensor(1)) == torch.tensor(1)
31+
assert model._do_all_gather(torch.tensor(1)) == torch.tensor(1)
32+
assert model._do_broadcast(torch.tensor(1), src=0) == torch.tensor(1)
3333
model.barrier()
3434

3535

3636
def test__encode_str__decode_str():
3737
device = torch.device("cpu")
3838
s = "test-abcedfg"
3939

40-
encoded_s = ComputationModel._encode_str(s, device)
40+
encoded_s = ComputationModel._encode_str(s, device, 1024)
4141
assert isinstance(encoded_s, torch.Tensor) and encoded_s.shape == (1, 1025)
4242

4343
decoded_s = ComputationModel._decode_str(encoded_s)

tests/ignite/distributed/test_auto.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def test_auto_methods_no_dist():
102102
_test_auto_dataloader(1, 1, batch_size=10, num_workers=2)
103103
_test_auto_dataloader(1, 1, batch_size=10, sampler_name="WeightedRandomSampler")
104104

105-
_test_auto_model_optimizer(1, "cpu")
105+
_test_auto_model_optimizer(1, "cuda" if torch.cuda.is_available() else "cpu")
106106

107107

108108
@pytest.mark.distributed

tests/ignite/distributed/utils/__init__.py

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,16 @@ def _test_sync(cls):
5656
assert isinstance(_model, cls), f"{type(_model)} vs {cls}"
5757

5858

59+
def _test_distrib__get_max_length(device):
60+
ws = idist.get_world_size()
61+
x = "_test_distrib__get_max_length" * (idist.get_rank() + 2)
62+
63+
from ignite.distributed.utils import _model
64+
65+
res = _model._get_max_length(x, device)
66+
assert res == len("_test_distrib__get_max_length" * (ws + 1))
67+
68+
5969
def _test_distrib_all_reduce(device):
6070

6171
res = idist.all_reduce(10)
@@ -65,9 +75,27 @@ def _test_distrib_all_reduce(device):
6575
res = idist.all_reduce(t)
6676
assert res.item() == 10 * idist.get_world_size()
6777

68-
t = torch.tensor(idist.get_rank(), device=device)
78+
rank = idist.get_rank()
79+
t = torch.tensor(rank * 2.0 + 1.0, device=device)
6980
res = idist.all_reduce(t)
70-
assert res.item() == sum([i for i in range(idist.get_world_size())])
81+
assert res.item() == sum([i * 2.0 + 1.0 for i in range(idist.get_world_size())])
82+
83+
t = torch.tensor(rank * 2.0 + 1.0, device=device)
84+
res = idist.all_reduce(t, "MIN").item()
85+
true_val = min([i * 2 + 1 for i in range(idist.get_world_size())])
86+
assert res == true_val, f"{res} vs {true_val}"
87+
88+
t = torch.tensor(rank * 2.0 + 1.0, device=device)
89+
res = idist.all_reduce(t, "MAX").item()
90+
true_val = max([i * 2.0 + 1.0 for i in range(idist.get_world_size())])
91+
assert res == true_val, f"{res} vs {true_val}"
92+
93+
t = torch.tensor(rank * 2.0 + 1.0, device=device)
94+
res = idist.all_reduce(t, "PRODUCT").item()
95+
true_val = 1
96+
for v in [i * 2.0 + 1.0 for i in range(idist.get_world_size())]:
97+
true_val *= v
98+
assert res == true_val, f"{res} vs {true_val}"
7199

72100
if idist.get_world_size() > 1:
73101
with pytest.raises(TypeError, match=r"Unhandled input type"):
@@ -99,17 +127,13 @@ def _test_distrib_all_gather(device):
99127
true_res = ["abc",] + ["test-test"] * (idist.get_world_size() - 1)
100128
assert res == true_res
101129

102-
base_x = "x" * 1026
130+
base_x = "tests/ignite/distributed/utils/test_native.py" * 2000
103131
x = base_x
104132
if idist.get_rank() == 0:
105133
x = "abc"
106134

107-
if idist.get_rank() > 0:
108-
with pytest.warns(UserWarning, match=r"is larger than 1024 and thus will be truncated"):
109-
res = idist.all_gather(x)
110-
else:
111-
res = idist.all_gather(x)
112-
true_res = ["abc",] + [base_x[:1024]] * (idist.get_world_size() - 1)
135+
res = idist.all_gather(x)
136+
true_res = ["abc",] + [base_x] * (idist.get_world_size() - 1)
113137
assert res == true_res
114138

115139
t = torch.arange(100, device=device).reshape(4, 25) * (idist.get_rank() + 1)
@@ -147,14 +171,19 @@ def _test_distrib_broadcast(device):
147171
true_res = torch.tensor([1.2345, 2.3456], dtype=torch.float, device=device)
148172
assert (res == true_res).all(), f"{res} vs {true_res}"
149173

150-
if rank == src:
151-
t = "test-abcdefg"
152-
else:
153-
t = ""
174+
def _test(text):
154175

155-
res = idist.broadcast(t, src=src)
156-
true_res = "test-abcdefg"
157-
assert res == true_res
176+
if rank == src:
177+
t = text
178+
else:
179+
t = ""
180+
181+
res = idist.broadcast(t, src=src)
182+
true_res = text
183+
assert res == true_res
184+
185+
_test("test-abcdefg")
186+
_test("tests/ignite/distributed/utils/test_horovod.py::test_idist_broadcast_hvd" * 200)
158187

159188
if rank == src:
160189
t = torch.arange(100, device=device).reshape(4, 25) * (src + 1)

tests/ignite/distributed/utils/test_horovod.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import ignite.distributed as idist
77
from ignite.distributed.utils import has_hvd_support
88
from tests.ignite.distributed.utils import (
9+
_test_distrib__get_max_length,
910
_test_distrib_all_gather,
1011
_test_distrib_all_reduce,
1112
_test_distrib_barrier,
@@ -145,6 +146,16 @@ def test_idist_all_reduce_hvd(gloo_hvd_executor):
145146
gloo_hvd_executor(_test_distrib_all_reduce, (device,), np=np, do_init=True)
146147

147148

149+
@pytest.mark.distributed
150+
@pytest.mark.skipif(not has_hvd_support, reason="Skip if no Horovod dist support")
151+
@pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc")
152+
def test_idist__model_methods_hvd(gloo_hvd_executor):
153+
154+
device = "cpu" if not torch.cuda.is_available() else "cuda"
155+
np = 4 if not torch.cuda.is_available() else torch.cuda.device_count()
156+
gloo_hvd_executor(_test_distrib__get_max_length, (device,), np=np, do_init=True)
157+
158+
148159
@pytest.mark.distributed
149160
@pytest.mark.skipif(not has_hvd_support, reason="Skip if no Horovod dist support")
150161
@pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc")

tests/ignite/distributed/utils/test_native.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import ignite.distributed as idist
88
from ignite.distributed.utils import has_native_dist_support
99
from tests.ignite.distributed.utils import (
10+
_test_distrib__get_max_length,
1011
_test_distrib_all_gather,
1112
_test_distrib_all_reduce,
1213
_test_distrib_barrier,
@@ -152,6 +153,23 @@ def test_idist_methods_in_native_nccl_context_set_local_rank(distributed_context
152153
_test_idist_methods_in_native_context_set_local_rank("nccl", "cuda", local_rank)
153154

154155

156+
@pytest.mark.distributed
157+
@pytest.mark.skipif(not has_native_dist_support, reason="Skip if no native dist support")
158+
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
159+
def test_idist__model_methods_nccl(distributed_context_single_node_nccl):
160+
161+
device = f"cuda:{distributed_context_single_node_nccl['local_rank']}"
162+
_test_distrib__get_max_length(device)
163+
164+
165+
@pytest.mark.distributed
166+
@pytest.mark.skipif(not has_native_dist_support, reason="Skip if no native dist support")
167+
def test_idist__model_methods_gloo(distributed_context_single_node_gloo):
168+
169+
device = "cpu"
170+
_test_distrib__get_max_length(device)
171+
172+
155173
@pytest.mark.distributed
156174
@pytest.mark.skipif(not has_native_dist_support, reason="Skip if no native dist support")
157175
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")

tests/ignite/distributed/utils/test_serial.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
11
import torch
22

33
import ignite.distributed as idist
4-
from tests.ignite.distributed.utils import _sanity_check, _test_sync
4+
from tests.ignite.distributed.utils import (
5+
_sanity_check,
6+
_test_distrib__get_max_length,
7+
_test_distrib_all_gather,
8+
_test_distrib_all_reduce,
9+
_test_distrib_barrier,
10+
_test_distrib_broadcast,
11+
_test_sync,
12+
)
513

614

715
def test_no_distrib(capsys):
@@ -48,10 +56,20 @@ def test_idist_methods_no_dist():
4856
assert idist.backend() is None, f"{idist.backend()}"
4957

5058

51-
def test_idist_all_reduce_no_dist():
52-
assert idist.all_reduce(10) == 10
59+
def test_idist__model_methods_no_dist():
60+
_test_distrib__get_max_length("cpu")
61+
if torch.cuda.device_count() > 1:
62+
_test_distrib__get_max_length("cuda")
5363

5464

55-
def test_idist_all_gather_no_dist():
56-
assert idist.all_gather(10) == [10]
57-
assert (idist.all_gather(torch.tensor(10)) == torch.tensor(10)).all()
65+
def test_idist_collective_ops_no_dist():
66+
_test_distrib_all_reduce("cpu")
67+
_test_distrib_all_gather("cpu")
68+
_test_distrib_barrier("cpu")
69+
_test_distrib_broadcast("cpu")
70+
71+
if torch.cuda.device_count() > 1:
72+
_test_distrib_all_reduce("cuda")
73+
_test_distrib_all_gather("cuda")
74+
_test_distrib_barrier("cuda")
75+
_test_distrib_broadcast("cuda")

0 commit comments

Comments
 (0)