From b78641331463e350b794b3467abefe89b5eca9c0 Mon Sep 17 00:00:00 2001 From: Ryan Date: Mon, 15 May 2023 20:25:34 -0400 Subject: [PATCH 1/8] only start a single set of workers for sequential --- src/lightning/pytorch/utilities/combined_loader.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/lightning/pytorch/utilities/combined_loader.py b/src/lightning/pytorch/utilities/combined_loader.py index 96299126b4bd1..7ff0c8b1192af 100644 --- a/src/lightning/pytorch/utilities/combined_loader.py +++ b/src/lightning/pytorch/utilities/combined_loader.py @@ -108,7 +108,7 @@ def limits(self, limits: Optional[List[Union[int, float]]]) -> None: self._limits = limits def __next__(self) -> Tuple[Any, int, int]: - n = len(self.iterators) + n = len(self.iterables) if n == 0 or self._iterator_idx >= n: raise StopIteration @@ -120,7 +120,7 @@ def __next__(self) -> Tuple[Any, int, int]: raise StopIteration try: - out = next(self.iterators[self._iterator_idx]) + out = next(self.iterators[0]) index = self._idx self._idx += 1 # batch, batch_idx, dataloader_idx @@ -131,9 +131,9 @@ def __next__(self) -> Tuple[Any, int, int]: return self.__next__() def __iter__(self) -> Self: - super().__iter__() self._iterator_idx = 0 self._idx = 0 + self._load_current_iterator() return self def reset(self) -> None: @@ -141,9 +141,14 @@ def reset(self) -> None: self._iterator_idx = 0 self._idx = 0 + def _load_current_iterator(self) -> None: + # Load a single DataLoader, prevents multiple sets of workers from starting unnecessarily + self.iterators = [iter(iterable) for idx, iterable in enumerate(self.iterables) if idx == self._iterator_idx] + def _use_next_iterator(self) -> None: self._iterator_idx += 1 self._idx = 0 + self._load_current_iterator() class _MaxSize(_ModeIterator[List]): From 5f3f3c70cb4b7cbd08b0516dc99d34e52344f2b4 Mon Sep 17 00:00:00 2001 From: Ryan Mukherjee Date: Mon, 29 May 2023 13:53:12 -0400 Subject: [PATCH 2/8] simplify initializing self.iterators MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- src/lightning/pytorch/utilities/combined_loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/utilities/combined_loader.py b/src/lightning/pytorch/utilities/combined_loader.py index 7ff0c8b1192af..e358369adee00 100644 --- a/src/lightning/pytorch/utilities/combined_loader.py +++ b/src/lightning/pytorch/utilities/combined_loader.py @@ -143,7 +143,7 @@ def reset(self) -> None: def _load_current_iterator(self) -> None: # Load a single DataLoader, prevents multiple sets of workers from starting unnecessarily - self.iterators = [iter(iterable) for idx, iterable in enumerate(self.iterables) if idx == self._iterator_idx] + self.iterators = [iter(self.iterables[self._iterator_idx])] def _use_next_iterator(self) -> None: self._iterator_idx += 1 From f917652dc913a269ce92f720deb5004707c2d9e2 Mon Sep 17 00:00:00 2001 From: Ryan Mukherjee Date: Mon, 29 May 2023 14:15:10 -0400 Subject: [PATCH 3/8] handle idx > len(iterables) --- src/lightning/pytorch/utilities/combined_loader.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/lightning/pytorch/utilities/combined_loader.py b/src/lightning/pytorch/utilities/combined_loader.py index e358369adee00..0e012dbae145b 100644 --- a/src/lightning/pytorch/utilities/combined_loader.py +++ b/src/lightning/pytorch/utilities/combined_loader.py @@ -143,7 +143,11 @@ def reset(self) -> None: def _load_current_iterator(self) -> None: # Load a single DataLoader, prevents multiple sets of workers from starting unnecessarily - self.iterators = [iter(self.iterables[self._iterator_idx])] + if self._iterator_idx < len(self.iterables): + self.iterators = [iter(self.iterables[self._iterator_idx])] + else: + # No more iterables to step through, return an empty list + self.iterators = [] def _use_next_iterator(self) -> None: self._iterator_idx += 1 From c5ed06fabca814dc5ca23445a3cdd4c08b3e946b Mon Sep 17 00:00:00 2001 From: Ryan Mukherjee Date: Mon, 29 May 2023 14:44:23 -0400 Subject: [PATCH 4/8] shutdown occurs before teardown --- tests/tests_pytorch/loops/test_loops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tests_pytorch/loops/test_loops.py b/tests/tests_pytorch/loops/test_loops.py index 15efb392a2ba8..ef4a0f6d9c506 100644 --- a/tests/tests_pytorch/loops/test_loops.py +++ b/tests/tests_pytorch/loops/test_loops.py @@ -844,9 +844,9 @@ def _get_iterator(self): # iterable check 0, # epoch ends + 0, 1, # teardown - 1, ] else: expected = [ @@ -855,9 +855,9 @@ def _get_iterator(self): # iterable check 0, # epoch ends + 0, 1, 2, # teardown - 3, ] assert val_dataloader.shutdown_workers_epochs == expected From ce2cfe5681b08dad076773c9ef74632d96a181bf Mon Sep 17 00:00:00 2001 From: Ryan Mukherjee Date: Mon, 29 May 2023 15:29:36 -0400 Subject: [PATCH 5/8] add test for sequential workers --- .../utilities/test_combined_loader.py | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/tests/tests_pytorch/utilities/test_combined_loader.py b/tests/tests_pytorch/utilities/test_combined_loader.py index fa6d33120abfc..0452f7148d714 100644 --- a/tests/tests_pytorch/utilities/test_combined_loader.py +++ b/tests/tests_pytorch/utilities/test_combined_loader.py @@ -305,6 +305,45 @@ def test_combined_loader_sequence_iterable_dataset(mode, use_multiple_dataloader assert idx == expected - 1 +@pytest.mark.parametrize("mode", ["min_size", "max_size_cycle", "max_size", "sequential"]) +def test_combined_loader_simultaneous_workers(mode): + """Test `CombinedLoader` to check how it initializes dataloader workers.""" + + class TestDataLoader(DataLoader): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.workers_active = False + + def _get_iterator(self): + self.workers_active = True + return super()._get_iterator() + + def _shutdown_workers(self): + self.workers_active = False + super()._shutdown_workers() + + loaders = [ + TestDataLoader(range(10), batch_size=2, num_workers=0), + TestDataLoader(range(20), batch_size=2, num_workers=0), + ] + combined_loader = CombinedLoader(loaders, mode) + + for idx, item in enumerate(combined_loader): + break + + workers_active = [] + for loader in loaders: + workers_active.append(loader.workers_active) + + if mode == "sequential": + # Only starts the first dataloader + expected = [True, False] + else: + # Starts all dataloaders in order to iterate through one at a time + expected = [True, True] + assert workers_active == expected + + @pytest.mark.parametrize( ("limits", "expected"), [ From 90183c9971d253767ef8f6602a0a0d9a42f12d1a Mon Sep 17 00:00:00 2001 From: Ryan Mukherjee Date: Mon, 29 May 2023 15:34:58 -0400 Subject: [PATCH 6/8] add combinedloader fix --- src/lightning/pytorch/CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 1e3c905b7e1e1..67badc0d4b75b 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -108,6 +108,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- `CombinedLoader` only starts DataLoader workers when necessary when operating in sequential mode ([#17639](https://github.com/Lightning-AI/lightning/pull/17639)) + + - Fixed issue where `Model.load_from_checkpoint("checkpoint.ckpt", map_location=map_location)` would always return model on CPU ([#17308](https://github.com/Lightning-AI/lightning/pull/17308)) From ef6d762b2d44347f9af3f02053a456ad26169547 Mon Sep 17 00:00:00 2001 From: Ryan Mukherjee Date: Mon, 29 May 2023 15:39:32 -0400 Subject: [PATCH 7/8] update for ruff --- tests/tests_pytorch/utilities/test_combined_loader.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/tests_pytorch/utilities/test_combined_loader.py b/tests/tests_pytorch/utilities/test_combined_loader.py index 0452f7148d714..77391d374990f 100644 --- a/tests/tests_pytorch/utilities/test_combined_loader.py +++ b/tests/tests_pytorch/utilities/test_combined_loader.py @@ -335,12 +335,8 @@ def _shutdown_workers(self): for loader in loaders: workers_active.append(loader.workers_active) - if mode == "sequential": - # Only starts the first dataloader - expected = [True, False] - else: - # Starts all dataloaders in order to iterate through one at a time - expected = [True, True] + # Sequential only starts the first dataloader, other modes start both + expected = [True, False] if mode == "sequential" else [True, True] assert workers_active == expected From 01a52fda5d41bc685dd08ee73e601b5770e65d49 Mon Sep 17 00:00:00 2001 From: Ryan Date: Mon, 29 May 2023 21:38:24 -0400 Subject: [PATCH 8/8] cleanup --- tests/tests_pytorch/loops/test_loops.py | 2 -- tests/tests_pytorch/utilities/test_combined_loader.py | 5 ++--- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/tests_pytorch/loops/test_loops.py b/tests/tests_pytorch/loops/test_loops.py index ef4a0f6d9c506..e38d4459e2ed3 100644 --- a/tests/tests_pytorch/loops/test_loops.py +++ b/tests/tests_pytorch/loops/test_loops.py @@ -846,7 +846,6 @@ def _get_iterator(self): # epoch ends 0, 1, - # teardown ] else: expected = [ @@ -858,6 +857,5 @@ def _get_iterator(self): 0, 1, 2, - # teardown ] assert val_dataloader.shutdown_workers_epochs == expected diff --git a/tests/tests_pytorch/utilities/test_combined_loader.py b/tests/tests_pytorch/utilities/test_combined_loader.py index 77391d374990f..7109523b378a9 100644 --- a/tests/tests_pytorch/utilities/test_combined_loader.py +++ b/tests/tests_pytorch/utilities/test_combined_loader.py @@ -327,9 +327,8 @@ def _shutdown_workers(self): TestDataLoader(range(20), batch_size=2, num_workers=0), ] combined_loader = CombinedLoader(loaders, mode) - - for idx, item in enumerate(combined_loader): - break + # Start the dataloader + _ = iter(combined_loader) workers_active = [] for loader in loaders: