Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions monai/handlers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,27 +85,39 @@ def evenly_divisible_all_gather(data: torch.Tensor) -> torch.Tensor:
return torch.cat([data[i * max_len : i * max_len + l, ...] for i, l in enumerate(all_lens)], dim=0)


def string_list_all_gather(strings: List[str], delimiter: str = "\t") -> List[str]:
def string_list_all_gather(strings: List[str]) -> List[str]:
"""
Utility function for distributed data parallel to all gather a list of strings.
Note that if the item in `strings` is longer than 1024 chars, it will be truncated to 1024:
https://github.com/pytorch/ignite/blob/master/ignite/distributed/comp_models/base.py#L92

Args:
strings: a list of strings to all gather.
delimiter: use the delimiter to join the string list to be a long string,
then all gather across ranks and split to a list. default to "\t".

"""
if idist.get_world_size() <= 1:
world_size = idist.get_world_size()
if world_size <= 1:
return strings

_joined = delimiter.join(strings)
result: List[List[str]] = [[] for _ in range(world_size)]
# get length of strings
length = len(strings)
all_lens = idist.all_gather(length)
max_len = max(all_lens).item()
# pad the item to make sure the same length
if length < max_len:
strings = strings + ["" for _ in range(max_len - length)]

if get_torch_version_tuple() > (1, 6, 0):
# all gather across all ranks
_joined = delimiter.join(idist.all_gather(_joined))
for s in strings:
gathered = idist.all_gather(s)
for i, g in enumerate(gathered):
if len(g) > 0:
result[i].append(g)
else:
raise RuntimeError("string all_gather can not be supported in PyTorch < 1.7.0.")

return _joined.split(delimiter)
return [i for k in result for i in k]


def write_metrics_reports(
Expand Down
15 changes: 11 additions & 4 deletions tests/test_handler_metrics_saver_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import csv
import os
import random
import tempfile
import unittest

Expand Down Expand Up @@ -44,8 +45,13 @@ def _val_func(engine, batch):

engine = Engine(_val_func)

# test the case that all_gather with string length > 1024 chars
filename_postfix = "abcdefghigklmnopqrstuvwxyz"
for _ in range(1100):
filename_postfix += filename_postfix[random.randint(0, 26)]

if dist.get_rank() == 0:
data = [{"image_meta_dict": {"filename_or_obj": ["filepath1"]}}]
data = [{"image_meta_dict": {"filename_or_obj": [f"1{filename_postfix}"]}}]

@engine.on(Events.EPOCH_COMPLETED)
def _save_metrics0(engine):
Expand All @@ -58,8 +64,8 @@ def _save_metrics0(engine):
if dist.get_rank() == 1:
# different ranks have different data length
data = [
{"image_meta_dict": {"filename_or_obj": ["filepath2"]}},
{"image_meta_dict": {"filename_or_obj": ["filepath3"]}},
{"image_meta_dict": {"filename_or_obj": [f"2{filename_postfix}"]}},
{"image_meta_dict": {"filename_or_obj": [f"3{filename_postfix}"]}},
]

@engine.on(Events.EPOCH_COMPLETED)
Expand All @@ -86,7 +92,8 @@ def _save_metrics1(engine):
f_csv = csv.reader(f)
for i, row in enumerate(f_csv):
if i > 0:
self.assertEqual(row, [f"filepath{i}\t{float(i)}\t{float(i + 1)}\t{i + 0.5}"])
expected = [f"{i}{filename_postfix[0: 1023]}\t{float(i)}\t{float(i + 1)}\t{i + 0.5}"]
self.assertEqual(row, expected)
self.assertTrue(os.path.exists(os.path.join(tempdir, "metric3_summary.csv")))
# check the metric_summary.csv and content
with open(os.path.join(tempdir, "metric3_summary.csv")) as f:
Expand Down