diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index 3e36af0652..a0717169aa 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -85,27 +85,39 @@ def evenly_divisible_all_gather(data: torch.Tensor) -> torch.Tensor: return torch.cat([data[i * max_len : i * max_len + l, ...] for i, l in enumerate(all_lens)], dim=0) -def string_list_all_gather(strings: List[str], delimiter: str = "\t") -> List[str]: +def string_list_all_gather(strings: List[str]) -> List[str]: """ Utility function for distributed data parallel to all gather a list of strings. + Note that if the item in `strings` is longer than 1024 chars, it will be truncated to 1024: + https://github.com/pytorch/ignite/blob/master/ignite/distributed/comp_models/base.py#L92 Args: strings: a list of strings to all gather. - delimiter: use the delimiter to join the string list to be a long string, - then all gather across ranks and split to a list. default to "\t". """ - if idist.get_world_size() <= 1: + world_size = idist.get_world_size() + if world_size <= 1: return strings - _joined = delimiter.join(strings) + result: List[List[str]] = [[] for _ in range(world_size)] + # get length of strings + length = len(strings) + all_lens = idist.all_gather(length) + max_len = max(all_lens).item() + # pad the item to make sure the same length + if length < max_len: + strings = strings + ["" for _ in range(max_len - length)] + if get_torch_version_tuple() > (1, 6, 0): - # all gather across all ranks - _joined = delimiter.join(idist.all_gather(_joined)) + for s in strings: + gathered = idist.all_gather(s) + for i, g in enumerate(gathered): + if len(g) > 0: + result[i].append(g) else: raise RuntimeError("string all_gather can not be supported in PyTorch < 1.7.0.") - return _joined.split(delimiter) + return [i for k in result for i in k] def write_metrics_reports( diff --git a/tests/test_handler_metrics_saver_dist.py b/tests/test_handler_metrics_saver_dist.py index 1b17d0adb4..dfdaa16526 100644 --- a/tests/test_handler_metrics_saver_dist.py +++ b/tests/test_handler_metrics_saver_dist.py @@ -12,6 +12,7 @@ import csv import os +import random import tempfile import unittest @@ -44,8 +45,13 @@ def _val_func(engine, batch): engine = Engine(_val_func) + # test the case that all_gather with string length > 1024 chars + filename_postfix = "abcdefghigklmnopqrstuvwxyz" + for _ in range(1100): + filename_postfix += filename_postfix[random.randint(0, 26)] + if dist.get_rank() == 0: - data = [{"image_meta_dict": {"filename_or_obj": ["filepath1"]}}] + data = [{"image_meta_dict": {"filename_or_obj": [f"1{filename_postfix}"]}}] @engine.on(Events.EPOCH_COMPLETED) def _save_metrics0(engine): @@ -58,8 +64,8 @@ def _save_metrics0(engine): if dist.get_rank() == 1: # different ranks have different data length data = [ - {"image_meta_dict": {"filename_or_obj": ["filepath2"]}}, - {"image_meta_dict": {"filename_or_obj": ["filepath3"]}}, + {"image_meta_dict": {"filename_or_obj": [f"2{filename_postfix}"]}}, + {"image_meta_dict": {"filename_or_obj": [f"3{filename_postfix}"]}}, ] @engine.on(Events.EPOCH_COMPLETED) @@ -86,7 +92,8 @@ def _save_metrics1(engine): f_csv = csv.reader(f) for i, row in enumerate(f_csv): if i > 0: - self.assertEqual(row, [f"filepath{i}\t{float(i)}\t{float(i + 1)}\t{i + 0.5}"]) + expected = [f"{i}{filename_postfix[0: 1023]}\t{float(i)}\t{float(i + 1)}\t{i + 0.5}"] + self.assertEqual(row, expected) self.assertTrue(os.path.exists(os.path.join(tempdir, "metric3_summary.csv"))) # check the metric_summary.csv and content with open(os.path.join(tempdir, "metric3_summary.csv")) as f: