From 0765059e2c9cd2cfe7b99f6bf6fb20002a46ebb0 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 21 Aug 2023 14:22:30 +0100 Subject: [PATCH 1/3] update rank filter Signed-off-by: Wenqi Li --- monai/utils/dist.py | 12 +++++++----- tests/test_rankfilter_dist.py | 16 +++++++++++++++- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/monai/utils/dist.py b/monai/utils/dist.py index 47e6de4a98..104fe081ac 100644 --- a/monai/utils/dist.py +++ b/monai/utils/dist.py @@ -197,11 +197,13 @@ def __init__(self, rank: int | None = None, filter_fn: Callable = lambda rank: r if dist.is_available() and dist.is_initialized(): self.rank: int = rank if rank is not None else dist.get_rank() else: - warnings.warn( - "The torch.distributed is either unavailable and uninitiated when RankFilter is instantiated. " - "If torch.distributed is used, please ensure that the RankFilter() is called " - "after torch.distributed.init_process_group() in the script." - ) + if torch.cuda.is_available() and torch.cuda.device_count() > 1: + warnings.warn( + "The torch.distributed is either unavailable and uninitiated when RankFilter is instantiated.\n" + "If torch.distributed is used, please ensure that the RankFilter() is called\n" + "after torch.distributed.init_process_group() in the script.\n", + ) + self.rank = 0 def filter(self, *_args): return self.filter_fn(self.rank) diff --git a/tests/test_rankfilter_dist.py b/tests/test_rankfilter_dist.py index 4dcd637c56..96a8b0ef26 100644 --- a/tests/test_rankfilter_dist.py +++ b/tests/test_rankfilter_dist.py @@ -43,7 +43,21 @@ def test_rankfilter(self): with open(log_filename) as file: lines = [line.rstrip() for line in file] log_message = " ".join(lines) - assert log_message.count("test_warnings") == 1 + self.assertEqual(log_message.count("test_warnings"), 1) + + def test_rankfilter_single_proc(self): + logger = logging.getLogger(__name__) + log_filename = os.path.join(self.log_dir.name, "records.log") + h1 = logging.FileHandler(filename=log_filename) + h1.setLevel(logging.WARNING) + logger.addHandler(h1) + logger.addFilter(RankFilter()) + logger.warning("test_warnings") + + with open(log_filename) as file: + lines = [line.rstrip() for line in file] + log_message = " ".join(lines) + self.assertEqual(log_message.count("test_warnings"), 1) def tearDown(self) -> None: self.log_dir.cleanup() From a99f61ba925f234326d59aee2921f01c2ef0f874 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 21 Aug 2023 14:32:27 +0100 Subject: [PATCH 2/3] autofix Signed-off-by: Wenqi Li --- monai/utils/dist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/utils/dist.py b/monai/utils/dist.py index 104fe081ac..20f09628ac 100644 --- a/monai/utils/dist.py +++ b/monai/utils/dist.py @@ -201,7 +201,7 @@ def __init__(self, rank: int | None = None, filter_fn: Callable = lambda rank: r warnings.warn( "The torch.distributed is either unavailable and uninitiated when RankFilter is instantiated.\n" "If torch.distributed is used, please ensure that the RankFilter() is called\n" - "after torch.distributed.init_process_group() in the script.\n", + "after torch.distributed.init_process_group() in the script.\n" ) self.rank = 0 From 71a001e168b06c82520c4a5e392bf39097d54ac0 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 21 Aug 2023 15:43:05 +0100 Subject: [PATCH 3/3] fixes windows test Signed-off-by: Wenqi Li --- tests/test_rankfilter_dist.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/tests/test_rankfilter_dist.py b/tests/test_rankfilter_dist.py index 96a8b0ef26..40cd36f31d 100644 --- a/tests/test_rankfilter_dist.py +++ b/tests/test_rankfilter_dist.py @@ -45,9 +45,20 @@ def test_rankfilter(self): log_message = " ".join(lines) self.assertEqual(log_message.count("test_warnings"), 1) + def tearDown(self) -> None: + self.log_dir.cleanup() + + +class SingleRankFilterTest(unittest.TestCase): + def tearDown(self) -> None: + self.log_dir.cleanup() + + def setUp(self): + self.log_dir = tempfile.TemporaryDirectory() + def test_rankfilter_single_proc(self): logger = logging.getLogger(__name__) - log_filename = os.path.join(self.log_dir.name, "records.log") + log_filename = os.path.join(self.log_dir.name, "records_sp.log") h1 = logging.FileHandler(filename=log_filename) h1.setLevel(logging.WARNING) logger.addHandler(h1) @@ -56,12 +67,11 @@ def test_rankfilter_single_proc(self): with open(log_filename) as file: lines = [line.rstrip() for line in file] + logger.removeHandler(h1) + h1.close() log_message = " ".join(lines) self.assertEqual(log_message.count("test_warnings"), 1) - def tearDown(self) -> None: - self.log_dir.cleanup() - if __name__ == "__main__": unittest.main()