Skip to content

Commit

Permalink
Disable strict loading in multiprocessing launcher (#16365)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
  • Loading branch information
3 people committed Feb 11, 2023
1 parent e00b23d commit 73cd956
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 30 deletions.
3 changes: 1 addition & 2 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

-
- Disabled strict loading in multiprocessing launcher ("ddp_spawn", etc.) when loading weights back into the main process ([#16365](https://github.com/Lightning-AI/lightning/pull/16365))


### Deprecated
Expand Down Expand Up @@ -67,7 +67,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Overriding the `on_train_batch_{start,end}` hooks in conjunction with taking a `dataloader_iter` in the `training_step` no longer errors out and instead shows a warning ([#16062](https://github.com/Lightning-AI/lightning/pull/16062))
- Move `tensorboardX` to extra dependencies. Use the `CSVLogger` by default ([#16349](https://github.com/Lightning-AI/lightning/pull/16349))


### Deprecated

- Deprecated `description`, `env_prefix` and `env_parse` parameters in `LightningCLI.__init__` in favour of giving them through `parser_kwargs` ([#15651](https://github.com/Lightning-AI/lightning/pull/15651))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,9 @@ def _recover_results_in_main_process(self, worker_output: "_WorkerOutput", train
# load last weights
if worker_output.weights_path is not None:
ckpt = self._strategy.checkpoint_io.load_checkpoint(worker_output.weights_path)
trainer.lightning_module.load_state_dict(ckpt)
# choose non-strict loading of parameters on the main process, because the model's composition
# could have changed in the worker process (layers added or removed)
trainer.lightning_module.load_state_dict(ckpt, strict=False)
self._strategy.checkpoint_io.remove_checkpoint(worker_output.weights_path)

trainer.state = worker_output.trainer_state
Expand Down
41 changes: 40 additions & 1 deletion tests/tests_pytorch/strategies/launchers/test_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def test_collect_rank_zero_results(trainer_fn, fake_node_rank, fake_local_rank,
strategy._local_rank = fake_local_rank

launcher = _MultiProcessingLauncher(strategy=strategy)
trainer = Trainer(default_root_dir=tmpdir, strategy=strategy)
trainer = Trainer(accelerator="cpu", default_root_dir=tmpdir, strategy=strategy)

assert strategy.node_rank == fake_node_rank
assert strategy.local_rank == fake_local_rank
Expand All @@ -124,3 +124,42 @@ def test_collect_rank_zero_results(trainer_fn, fake_node_rank, fake_local_rank,
else:
# all other ranks don't have outputs (rank 0 needs to handle the output)
assert spawn_output is None


@pytest.mark.parametrize("trainer_fn", [TrainerFn.FITTING, "other"])
def test_transfer_weights(tmpdir, trainer_fn):
"""Tests that the multiprocessing launcher transfers the new weights to the main process and deletes the
temporary file."""
model = Mock(wraps=BoringModel(), spec=BoringModel)
strategy = DDPSpawnStrategy()
trainer = Trainer(accelerator="cpu", default_root_dir=tmpdir, strategy=strategy)
trainer.strategy.connect(model)
trainer.state.fn = trainer_fn # pretend we are in a particular trainer state

spawn_output = strategy._launcher._collect_rank_zero_results(trainer, {})

model.state_dict.assert_called_once()
if trainer_fn == TrainerFn.FITTING:
assert spawn_output.weights_path.endswith(".temp.ckpt")
assert os.path.isfile(spawn_output.weights_path)
else:
assert spawn_output.weights_path is None

# <-- here would normally be the multiprocessing boundary
strategy._launcher._recover_results_in_main_process(spawn_output, trainer)
assert model.load_state_dict.call_count == int(spawn_output.weights_path is not None)


def test_non_strict_loading(tmpdir):
"""Tests that the multiprocessing launcher loads the weights back into the main process but with strict loading
disabled, not erroring for missing keys."""
model = Mock(wraps=BoringModel(), spec=BoringModel)
strategy = DDPSpawnStrategy()
trainer = Trainer(accelerator="cpu", default_root_dir=tmpdir, strategy=strategy)
trainer.strategy.connect(model)
trainer.state.fn = TrainerFn.FITTING # state dict loading only relevant for the FITTING case

spawn_output = strategy._launcher._collect_rank_zero_results(trainer, {})
# <-- here would normally be the multiprocessing boundary
strategy._launcher._recover_results_in_main_process(spawn_output, trainer)
model.load_state_dict.assert_called_once_with(ANY, strict=False)
26 changes: 0 additions & 26 deletions tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from datetime import timedelta
from unittest import mock
from unittest.mock import Mock

import pytest
import torch
Expand Down Expand Up @@ -126,30 +124,6 @@ def test_ddp_spawn_configure_ddp(tmpdir):
trainer.predict(model, dataloaders=model.predict_dataloader())


@pytest.mark.parametrize("trainer_fn", [TrainerFn.FITTING, "other"])
def test_ddp_spawn_transfer_weights(tmpdir, trainer_fn):
"""Tests that the spawn strategy transfers the new weights to the main process and deletes the temporary
file."""
model = Mock(wraps=BoringModel(), spec=BoringModel)
strategy = DDPSpawnStrategy()
trainer = Trainer(default_root_dir=tmpdir, strategy=strategy)
trainer.strategy.connect(model)
trainer.state.fn = trainer_fn # pretend we are in a particular trainer state

spawn_output = strategy._launcher._collect_rank_zero_results(trainer, {})

model.state_dict.assert_called_once()
if trainer_fn == TrainerFn.FITTING:
assert spawn_output.weights_path.endswith(".temp.ckpt")
assert os.path.isfile(spawn_output.weights_path)
else:
assert spawn_output.weights_path is None

# <-- here would normally be the multiprocessing boundary
strategy._launcher._recover_results_in_main_process(spawn_output, trainer)
assert model.load_state_dict.call_count == int(spawn_output.weights_path is not None)


@mock.patch("torch.distributed.init_process_group")
def test_ddp_spawn_strategy_set_timeout(mock_init_process_group):
"""Test that the timeout gets passed to the ``torch.distributed.init_process_group`` function."""
Expand Down

0 comments on commit 73cd956

Please sign in to comment.