Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable strict loading in multiprocessing launcher #16365

Merged
merged 9 commits into from
Jan 18, 2023

Conversation

awaelchli
Copy link
Contributor

@awaelchli awaelchli commented Jan 15, 2023

What does this PR do?

Fixes #14534

Sets strict=False when loading the state dict of the model back into the main process. The model in the main process may have a different architecture than the one trained in the worker processes:

class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer1 = torch.nn.Linear(32, 2)

    def setup(self, stage=None):
        self.layer2 = torch.nn.Linear(32, 2)  # this layer does not exist in main process


model = BoringModel()  # layer2 does not exist yet
trainer = Trainer(strategy="ddp_spawn", ...)
trainer.fit(model) 
# here, at the end of fit, we load model weights back into main process
# but we can only do so for layer1, because layer2 does not exist in main process

This is a limitation of this type of training with the "spawn" method. Since we don't know what the user will do with the model after fit(), the best we can do is load the weights that match.

Before submitting

  • Was this discussed/approved via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or internal minor changes/refactorings)

PR review

Anyone in the community is free to review the PR once the tests have passed.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:

  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

Did you have fun?

I made sure I had fun coding 🙃

cc @Borda @justusschock @awaelchli

@awaelchli awaelchli added the feature Is an improvement or enhancement label Jan 15, 2023
@github-actions github-actions bot added the pl Generic label for PyTorch Lightning package label Jan 15, 2023
@awaelchli awaelchli added strategy: ddp spawn fun Staff contributions outside working hours - to differentiate from the "community" label and removed pl Generic label for PyTorch Lightning package labels Jan 15, 2023
@github-actions github-actions bot added the pl Generic label for PyTorch Lightning package label Jan 15, 2023
@awaelchli awaelchli marked this pull request as ready for review January 15, 2023 13:04
@github-actions
Copy link
Contributor

github-actions bot commented Jan 15, 2023

⚡ Required checks status: All passing 🟢

Groups summary

🟢 pytorch_lightning: Tests workflow
Check ID Status
pl-cpu (macOS-11, pytorch, 3.8, 1.11) success
pl-cpu (macOS-11, pytorch, 3.9, 1.12) success
pl-cpu (macOS-11, pytorch, 3.10, 1.13) success
pl-cpu (macOS-11, pytorch, 3.8, 1.10, oldest) success
pl-cpu (ubuntu-20.04, pytorch, 3.8, 1.10) success
pl-cpu (ubuntu-20.04, pytorch, 3.9, 1.11) success
pl-cpu (ubuntu-20.04, pytorch, 3.10, 1.12) success
pl-cpu (ubuntu-20.04, pytorch, 3.10, 1.13) success
pl-cpu (ubuntu-20.04, pytorch, 3.7, 1.10, oldest) success
pl-cpu (windows-2022, pytorch, 3.9, 1.11) success
pl-cpu (windows-2022, pytorch, 3.10, 1.12) success
pl-cpu (windows-2022, pytorch, 3.10, 1.13) success
pl-cpu (windows-2022, pytorch, 3.7, 1.10, oldest) success
pl-cpu (slow, macOS-11, pytorch, 3.7, 1.11) success
pl-cpu (slow, ubuntu-20.04, pytorch, 3.7, 1.11) success
pl-cpu (slow, windows-2022, pytorch, 3.7, 1.11) success
pl-cpu (macOS-11, lightning, 3.8, 1.13) success
pl-cpu (ubuntu-20.04, lightning, 3.8, 1.13) success
pl-cpu (windows-2022, lightning, 3.8, 1.13) success

These checks are required after the changes to src/pytorch_lightning/strategies/launchers/multiprocessing.py, tests/tests_pytorch/strategies/launchers/test_multiprocessing.py, tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py.

🟢 pytorch_lightning: Azure GPU
Check ID Status
pytorch-lightning (GPUs) success

These checks are required after the changes to src/pytorch_lightning/strategies/launchers/multiprocessing.py, tests/tests_pytorch/strategies/launchers/test_multiprocessing.py, tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py.

🟢 pytorch_lightning: Azure HPU
Check ID Status
pytorch-lightning (HPUs) success

These checks are required after the changes to src/pytorch_lightning/strategies/launchers/multiprocessing.py, tests/tests_pytorch/strategies/launchers/test_multiprocessing.py, tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py.

🟢 pytorch_lightning: Azure IPU
Check ID Status
pytorch-lightning (IPUs) success

These checks are required after the changes to src/pytorch_lightning/strategies/launchers/multiprocessing.py, tests/tests_pytorch/strategies/launchers/test_multiprocessing.py, tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py.

🟢 pytorch_lightning: Docs
Check ID Status
make-doctest (pytorch) success
make-html (pytorch) success

These checks are required after the changes to src/pytorch_lightning/strategies/launchers/multiprocessing.py.

🟢 mypy
Check ID Status
mypy success

These checks are required after the changes to src/pytorch_lightning/strategies/launchers/multiprocessing.py.

🟢 install
Check ID Status
install-pkg (ubuntu-22.04, app, 3.7) success
install-pkg (ubuntu-22.04, app, 3.10) success
install-pkg (ubuntu-22.04, fabric, 3.7) success
install-pkg (ubuntu-22.04, fabric, 3.10) success
install-pkg (ubuntu-22.04, pytorch, 3.7) success
install-pkg (ubuntu-22.04, pytorch, 3.10) success
install-pkg (ubuntu-22.04, lightning, 3.7) success
install-pkg (ubuntu-22.04, lightning, 3.10) success
install-pkg (ubuntu-22.04, notset, 3.7) success
install-pkg (ubuntu-22.04, notset, 3.10) success
install-pkg (macOS-12, app, 3.7) success
install-pkg (macOS-12, app, 3.10) success
install-pkg (macOS-12, fabric, 3.7) success
install-pkg (macOS-12, fabric, 3.10) success
install-pkg (macOS-12, pytorch, 3.7) success
install-pkg (macOS-12, pytorch, 3.10) success
install-pkg (macOS-12, lightning, 3.7) success
install-pkg (macOS-12, lightning, 3.10) success
install-pkg (macOS-12, notset, 3.7) success
install-pkg (macOS-12, notset, 3.10) success
install-pkg (windows-2022, app, 3.7) success
install-pkg (windows-2022, app, 3.10) success
install-pkg (windows-2022, fabric, 3.7) success
install-pkg (windows-2022, fabric, 3.10) success
install-pkg (windows-2022, pytorch, 3.7) success
install-pkg (windows-2022, pytorch, 3.10) success
install-pkg (windows-2022, lightning, 3.7) success
install-pkg (windows-2022, lightning, 3.10) success
install-pkg (windows-2022, notset, 3.7) success
install-pkg (windows-2022, notset, 3.10) success

These checks are required after the changes to src/pytorch_lightning/strategies/launchers/multiprocessing.py.


Thank you for your contribution! 💜

Note
This comment is automatically generated and updates for 60 minutes every 180 seconds. If you have any other questions, contact carmocca for help.

@mergify mergify bot removed the has conflicts label Jan 17, 2023
@awaelchli awaelchli added bug Something isn't working and removed feature Is an improvement or enhancement labels Jan 18, 2023
@awaelchli awaelchli added this to the v1.9.x milestone Jan 18, 2023
@mergify mergify bot added the ready PRs ready to be merged label Jan 18, 2023
@awaelchli awaelchli enabled auto-merge (squash) January 18, 2023 22:23
@awaelchli awaelchli merged commit 7d36db8 into master Jan 18, 2023
@awaelchli awaelchli deleted the feature/fit-spawn-missing-keys branch January 18, 2023 22:53
@vitkl
Copy link

vitkl commented Jan 23, 2023

@awaelchli Thank you for quickly implementing this!

It is possible to install pytorch-lightning with this fix (not lightning)? I don't see any way to do this listed here https://pytorch-lightning.readthedocs.io/en/stable/starter/installation.html

Screenshot 2023-01-23 at 21 54 27

@carmocca
Copy link
Contributor

Try this:

PACKAGE_NAME=pytorch pip install https://github.com/Lightning-AI/lightning/archive/refs/heads/master.zip -U

@vitkl
Copy link

vitkl commented Jan 24, 2023

This works, thank you!

@vitkl
Copy link

vitkl commented Jan 24, 2023

Although the solution itself doesn't work for my pyro model - the module ends up with no parameters.

module # nn.Module that contain pyro model and pyro guide as attributes module.model and module.guide
training_plan = TrainingPlan(pyro_module=module, **plan_kwargs) # pl.LightningModule
trainer = Trainer(
    max_epochs=max_epochs,
    accelerator=accelerator,
    devices=devices,
    strategy=strategy,
    **trainer_kwargs,
)
trainer.fit(training_plan, data_splitter)
module.state_dict().keys()
# no parameters listed

Is there any point downstream of this change where the contents of state_dict can be ignored or overwritten?

@awaelchli
Copy link
Contributor Author

This is probably because you don't create your layers at the time of instantiation, only later. You will always run into this limitation with the "ddp_spawn" strategy, it's a result of the design. In this case, you should choose strategy="ddp" and you will never have this issue.

@vitkl
Copy link

vitkl commented Jan 25, 2023

I see, thanks for explaining!

And there is no way to modify the strategy - e.g. run Callback setup before loading the parameters in the main process?

@awaelchli
Copy link
Contributor Author

If you want, you can always call the setup() method yourself in the main process:

my_callback.setup("fit") # call so that layers exist after fit
my_model.setup("fit")  # call so that layers exist after fit
trainer.fit(my_model, ...)

But before falling back to this workaround, I suggest just using the regular ddp strategy.

@vitkl
Copy link

vitkl commented Jan 29, 2023

Thanks @awaelchli! Both solutions ddp and dpp_notebook + running callback in the main process seem to work (code runs). However, I see that both solutions lead to identical batches loaded to both devices (the data-loaded tensors such as the pyro plate indices on both devices are identical). I read this #7186 and other related issues - but I don't understand why this should be happening. Is it possible to check if DistributedSampler (replace_sampler_ddp=True) was created successfully?

It appears that worker_init_fn=pl_worker_init_function leads to all workers being initialised in all processes using the same seed. Is this expected?

Screenshot 2023-01-29 at 02 31 28

@awaelchli
Copy link
Contributor Author

awaelchli commented Jan 29, 2023

Is it possible to check if DistributedSampler (replace_sampler_ddp=True) was created successfully?

You can check isinstance(trainer.train_dataloaders[0].sampler, DistributedSampler) for example.

It appears that worker_init_fn=pl_worker_init_function leads to all workers being initialised in all processes using the same seed. Is this expected?

Call seed_everything(1, workers=True) to seed dataloader workers based on global rank. This doesn't concern the training processes. You can set a different seed like so: seed_everything(seed + trainer.global_rank) for example.

For further questions, please consider posting in the forum or if you find a bug, a new issue would be appreciated (since this topic here is about strict loading of weights).

awaelchli added a commit that referenced this pull request Feb 11, 2023
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
lexierule pushed a commit that referenced this pull request Feb 15, 2023
* Add .git-blame-ignore-revs (#16709)

Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>

* Fix strategy type validation in connectors (#16693)

* Disable strict loading in multiprocessing launcher (#16365)


Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>

* Fix min-epochs and early-stopping triggering too many validation runs (#16719)

Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>

* Update hydra-core requirement from <1.3.0,>=1.0.5 to >=1.0.5,<1.4.0 in /requirements (#16736)

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* [App] Add support for private data (#16738)

Co-authored-by: thomas <thomas@thomass-MacBook-Pro.local>

* [App] Add rm one level below project level (#16740)

Co-authored-by: Ethan Harris <ethanwharris@gmail.com>
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
Co-authored-by: thomas <thomas@thomass-MacBook-Pro.local>

* ci: cleaning caches (#16752)

* CI: Update colossalai version (#16747)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
type

* Update version and changelog for 1.9.2

---------

Co-authored-by: Akihiro Nitta <nitta@akihironitta.com>
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: thomas chaton <thomas@grid.ai>
Co-authored-by: thomas <thomas@thomass-MacBook-Pro.local>
Co-authored-by: Ethan Harris <ethanwharris@gmail.com>
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
@awaelchli awaelchli added strategy: ddp DistributedDataParallel and removed strategy: ddp spawn labels Nov 4, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working fun Staff contributions outside working hours - to differentiate from the "community" label pl Generic label for PyTorch Lightning package ready PRs ready to be merged strategy: ddp DistributedDataParallel
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Setting up submodules in setup doesn't work correctly with DDP
5 participants