Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/pytorch lightning #702

Merged
merged 59 commits into from
Feb 15, 2022
Merged

Feat/pytorch lightning #702

merged 59 commits into from
Feb 15, 2022

Conversation

dennisbader
Copy link
Collaborator

@dennisbader dennisbader commented Dec 21, 2021

Addresses #577.

Update of #697

  • i think we could let user pass all pytorch-lightning trainer parameters as a dedicted kwarg like pl_trainer_pararms. This will set up a trainer with some darts specifics + the user's customization
  • as dicussed with @hrzn, it could be nice if user can pass their own PL trainer to fit
  • one issue is how we resume training (calling fit once). Best solution I've found so far was with mandatory checkpoint saving
  • further investigation of warnings required.

@dennisbader dennisbader marked this pull request as draft December 21, 2021 17:06
Copy link
Contributor

@hrzn hrzn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like a great start. Maybe so far my main comment concerns the PL trainer; I feel (but might be wrong) that it would be nice to change our models' signatures, in order to put everything training related in fit(). So for instance we could receive either the PL trainer kwargs (or the PL trainer itself?) as arguments to fit(). In init we could also receive training-related arguments (or PL trainer directly), but those would only be used as default if nothing is specified to fit() (as we are currently doing with epochs).

darts/models/forecasting/ptl_torch_forecasting_model.py Outdated Show resolved Hide resolved
darts/models/forecasting/ptl_torch_forecasting_model.py Outdated Show resolved Hide resolved
darts/models/forecasting/ptl_torch_forecasting_model.py Outdated Show resolved Hide resolved
darts/models/forecasting/ptl_torch_forecasting_model.py Outdated Show resolved Hide resolved
darts/models/forecasting/ptl_torch_forecasting_model.py Outdated Show resolved Hide resolved
darts/models/forecasting/ptl_torch_forecasting_model.py Outdated Show resolved Hide resolved
darts/models/forecasting/ptl_torch_forecasting_model.py Outdated Show resolved Hide resolved
darts/models/forecasting/ptl_model.py Outdated Show resolved Hide resolved
def configure_optimizers(self):
"""sets up optimizers"""

# TODO: i think we can move this to to pl.Trainer(). and could probably be simplified
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

darts/models/forecasting/ptl_tft_model.py Outdated Show resolved Hide resolved
Copy link
Contributor

@hrzn hrzn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💯

@dennisbader dennisbader marked this pull request as ready for review February 13, 2022 17:34
Copy link
Contributor

@hrzn hrzn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚀

<https://pytorch-lightning.readthedocs.io/en/stable/extensions/callbacks.html>`_

.. highlight:: python
.. code-block:: python
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 for the example!

@dennisbader dennisbader merged commit dda4d9c into master Feb 15, 2022
@dennisbader dennisbader deleted the feat/pytorch_lightning branch February 15, 2022 13:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants