Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Lightning Logging without Trainer #8509

Open
tchaton opened this issue Jul 21, 2021 · 10 comments
Open

Support Lightning Logging without Trainer #8509

tchaton opened this issue Jul 21, 2021 · 10 comments
Labels
design Includes a design discussion discussion In a discussion stage feature Is an improvement or enhancement logging Related to the `LoggerConnector` and `log()` priority: 2 Low priority task
Milestone

Comments

@tchaton
Copy link
Contributor

tchaton commented Jul 21, 2021

🚀 Feature

Motivation

To ease conversion from pure PyTorch to Lightning, users might start by creating their LightningModule.

However, their code would break if they try to log as the trainer isn't available.

Currently, we have 2 options:

  • make self.log in the absence of a Trainer
  • add support for logging without a trainer.

Here is a pseudo code to explain how we could support it.

The ResultCollection object is pretty self contained and is used to store logged values.

class LightningModule:
    def __init__(self):
        self._lightning_results = ResultCollection
        self.training_step = self._training_step_wrapper(self.training_step)

    @property
    def _results(self):
        if getattr(self, "trainer", None) is not None:
            return self.trainer._results
        return self._lightning_results

    def _training_step_wrapper(self, training_step_fn):
        def wrapper(self, *args, **kwargs)
            self._current_fx = "training_step"
            output = training_step_fn(self, *args, **kwargs)
            self._current_fx = None
        return wrapper

    def training_step():
        self.log(...)

class Model(LightningModule):
    ...

model = Model()

for _ in range(epochs):
    for batch in datalaoder:
        loss = model.training_step(batch, batch_idx)
        ...
        logged_metrics = model.get_logged_metrics()

reduced_metrics = model.get_callback_metrics(epoch=True)

Drawback, every LightningModule hooks used for logging should be wrapped to set the _current_fx function.

Pitch

Alternatives

Additional context

If you enjoy PL, check out our other projects:

  • Metrics: Machine learning metrics for distributed, scalable PyTorch applications.
  • Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, finetuning and solving problems with deep learning
  • Bolts: Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch
  • Lightning Transformers: Flexible interface for high performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.

cc @Borda @tchaton @justusschock @awaelchli @rohitgr7 @akihironitta @carmocca @edward-io @ananthsub @kamil-kaczmarek @Raalsky @Blaizzy

@tchaton tchaton added feature Is an improvement or enhancement help wanted Open to be worked on labels Jul 21, 2021
@carmocca
Copy link
Contributor

I don't quite understand the motivation. Are you saying that new users would:

  1. Convert their nn.Module into a LightningModule
  2. Keep their training loop but use the LightningModule hooks manually
  3. Replace their training loop with the Trainer

So you want to be able to self.log() inside the LightningModule so step (2) works?

Another challenge associated is that the following log() arguments would not work.

        prog_bar: bool = False,  # callbacks are managed by the trainer
        logger: bool = True,  # loggers are managed by the trainer
        sync_dist: bool = False,  # would require the user to handle the training type setup
        sync_dist_group: Optional[Any] = None,
        add_dataloader_idx: bool = True,  # not necessary without a trainer that would add it automatically
        rank_zero_only: Optional[bool] = None,

Also, the user would need to update their training loop with all the logic we currently encode in the LoggerConnector. This is not trivial and I wouldn't expect anybody to want to.

@carmocca carmocca added design Includes a design discussion discussion In a discussion stage logging Related to the `LoggerConnector` and `log()` and removed help wanted Open to be worked on labels Jul 22, 2021
@ananthsub
Copy link
Contributor

@carmocca - one idea is that the LightningModule directly owns the results as a buffer. We could offer an abstract interface for this. Some requirements:

  • This buffer must have a state_dict/load_state_dict style interface for checkpoint saving/loading. This would resolve the gap today where metric attributes are in the lightning module while the result collection is in the loop. we should consolidate all of this in the lightning module, and then populate this in the overall trainer checkpoint for saving & resuming.

For example, ResultsCollection could be considered one implementation of this interface, and perhaps its the default one the trainer uses in case the user doesn't provide their own.

Then calling self.log(...) in the LightningModule simply populates this buffer. Log carries forward much of the assumptions around the prior Results struct, which itself was introduced in Lightning 0.7 or 0.8, before torchmetrics even existed. Since then, I think torchmetrics encapsulates the metric updates, syncing, and compute much much better! And it'll get even better after Lightning-AI/torchmetrics#344

Given all of the nice utilities of torchmetrics, the metric updates and syncing ideally should be handled as part of these metric classes and not as part of the logging. For example, if people are logging arbitrary float values, I think users ought to wrap these in utilities like https://github.com/PyTorchLightning/metrics/blob/master/torchmetrics/average.py

What I think self.log in the lightning module ought to be responsible for is determining which destinations the value should be made available to (e.g. loggers, prog bar, callbacks, etc)

Because the Trainer is the entity calling all the hooks in the LightningModule, it can inspect this buffer after every hook is called and make the log data available to its components. E.g. publish to logger connector

The logger connector can then inspect the destinations for the log entry, and populate the corresponding metrics entries (e.g. progress_bar_metrics, callback_metrics, logged_metrics). https://github.com/PyTorchLightning/pytorch-lightning/blob/e1442d247e0e4967dd2772bdcf5166226c974f89/pytorch_lightning/trainer/properties.py#L631-L646

(separately, since progress bar is a callback, do we really need both progress bar and callback metrics?)

@edward-io and I are interested in pursuing a POC for this. we'll write up a doc to discuss this proposal in more detail

@stale
Copy link

stale bot commented Sep 19, 2021

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Sep 19, 2021
@ananthsub ananthsub removed the won't fix This will not be worked on label Sep 23, 2021
@alanhdu
Copy link
Contributor

alanhdu commented Sep 23, 2021

I have another use-case for this: unit-testing an existing LightningModule. In our code, we have some LightningModules that effectively define our task. Previously, in Litghtning 1.3 we were able to test this by calling

def test_my_module(...):
     module = MyModule(...)
     loss = module.training_step(...)
     assert_something_about_the_loss(loss)

In Lightning 1.4, this fails because the module isn't attached to a trainer.

@carmocca
Copy link
Contributor

@lucmos
Copy link
Contributor

lucmos commented Sep 28, 2021

Hello, replying to #9716 (comment)

That was a just mwe. My use case is to write unit tests for different methods of the model (forward, but also the *step and others).
With Lightning 1.4 it is not possible to test the model without instantiating also the Trainer (which arguably wouldn't be a unit test)

Another use case I am having problems with is benchmarking the training_step speed (without the overhead introduced by other components). Before 1.4 I was monitoring the speed of

    def module_step_approximation(_):
       module.zero_grad()
       loss = module.training_step(batch, 0)
       loss.backward()

@tchaton
Copy link
Contributor Author

tchaton commented Oct 6, 2021

Dear @lucmos,

One could patch the log function from the LightningModule to prevent the logging computation to be accounted for speed benchmarking.

Best,
T.C

@tchaton tchaton added the priority: 2 Low priority task label Oct 6, 2021
@stale
Copy link

stale bot commented Nov 6, 2021

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Nov 6, 2021
@daniellepintz daniellepintz removed the won't fix This will not be worked on label Nov 12, 2021
@stale
Copy link

stale bot commented Dec 15, 2021

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Dec 15, 2021
@daniellepintz daniellepintz removed the won't fix This will not be worked on label Dec 22, 2021
@stale
Copy link

stale bot commented Jan 22, 2022

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Jan 22, 2022
@awaelchli awaelchli removed the won't fix This will not be worked on label Jan 23, 2022
@awaelchli awaelchli added this to the 1.7 milestone Jan 23, 2022
@carmocca carmocca modified the milestones: 1.7, future Feb 3, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
design Includes a design discussion discussion In a discussion stage feature Is an improvement or enhancement logging Related to the `LoggerConnector` and `log()` priority: 2 Low priority task
Projects
None yet
Development

No branches or pull requests

7 participants