Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to enable grad during validation #13948

Closed
Ilykuleshov opened this issue Jul 31, 2022 · 11 comments
Closed

Unable to enable grad during validation #13948

Ilykuleshov opened this issue Jul 31, 2022 · 11 comments
Labels
question Further information is requested won't fix This will not be worked on
Milestone

Comments

@Ilykuleshov
Copy link

Ilykuleshov commented Jul 31, 2022

🐛 Bug

Even with @torch.enable_grad , gradients aren't enabled during validation

To Reproduce

import os

import numpy as np
import torch
from torch.utils.data import DataLoader

from pytorch_lightning import Trainer, Callback
from pytorch_lightning.demos.boring_classes import RandomDataset, BoringModel


class BoringCallback(Callback):
    def __init__(self):
        super().__init__()
        self.saved_batch = np.random.rand(32).astype(np.float32)
        self.out_ch = 2

    @torch.enable_grad()
    def _init_and_train_layer(self, feats):
        feats.requires_grad_()
        layer = torch.nn.Linear(2, self.out_ch, device=feats.device)
        logits = layer(feats[None])
        assert logits.requires_grad

    def _pred_feats(self, pl_module):
        saved_batch = torch.tensor(self.saved_batch, device=pl_module.device)
        return pl_module(saved_batch[None])[0]

    def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx) -> None:
        self._init_and_train_layer(self._pred_feats(pl_module))


val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = BoringModel()
trainer = Trainer(
    default_root_dir=os.getcwd(),
    limit_val_batches=1,
    max_epochs=1,
    enable_model_summary=False,
    callbacks=[BoringCallback()],
    accelerator='cpu'
)
trainer.validate(model, val_data)

Expected behavior

Expected to finish without error (assert logits.requires_grad should work)

Environment

  • CUDA:
    • GPU:
      • Tesla T4
    • available: True
    • version: 11.3
  • Packages:
    • lightning: 2022.7.31
    • lightning_app: 0.6.0dev
    • numpy: 1.21.6
    • pyTorch_debug: False
    • pyTorch_version: 1.12.0+cu113
    • pytorch-lightning: 1.7.0rc1
    • tqdm: 4.64.0
  • System:
    • OS: Linux
    • architecture:
      • 64bit
    • processor: x86_64
    • python: 3.7.13
    • version: 1 SMP Sun Apr 24 10:03:06 PDT 2022

Additional context

@Ilykuleshov Ilykuleshov added the needs triage Waiting to be triaged by maintainers label Jul 31, 2022
@carmocca carmocca added question Further information is requested and removed needs triage Waiting to be triaged by maintainers labels Aug 5, 2022
@carmocca
Copy link
Contributor

carmocca commented Aug 5, 2022

The decorator is working properly, the issues is that you are asserting logits.requires_grad, when you actually set feats.requires_grad_()

@SinzyShen
Copy link

Same issue here, after updating to version 1.7.

@chrisliu298
Copy link

chrisliu298 commented Aug 14, 2022

Same issue. In my case, I perform FGSM/PGD attacks in the validation and test steps.

Following #201, gradients work fine with trainer.fit() by setting torch.set_grad_enabled(True) or with torch.enable_grad(): in validation_step() (when validation is performed in the fit loop). However, with trainer.validate(), gradients are disabled even with the above setting, and the same applies to trainer.test().

When the error happens, it produces the message below. Does this indicatetorch.set_grad_enabled(True) is somehow ignored in trainer.validate() and trainer.test()?

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Update: This only occurs in 1.7 and does not happen with 1.6.5.

@carmocca
Copy link
Contributor

Oh. I know the cause. It's because validate and test now use torch.inference_mode after #12715 cc (@rohitgr7).

You will need to do something like:

with torch.inference_mode(False):
    grad_feats = feats.clone().requires_grad_()
    layer = torch.nn.Linear(2, self.out_ch, device=feats.device)
    logits = layer(grad_feats[None])
assert logits.requires_grad

@Waybaba
Copy link

Waybaba commented Sep 13, 2022

Oh. I know the cause. It's because validate and test now use torch.inference_mode after #12715 cc (@rohitgr7).

You will need to do something like:

with torch.inference_mode(False):
    grad_feats = feats.clone().requires_grad_()
    layer = torch.nn.Linear(2, self.out_ch, device=feats.device)
    logits = layer(grad_feats[None])
assert logits.requires_grad

Is there any nice way to update the net (where the parameters are predefined) in the inference mode?

@torch.enable_grad() 
def forward_and_adapt(self, x):
    grad_x = x.clone().requires_grad_()
    outputs = self.net(grad_x) # ?
    entropys = softmax_entropy(outputs)
    loss = entropys.mean()
    loss.backward()
    ...

@carmocca
Copy link
Contributor

Why dp you want to update the net during evaluation? You should use training_step for that

@Waybaba
Copy link

Waybaba commented Sep 14, 2022

Why dp you want to update the net during evaluation? You should use training_step for that

We are working on Test-time Domain Adaptation (link), where the network needs to be updated after each batch feeding in an unsupervised fashion.

Implementing in training_step could be a solution, but would it cause other problems such as augmentation, random shuffling, or dropout?

Nevertheless, I am using the following way, and seems that it works well.

@torch.enable_grad() 
@torch.inference_mode(False)
def forward_and_adapt(self, x):
    grad_x = x.clone().requires_grad_()
    outputs = self.net(grad_x) # ?
    entropys = softmax_entropy(outputs)
    loss = entropys.mean()
    loss.backward()
    ...

Thanks for your reply.

@carmocca
Copy link
Contributor

All of that is meant to be in the training_step. Forward should only contain self.net(x)

but would it cause other problems such as augmentation, random shuffling, or dropout?

Why would it?

@stale
Copy link

stale bot commented Oct 16, 2022

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, PyTorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Oct 16, 2022
@rohitgr7
Copy link
Contributor

fixed by: #15034

you can now set Trainer(inference_mode=False) and enable grad manually.

@carmocca carmocca added this to the v1.8 milestone Oct 18, 2022
@kugwzk
Copy link

kugwzk commented Oct 19, 2022

fixed by: #15034

you can now set Trainer(inference_mode=False) and enable grad manually.

Hi~ When will release v1.8?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested won't fix This will not be worked on
Projects
None yet
Development

No branches or pull requests

7 participants