Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1533 Fix distributed data parallel issue in ClassificationSaver #1535

Merged
merged 10 commits into from
Feb 3, 2021
8 changes: 7 additions & 1 deletion monai/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,11 @@
from .stats_handler import StatsHandler
from .surface_distance import SurfaceDistance
from .tensorboard_handlers import TensorBoardImageHandler, TensorBoardStatsHandler
from .utils import evenly_divisible_all_gather, stopping_fn_from_loss, stopping_fn_from_metric, write_metrics_reports
from .utils import (
evenly_divisible_all_gather,
stopping_fn_from_loss,
stopping_fn_from_metric,
string_list_all_gather,
write_metrics_reports,
)
from .validation_handler import ValidationHandler
22 changes: 18 additions & 4 deletions monai/handlers/classification_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
from typing import TYPE_CHECKING, Callable, Optional

from monai.data import CSVSaver
from monai.handlers.utils import evenly_divisible_all_gather, string_list_all_gather
from monai.utils import exact_version, optional_import

idist, _ = optional_import("ignite", "0.4.2", exact_version, "distributed")
Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events")
if TYPE_CHECKING:
from ignite.engine import Engine
Expand All @@ -25,6 +27,8 @@
class ClassificationSaver:
"""
Event handler triggered on completing every iteration to save the classification predictions as CSV file.
If running in distributed data parallel, only saves CSV file in the specified rank.

"""

def __init__(
Expand All @@ -35,6 +39,7 @@ def __init__(
batch_transform: Callable = lambda x: x,
output_transform: Callable = lambda x: x,
name: Optional[str] = None,
save_rank: int = 0,
) -> None:
"""
Args:
Expand All @@ -49,8 +54,11 @@ def __init__(
The first dimension of this transform's output will be treated as the
batch dimension. Each item in the batch will be saved individually.
name: identifier of logging.logger to use, defaulting to `engine.logger`.
save_rank: only the handler on specified rank will save to CSV file in multi-gpus validation,
default to 0.

"""
self._expected_rank: bool = idist.get_rank() == save_rank
self.saver = CSVSaver(output_dir, filename, overwrite)
self.batch_transform = batch_transform
self.output_transform = output_transform
Expand All @@ -67,7 +75,7 @@ def attach(self, engine: Engine) -> None:
self.logger = engine.logger
if not engine.has_event_handler(self, Events.ITERATION_COMPLETED):
engine.add_event_handler(Events.ITERATION_COMPLETED, self)
if not engine.has_event_handler(self.saver.finalize, Events.COMPLETED):
if self._expected_rank and not engine.has_event_handler(self.saver.finalize, Events.COMPLETED):
engine.add_event_handler(Events.COMPLETED, lambda engine: self.saver.finalize())
Nic-Ma marked this conversation as resolved.
Show resolved Hide resolved

def __call__(self, engine: Engine) -> None:
Expand All @@ -77,6 +85,12 @@ def __call__(self, engine: Engine) -> None:
Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
"""
meta_data = self.batch_transform(engine.state.batch)
engine_output = self.output_transform(engine.state.output)
self.saver.save_batch(engine_output, meta_data)
_meta_data = self.batch_transform(engine.state.batch)
if "filename_or_obj" in _meta_data:
Nic-Ma marked this conversation as resolved.
Show resolved Hide resolved
# all gather filenames across ranks
_meta_data["filename_or_obj"] = string_list_all_gather(_meta_data["filename_or_obj"])
# all gather predictions across ranks
_engine_output = evenly_divisible_all_gather(self.output_transform(engine.state.output))

if self._expected_rank:
self.saver.save_batch(_engine_output, _meta_data)
14 changes: 3 additions & 11 deletions monai/handlers/metrics_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@

from typing import TYPE_CHECKING, Callable, List, Optional, Sequence, Union

from monai.handlers.utils import write_metrics_reports
from monai.handlers.utils import string_list_all_gather, write_metrics_reports
from monai.utils import ensure_tuple, exact_version, optional_import
from monai.utils.module import get_torch_version_tuple

Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events")
idist, _ = optional_import("ignite", "0.4.2", exact_version, "distributed")
Expand Down Expand Up @@ -105,15 +104,8 @@ def __call__(self, engine: Engine) -> None:
if self.save_rank >= ws:
raise ValueError("target rank is greater than the distributed group size.")

_images = self._filenames
if ws > 1:
_filenames = self.deli.join(_images)
if get_torch_version_tuple() > (1, 6, 0):
# all gather across all processes
_filenames = self.deli.join(idist.all_gather(_filenames))
else:
raise RuntimeError("MetricsSaver can not save metric details in distributed mode with PyTorch < 1.7.0.")
_images = _filenames.split(self.deli)
# all gather file names across ranks
_images = string_list_all_gather(strings=self._filenames) if ws > 1 else self._filenames

# only save metrics to file in specified rank
if idist.get_rank() == self.save_rank:
Expand Down
28 changes: 26 additions & 2 deletions monai/handlers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@

import os
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Union

import numpy as np
import torch

from monai.utils import ensure_tuple, exact_version, optional_import
from monai.utils import ensure_tuple, exact_version, get_torch_version_tuple, optional_import

idist, _ = optional_import("ignite", "0.4.2", exact_version, "distributed")
if TYPE_CHECKING:
Expand All @@ -28,6 +28,7 @@
"stopping_fn_from_metric",
"stopping_fn_from_loss",
"evenly_divisible_all_gather",
"string_list_all_gather",
"write_metrics_reports",
]

Expand Down Expand Up @@ -81,6 +82,29 @@ def evenly_divisible_all_gather(data: torch.Tensor) -> torch.Tensor:
return torch.cat([data[i * max_len : i * max_len + l, ...] for i, l in enumerate(all_lens)], dim=0)


def string_list_all_gather(strings: List[str], delimiter: str = "\t") -> List[str]:
"""
Utility function for distributed data parallel to all gather a list of strings.

Args:
strings: a list of strings to all gather.
delimiter: use the delimiter to join the string list to be a long string,
then all gather across ranks and split to a list. default to "\t".

"""
if idist.get_world_size() <= 1:
return strings

_joined = delimiter.join(strings)
if get_torch_version_tuple() > (1, 6, 0):
# all gather across all ranks
_joined = delimiter.join(idist.all_gather(_joined))
else:
raise RuntimeError("MetricsSaver can not save metric details in distributed mode with PyTorch < 1.7.0.")

return _joined.split(delimiter)


def write_metrics_reports(
save_dir: str,
images: Optional[Sequence[str]],
Expand Down
1 change: 1 addition & 0 deletions tests/min_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def run_testsuit():
"test_handler_metrics_saver",
"test_handler_metrics_saver_dist",
"test_evenly_divisible_all_gather_dist",
"test_handler_classification_saver_dist",
]
assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}"

Expand Down
60 changes: 60 additions & 0 deletions tests/test_handler_classification_saver_dist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import csv
import os
import tempfile
import unittest

import numpy as np
import torch
import torch.distributed as dist
from ignite.engine import Engine

from monai.handlers import ClassificationSaver
from tests.utils import DistCall, DistTestCase, SkipIfBeforePyTorchVersion


@SkipIfBeforePyTorchVersion((1, 7))
class DistributedHandlerClassificationSaver(DistTestCase):
@DistCall(nnodes=1, nproc_per_node=2)
def test_saved_content(self):
with tempfile.TemporaryDirectory() as tempdir:
rank = dist.get_rank()

# set up engine
def _train_func(engine, batch):
return torch.zeros(8 + rank * 2)

engine = Engine(_train_func)

# set up testing handler
saver = ClassificationSaver(output_dir=tempdir, filename="predictions.csv", save_rank=1)
saver.attach(engine)

# rank 0 has 8 images, rank 1 has 10 images
data = [{"filename_or_obj": ["testfile" + str(i) for i in range(8 * rank, (8 + rank) * (rank + 1))]}]
engine.run(data, max_epochs=1)
filepath = os.path.join(tempdir, "predictions.csv")
if rank == 1:
self.assertTrue(os.path.exists(filepath))
with open(filepath, "r") as f:
reader = csv.reader(f)
i = 0
for row in reader:
self.assertEqual(row[0], "testfile" + str(i))
self.assertEqual(np.array(row[1:]).astype(np.float32), 0.0)
i += 1
self.assertEqual(i, 18)


if __name__ == "__main__":
unittest.main()