Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sanity check LightningModule without create trainer #9357

Closed
gungui98 opened this issue Sep 7, 2021 · 1 comment
Closed

Sanity check LightningModule without create trainer #9357

gungui98 opened this issue Sep 7, 2021 · 1 comment
Labels
bug Something isn't working help wanted Open to be worked on

Comments

@gungui98
Copy link

gungui98 commented Sep 7, 2021

🐛 Bug

Hi, I want to do a sanity check to a LightningModule, in the previous version 1.3.8 everything works fine with the following code, but with the latest version 1.4.5, the test fails because of trainer is NoneType, I do not want to create a whole trainer just to assert some properties of my model.

To Reproduce

import torch
from torch import nn
from torch.nn import functional as F
import pytorch_lightning as pl
import torchmetrics
from torchvision.models import resnet50


class MyResNet(nn.Module):
    def __init__(self, num_classes, pretrained=True):
        super(MyResNet, self).__init__()
        model = resnet50(pretrained)
        self.num_ftrs = model.fc.in_features  # 512

        self.shared = nn.Sequential(*list(model.children())[:-1])
        self.target = torch.nn.Sequential(nn.Linear(self.num_ftrs, num_classes))

    def forward(self, x):
        x = self.shared(x)
        x = torch.squeeze(x)  # dimensions of input of size 1 removed.
        return self.target(x)


class Classifier(pl.LightningModule):

    def __init__(self, num_classes):
        super().__init__()
        self.save_hyperparameters()
        self.backbone = MyResNet(num_classes)
        self.train_accuracy = torchmetrics.Accuracy()

    def forward(self, x):
        embedding = self.backbone(x)
        return embedding

    def training_step(self, batch, batch_idx):
        x, y, _ = batch
        y_hat = self.backbone(x.float())
        loss = F.cross_entropy(y_hat, y)
        self.log('train_loss', loss, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y, _ = batch
        y_hat = self.backbone(x.float())
        loss = F.cross_entropy(y_hat, y)
        acc = self.train_accuracy(y_hat, y)
        self.log('valid_loss', loss, on_step=True)
        self.log('accuracy', acc, on_epoch=True, prog_bar=True)
        return acc

    def test_step(self, batch, batch_idx):
        x, y, _ = batch
        y_hat = self.backbone(x.float())
        loss = F.cross_entropy(y_hat, y)
        self.log('test_loss', loss)
        return loss


if __name__ == '__main__':
    model = Classifier(10, )
    batch = (torch.ones(16, 3, 30, 20), torch.zeros(16).long(), None)
    print(model.training_step(batch, 0))

Expected behavior

the code should return a scalar loss instead of

AttributeError: 'NoneType' object has no attribute '_results'

Environment

  • CUDA:
    • GPU:
      • GeForce RTX 2070 with Max-Q Design
    • available: True
    • version: 10.2
  • Packages:
    • numpy: 1.21.1
    • pyTorch_debug: False
    • pyTorch_version: 1.7.1
    • pytorch-lightning: 1.4.5
    • tqdm: 4.62.0
  • System:
    • OS: Windows
    • architecture:
      • 64bit
      • WindowsPE
    • processor: Intel64 Family 6 Model 158 Stepping 10, GenuineIntel
    • python: 3.8.8
    • version: 10.0.19041

Additional context

@gungui98 gungui98 added bug Something isn't working help wanted Open to be worked on labels Sep 7, 2021
@carmocca
Copy link
Contributor

carmocca commented Sep 7, 2021

You didn't provide a stacktrace but I assume this is where it's failing:

https://github.com/PyTorchLightning/pytorch-lightning/blob/645eabe11055b210f0865aad142d1c26b7827012/pytorch_lightning/core/lightning.py#L407

The message for this error has improved in master:

https://github.com/PyTorchLightning/pytorch-lightning/blob/6892d533ea1c743f7e05171846a28e685db85f51/pytorch_lightning/core/lightning.py#L394-L398

This is not supported right now, there's a feature request for it in #8509. You'll have to avoid the log call until then.

Closing, feel free to comment your thoughts on the feature request.

@carmocca carmocca closed this as completed Sep 7, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on
Projects
None yet
Development

No branches or pull requests

2 participants