Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable strict loading in multiprocessing launcher #16365

Merged
merged 9 commits into from
Jan 18, 2023
2 changes: 1 addition & 1 deletion src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- `MLFlowLogger` now logs hyperparameters and metrics in batched API calls ([#15915](https://github.com/Lightning-AI/lightning/pull/15915))
- Overriding the `on_train_batch_{start,end}` hooks in conjunction with taking a `dataloader_iter` in the `training_step` no longer errors out and instead shows a warning ([#16062](https://github.com/Lightning-AI/lightning/pull/16062))
- Move `tensorboardX` to extra dependencies. Use the `CSVLogger` by default ([#16349](https://github.com/Lightning-AI/lightning/pull/16349))

- Disabled strict loading in multiprocessing launcher ("ddp_spawn", etc.) when loading weights back into the main process ([#16365](https://github.com/Lightning-AI/lightning/pull/16365))

### Deprecated

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,9 @@ def _recover_results_in_main_process(self, worker_output: "_WorkerOutput", train
# load last weights
if worker_output.weights_path is not None:
ckpt = self._strategy.checkpoint_io.load_checkpoint(worker_output.weights_path)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
trainer.lightning_module.load_state_dict(ckpt)
# choose non-strict loading of parameters on the main process, because the model's composition
# could have changed in the worker process (layers added or removed)
trainer.lightning_module.load_state_dict(ckpt, strict=False)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
self._strategy.checkpoint_io.remove_checkpoint(worker_output.weights_path)

trainer.state = worker_output.trainer_state
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def test_collect_rank_zero_results(trainer_fn, fake_node_rank, fake_local_rank,
strategy._local_rank = fake_local_rank

launcher = _MultiProcessingLauncher(strategy=strategy)
trainer = Trainer(default_root_dir=tmpdir, strategy=strategy)
trainer = Trainer(accelerator="cpu", default_root_dir=tmpdir, strategy=strategy)

assert strategy.node_rank == fake_node_rank
assert strategy.local_rank == fake_local_rank
Expand All @@ -124,3 +124,42 @@ def test_collect_rank_zero_results(trainer_fn, fake_node_rank, fake_local_rank,
else:
# all other ranks don't have outputs (rank 0 needs to handle the output)
assert spawn_output is None


@pytest.mark.parametrize("trainer_fn", [TrainerFn.FITTING, "other"])
def test_transfer_weights(tmpdir, trainer_fn):
"""Tests that the multiprocessing launcher transfers the new weights to the main process and deletes the
temporary file."""
model = Mock(wraps=BoringModel(), spec=BoringModel)
strategy = DDPSpawnStrategy()
trainer = Trainer(accelerator="cpu", default_root_dir=tmpdir, strategy=strategy)
trainer.strategy.connect(model)
trainer.state.fn = trainer_fn # pretend we are in a particular trainer state

spawn_output = strategy._launcher._collect_rank_zero_results(trainer, {})

model.state_dict.assert_called_once()
if trainer_fn == TrainerFn.FITTING:
assert spawn_output.weights_path.endswith(".temp.ckpt")
assert os.path.isfile(spawn_output.weights_path)
else:
assert spawn_output.weights_path is None

# <-- here would normally be the multiprocessing boundary
strategy._launcher._recover_results_in_main_process(spawn_output, trainer)
assert model.load_state_dict.call_count == int(spawn_output.weights_path is not None)


def test_non_strict_loading(tmpdir):
"""Tests that the multiprocessing launcher loads the weights back into the main process but with strict loading
disabled, not erroring for missing keys."""
model = Mock(wraps=BoringModel(), spec=BoringModel)
strategy = DDPSpawnStrategy()
trainer = Trainer(accelerator="cpu", default_root_dir=tmpdir, strategy=strategy)
trainer.strategy.connect(model)
trainer.state.fn = TrainerFn.FITTING # state dict loading only relevant for the FITTING case

spawn_output = strategy._launcher._collect_rank_zero_results(trainer, {})
# <-- here would normally be the multiprocessing boundary
strategy._launcher._recover_results_in_main_process(spawn_output, trainer)
model.load_state_dict.assert_called_once_with(ANY, strict=False)
26 changes: 0 additions & 26 deletions tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from datetime import timedelta
from unittest import mock
from unittest.mock import Mock

import pytest
import torch
Expand Down Expand Up @@ -126,30 +124,6 @@ def test_ddp_spawn_configure_ddp(tmpdir):
trainer.predict(model, dataloaders=model.predict_dataloader())


@pytest.mark.parametrize("trainer_fn", [TrainerFn.FITTING, "other"])
def test_ddp_spawn_transfer_weights(tmpdir, trainer_fn):
Borda marked this conversation as resolved.
Show resolved Hide resolved
"""Tests that the spawn strategy transfers the new weights to the main process and deletes the temporary
file."""
model = Mock(wraps=BoringModel(), spec=BoringModel)
strategy = DDPSpawnStrategy()
trainer = Trainer(default_root_dir=tmpdir, strategy=strategy)
trainer.strategy.connect(model)
trainer.state.fn = trainer_fn # pretend we are in a particular trainer state

spawn_output = strategy._launcher._collect_rank_zero_results(trainer, {})

model.state_dict.assert_called_once()
if trainer_fn == TrainerFn.FITTING:
assert spawn_output.weights_path.endswith(".temp.ckpt")
assert os.path.isfile(spawn_output.weights_path)
else:
assert spawn_output.weights_path is None

# <-- here would normally be the multiprocessing boundary
strategy._launcher._recover_results_in_main_process(spawn_output, trainer)
assert model.load_state_dict.call_count == int(spawn_output.weights_path is not None)


@mock.patch("torch.distributed.init_process_group")
def test_ddp_spawn_strategy_set_timeout(mock_init_process_group):
"""Test that the timeout gets passed to the ``torch.distributed.init_process_group`` function."""
Expand Down