-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Calculating gradient during validation #201
Comments
it's standard to freeze model during validation. So, this is an edge case. However, I think adding to the docs would be helpful. want to take a stab at it? |
Important for meta-learning and nested optimization research: Having this feature could be quite useful for researchers working on meta learning and nested optimization. For example, without the option to enable gradients during validation, the recent inner loop optimization library higher from Facebook AI is incompatible with Pytorch Lightning. Are there any negative downstream effects of enabling gradients during validation that I might be missing? If there aren't any, then addressing this issue by just adding a new argument to the Trainer class seems reasonable to me. I'd be happy to take a stab at it if the maintainers are ok with adding this feature. Thanks! |
@cemanil Is there a reason why we can't do the meta-test in the training step itself and have the validation remain true validation without backpropagating through it? |
Thank you for your response. I might have misunderstood your recommendation. If we'd like to compute the performance of our bilevel model on the validation or test set, how can/should we do so in the training step? Models like Structured Prediction Energy Networks require running backpropagation as part of inference, which in turn require having gradient computation enabled. |
I agree with cemanil that there are important mainstream use-cases of validation-time gradient computation. For instance, any inference tasks using MCMC sampling (e.g. for energy-based models or Bayesian inference). |
why can't you just enable it again in the validation_step? Lightning handles the major use cases, but this (edge case, or not so edge for your research haha), can just be handled like this: def validation_step(self, batch, batch_idx):
torch.set_grad_enabled(True)
... But either way, in a week or so we can revisit this since we're finishing refactors |
Thank you for your reply William! This is indeed how I ended up enabling test-time gradient computation. I was just a bit hesitant to manually toggle flags like this, in order to avoid any unanticipated side effect. This one does seem pretty harmless, though. Do you think just adding a sentence or two about this in the documentation should suffice, then? Or would it be cleaner to add an argument to the Trainer class? |
It doesn't work for me. I even added this to end of
and it prints |
Hi, is this case you just set x.requires_grad=True, not y. If you check x.requires_grad, it would be true. |
Problem (Solution below)Having the same issue, implemented a import torch
from lightning import Callback, LightningDataModule, LightningModule, Trainer
from torch.utils.data import DataLoader
class DummyCallback(Callback):
def _dummy_optimization(self):
print(torch.is_grad_enabled()) # prints True in either case
# some dummy optimization
target = torch.tensor([1, 2, 3], dtype=torch.float32)
input = torch.zeros_like(target)
input.requires_grad = True
optimizer = torch.optim.SGD([input], lr=0.1)
for _ in range(100):
optimizer.zero_grad()
loss = torch.nn.functional.mse_loss(input, target)
loss.backward() # throws an error in on_test_end...
print(f"Loss: {loss.item()}")
optimizer.step()
def on_train_end(self, trainer: Trainer, pl_module: LightningModule):
self._dummy_optimization() # works well...
def on_test_end(self, trainer: Trainer, pl_module: LightningModule):
torch.set_grad_enabled(True)
self._dummy_optimization() # throws an error...
class DummyModule(LightningModule):
def configure_optimizers(self):
pass
def training_step(self, batch: int, batch_idx: int):
pass
def test_step(self, batch: int, batch_idx: int):
pass
class DummyDataModule(LightningDataModule):
def train_dataloader(self):
return DataLoader([0])
def test_dataloader(self):
return DataLoader([0])
if __name__ == "__main__":
dummy_callback = DummyCallback()
trainer = Trainer(callbacks=[dummy_callback], fast_dev_run=True)
dummy_module = DummyModule()
dummy_data_module = DummyDataModule()
trainer.fit(dummy_module, dummy_data_module)
trainer.test(
dummy_module, dummy_data_module, verbose=True
) # gradient tracking fails in on_test_end of DummyCallback Am I missing something here? SolutionTo answer the question (hint found in #10287 (comment)) If e.g., in the above you do @torch.inference_mode(False)
@torch.enable_grad()
def on_test_end(self, trainer: Trainer, pl_module: LightningModule):
self._dummy_optimization() everything works fine. The missing bit was |
Hello,
I am using pytorch-lightning in a physics-related project, where I need to calculate gradients during validation. At the moment the gradient calculation is disabled in the validate function of the trainer. Of course, commenting out the line solves the problem. However, it took some time to figure out where everything goes wrong.
Solution:
I would suggest adding another parameter to trainer (e.g. enable_grad_during_validation) to allow the enabling of gradient calculation during validation. Of course, this parameter should be set to False by default so that nothing changes for users that do not require the feature. The changes that are required would be to add the parameter and change the line where the gradient is disabled.
This might be a niche issue, therefore if no one else needs this change, I would suggest adding an extra note in the documentation of the validation loop, that informs users that gradient calculation is disabled during validation.
Ps: thank you for the great library
The text was updated successfully, but these errors were encountered: