Skip to content

Disable strict loading in multiprocessing launcher #16365

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jan 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- `MLFlowLogger` now logs hyperparameters and metrics in batched API calls ([#15915](https://github.com/Lightning-AI/lightning/pull/15915))
- Overriding the `on_train_batch_{start,end}` hooks in conjunction with taking a `dataloader_iter` in the `training_step` no longer errors out and instead shows a warning ([#16062](https://github.com/Lightning-AI/lightning/pull/16062))
- Move `tensorboardX` to extra dependencies. Use the `CSVLogger` by default ([#16349](https://github.com/Lightning-AI/lightning/pull/16349))

- Disabled strict loading in multiprocessing launcher ("ddp_spawn", etc.) when loading weights back into the main process ([#16365](https://github.com/Lightning-AI/lightning/pull/16365))

### Deprecated

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,9 @@ def _recover_results_in_main_process(self, worker_output: "_WorkerOutput", train
# load last weights
if worker_output.weights_path is not None:
ckpt = self._strategy.checkpoint_io.load_checkpoint(worker_output.weights_path)
trainer.lightning_module.load_state_dict(ckpt)
# choose non-strict loading of parameters on the main process, because the model's composition
# could have changed in the worker process (layers added or removed)
trainer.lightning_module.load_state_dict(ckpt, strict=False)
self._strategy.checkpoint_io.remove_checkpoint(worker_output.weights_path)

trainer.state = worker_output.trainer_state
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def test_collect_rank_zero_results(trainer_fn, fake_node_rank, fake_local_rank,
strategy._local_rank = fake_local_rank

launcher = _MultiProcessingLauncher(strategy=strategy)
trainer = Trainer(default_root_dir=tmpdir, strategy=strategy)
trainer = Trainer(accelerator="cpu", default_root_dir=tmpdir, strategy=strategy)

assert strategy.node_rank == fake_node_rank
assert strategy.local_rank == fake_local_rank
Expand All @@ -124,3 +124,42 @@ def test_collect_rank_zero_results(trainer_fn, fake_node_rank, fake_local_rank,
else:
# all other ranks don't have outputs (rank 0 needs to handle the output)
assert spawn_output is None


@pytest.mark.parametrize("trainer_fn", [TrainerFn.FITTING, "other"])
def test_transfer_weights(tmpdir, trainer_fn):
"""Tests that the multiprocessing launcher transfers the new weights to the main process and deletes the
temporary file."""
model = Mock(wraps=BoringModel(), spec=BoringModel)
strategy = DDPSpawnStrategy()
trainer = Trainer(accelerator="cpu", default_root_dir=tmpdir, strategy=strategy)
trainer.strategy.connect(model)
trainer.state.fn = trainer_fn # pretend we are in a particular trainer state

spawn_output = strategy._launcher._collect_rank_zero_results(trainer, {})

model.state_dict.assert_called_once()
if trainer_fn == TrainerFn.FITTING:
assert spawn_output.weights_path.endswith(".temp.ckpt")
assert os.path.isfile(spawn_output.weights_path)
else:
assert spawn_output.weights_path is None

# <-- here would normally be the multiprocessing boundary
strategy._launcher._recover_results_in_main_process(spawn_output, trainer)
assert model.load_state_dict.call_count == int(spawn_output.weights_path is not None)


def test_non_strict_loading(tmpdir):
"""Tests that the multiprocessing launcher loads the weights back into the main process but with strict loading
disabled, not erroring for missing keys."""
model = Mock(wraps=BoringModel(), spec=BoringModel)
strategy = DDPSpawnStrategy()
trainer = Trainer(accelerator="cpu", default_root_dir=tmpdir, strategy=strategy)
trainer.strategy.connect(model)
trainer.state.fn = TrainerFn.FITTING # state dict loading only relevant for the FITTING case

spawn_output = strategy._launcher._collect_rank_zero_results(trainer, {})
# <-- here would normally be the multiprocessing boundary
strategy._launcher._recover_results_in_main_process(spawn_output, trainer)
model.load_state_dict.assert_called_once_with(ANY, strict=False)
26 changes: 0 additions & 26 deletions tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from datetime import timedelta
from unittest import mock
from unittest.mock import Mock

import pytest
import torch
Expand Down Expand Up @@ -126,30 +124,6 @@ def test_ddp_spawn_configure_ddp(tmpdir):
trainer.predict(model, dataloaders=model.predict_dataloader())


@pytest.mark.parametrize("trainer_fn", [TrainerFn.FITTING, "other"])
def test_ddp_spawn_transfer_weights(tmpdir, trainer_fn):
"""Tests that the spawn strategy transfers the new weights to the main process and deletes the temporary
file."""
model = Mock(wraps=BoringModel(), spec=BoringModel)
strategy = DDPSpawnStrategy()
trainer = Trainer(default_root_dir=tmpdir, strategy=strategy)
trainer.strategy.connect(model)
trainer.state.fn = trainer_fn # pretend we are in a particular trainer state

spawn_output = strategy._launcher._collect_rank_zero_results(trainer, {})

model.state_dict.assert_called_once()
if trainer_fn == TrainerFn.FITTING:
assert spawn_output.weights_path.endswith(".temp.ckpt")
assert os.path.isfile(spawn_output.weights_path)
else:
assert spawn_output.weights_path is None

# <-- here would normally be the multiprocessing boundary
strategy._launcher._recover_results_in_main_process(spawn_output, trainer)
assert model.load_state_dict.call_count == int(spawn_output.weights_path is not None)


@mock.patch("torch.distributed.init_process_group")
def test_ddp_spawn_strategy_set_timeout(mock_init_process_group):
"""Test that the timeout gets passed to the ``torch.distributed.init_process_group`` function."""
Expand Down