Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Calculating gradient during validation #201

Closed
JonathanSchmidt1 opened this issue Sep 5, 2019 · 10 comments
Closed

Calculating gradient during validation #201

JonathanSchmidt1 opened this issue Sep 5, 2019 · 10 comments
Assignees
Labels
feature Is an improvement or enhancement help wanted Open to be worked on

Comments

@JonathanSchmidt1
Copy link

Hello,
I am using pytorch-lightning in a physics-related project, where I need to calculate gradients during validation. At the moment the gradient calculation is disabled in the validate function of the trainer. Of course, commenting out the line solves the problem. However, it took some time to figure out where everything goes wrong.
Solution:
I would suggest adding another parameter to trainer (e.g. enable_grad_during_validation) to allow the enabling of gradient calculation during validation. Of course, this parameter should be set to False by default so that nothing changes for users that do not require the feature. The changes that are required would be to add the parameter and change the line where the gradient is disabled.

   # disable gradients to save memory
   torch.set_grad_enabled(enable_grad_during_validation)

This might be a niche issue, therefore if no one else needs this change, I would suggest adding an extra note in the documentation of the validation loop, that informs users that gradient calculation is disabled during validation.
Ps: thank you for the great library

@JonathanSchmidt1 JonathanSchmidt1 added feature Is an improvement or enhancement help wanted Open to be worked on labels Sep 5, 2019
@williamFalcon
Copy link
Contributor

it's standard to freeze model during validation. So, this is an edge case. However, I think adding to the docs would be helpful.

want to take a stab at it?

@cemanil
Copy link

cemanil commented May 13, 2020

Important for meta-learning and nested optimization research:

Having this feature could be quite useful for researchers working on meta learning and nested optimization.

For example, without the option to enable gradients during validation, the recent inner loop optimization library higher from Facebook AI is incompatible with Pytorch Lightning.

Are there any negative downstream effects of enabling gradients during validation that I might be missing? If there aren't any, then addressing this issue by just adding a new argument to the Trainer class seems reasonable to me. I'd be happy to take a stab at it if the maintainers are ok with adding this feature.

Thanks!

@chnsh
Copy link

chnsh commented Aug 8, 2020

@cemanil Is there a reason why we can't do the meta-test in the training step itself and have the validation remain true validation without backpropagating through it?

@cemanil
Copy link

cemanil commented Aug 8, 2020

Thank you for your response.

I might have misunderstood your recommendation. If we'd like to compute the performance of our bilevel model on the validation or test set, how can/should we do so in the training step? Models like Structured Prediction Energy Networks require running backpropagation as part of inference, which in turn require having gradient computation enabled.

@benrhodes26
Copy link

I agree with cemanil that there are important mainstream use-cases of validation-time gradient computation. For instance, any inference tasks using MCMC sampling (e.g. for energy-based models or Bayesian inference).

@williamFalcon
Copy link
Contributor

williamFalcon commented Aug 24, 2020

why can't you just enable it again in the validation_step?

Lightning handles the major use cases, but this (edge case, or not so edge for your research haha), can just be handled like this:

def validation_step(self, batch, batch_idx):
    torch.set_grad_enabled(True)
    ...

But either way, in a week or so we can revisit this since we're finishing refactors

@williamFalcon williamFalcon self-assigned this Aug 24, 2020
@cemanil
Copy link

cemanil commented Aug 29, 2020

Thank you for your reply William!

This is indeed how I ended up enabling test-time gradient computation. I was just a bit hesitant to manually toggle flags like this, in order to avoid any unanticipated side effect. This one does seem pretty harmless, though.

Do you think just adding a sentence or two about this in the documentation should suffice, then? Or would it be cleaner to add an argument to the Trainer class?

@voorhs
Copy link

voorhs commented Oct 17, 2023

why can't you just enable it again in the validation_step?

Lightning handles the major use cases, but this (edge case, or not so edge for your research haha), can just be handled like this:

def validation_step(self, batch, batch_idx):
    torch.set_grad_enabled(True)
    ...

But either way, in a week or so we can revisit this since we're finishing refactors

It doesn't work for me. I even added this to end of validation_step():

torch.set_grad_enabled(True)
x = torch.tensor([1.], requires_grad=True)
y = 2 * x
print(y.requires_grad)

and it prints False. Is there an update out that restricts enabling grads so much?

@Kirito-Ausna
Copy link

why can't you just enable it again in the validation_step?
Lightning handles the major use cases, but this (edge case, or not so edge for your research haha), can just be handled like this:

def validation_step(self, batch, batch_idx):
    torch.set_grad_enabled(True)
    ...

But either way, in a week or so we can revisit this since we're finishing refactors

It doesn't work for me. I even added this to end of validation_step():

torch.set_grad_enabled(True)
x = torch.tensor([1.], requires_grad=True)
y = 2 * x
print(y.requires_grad)

and it prints False. Is there an update out that restricts enabling grads so much?

Hi, is this case you just set x.requires_grad=True, not y. If you check x.requires_grad, it would be true.

@mhubii
Copy link

mhubii commented Oct 15, 2024

Problem (Solution below)

Having the same issue, implemented a Callback that does some optimization in on_test_end. Somehow the gradients are not tracked. Please find below some example

import torch
from lightning import Callback, LightningDataModule, LightningModule, Trainer
from torch.utils.data import DataLoader


class DummyCallback(Callback):
    def _dummy_optimization(self):
        print(torch.is_grad_enabled())  # prints True in either case

        # some dummy optimization
        target = torch.tensor([1, 2, 3], dtype=torch.float32)
        input = torch.zeros_like(target)
        input.requires_grad = True

        optimizer = torch.optim.SGD([input], lr=0.1)

        for _ in range(100):
            optimizer.zero_grad()
            loss = torch.nn.functional.mse_loss(input, target)
            loss.backward()  # throws an error in on_test_end...
            print(f"Loss: {loss.item()}")
            optimizer.step()

    def on_train_end(self, trainer: Trainer, pl_module: LightningModule):
        self._dummy_optimization()  # works well...

    def on_test_end(self, trainer: Trainer, pl_module: LightningModule):
        torch.set_grad_enabled(True)
        self._dummy_optimization()  # throws an error...


class DummyModule(LightningModule):
    def configure_optimizers(self):
        pass

    def training_step(self, batch: int, batch_idx: int):
        pass

    def test_step(self, batch: int, batch_idx: int):
        pass


class DummyDataModule(LightningDataModule):
    def train_dataloader(self):
        return DataLoader([0])

    def test_dataloader(self):
        return DataLoader([0])


if __name__ == "__main__":
    dummy_callback = DummyCallback()
    trainer = Trainer(callbacks=[dummy_callback], fast_dev_run=True)
    dummy_module = DummyModule()
    dummy_data_module = DummyDataModule()
    trainer.fit(dummy_module, dummy_data_module)
    trainer.test(
        dummy_module, dummy_data_module, verbose=True
    )  # gradient tracking fails in on_test_end of DummyCallback

Am I missing something here?

Solution

To answer the question (hint found in #10287 (comment))

If e.g., in the above you do

@torch.inference_mode(False)
@torch.enable_grad()
def on_test_end(self, trainer: Trainer, pl_module: LightningModule):
    self._dummy_optimization()

everything works fine. The missing bit was torch.inference_mode(False). Also refer to https://pytorch.org/docs/stable/generated/torch.autograd.grad_mode.inference_mode.html

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement help wanted Open to be worked on
Projects
None yet
Development

No branches or pull requests

8 participants