Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

self.log isn't logging anything when combining two pl.LightningModule into one main pl.LightningModule #10402

Closed
KinWaiCheuk opened this issue Nov 8, 2021 · 3 comments
Labels
bug Something isn't working help wanted Open to be worked on logging Related to the `LoggerConnector` and `log()` priority: 1 Medium priority task

Comments

@KinWaiCheuk
Copy link

🐛 Bug

self.log does not log anything when pl.LightningModule consists of different sub pl.LightningModule

To Reproduce

Sometimes we need to develop submodules (one loss output per submodules), and combine them into the final ensemble module to train everything end to end. Here's the pseudo code:

import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
import pytorch_lightning as pl


class ModelA(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(784,10)

    def forward(self, x):
        x = self.linear(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.flatten(1)
        y_hat = self.forward(x)
        loss = F.cross_entropy(y_hat, y)
        # Logging to TensorBoard by default
        self.log("train_loss_A", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer
    
class ModelB(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(784,256)
        self.linear2 = nn.Linear(256,10)

    def forward(self, x):
        x = self.linear1(x)
        x = self.linear2(x)
        return x

    def training_step(self, batch, batch_idx):   
        x, y = batch
        x = x.flatten(1)         
        y_hat = self.forward(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("train_loss_B", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer    


class Ensemble(pl.LightningModule):
    def __init__(self, modelA, modelB):
        super().__init__()
        self.modelA = modelA
        self.modelB = modelB

#     def forward(self, x):
#         # in lightning, forward defines the prediction/inference actions
#         outputA = self.modelA(x)
#         outputB = self.modelB(x)        
#         return outputA, 

    def training_step(self, batch, batch_idx):
        # training_step defined the train loop.
        # It is independent of forward
        x, y = batch
        loss_A = self.modelA.training_step(batch, batch_idx)
        loss_B = self.modelB.training_step(batch, batch_idx)
        # Logging to TensorBoard by default
        return loss_A+loss_B

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


dataset = MNIST(os.getcwd(), download=False, transform=transforms.ToTensor())
train_loader = DataLoader(dataset)

modelA = ModelA()
modelB = ModelB()
ensemble = Ensemble(modelA, modelB)
logger = pl.loggers.TensorBoardLogger(save_dir='.')

trainer = pl.Trainer(gpus=1, logger=logger)
trainer.fit(ensemble, train_loader)

When training only modelA or modelB, the self.log could log the loss to tensorboard.
But after combining both of them into the final Ensemble model, nothing is logged to the tensorboard.

Expected behavior

I would expect that when training ensemble, I could see both train_loss_A and train_loss_B in the tensorboard.

Environment

You can also fill out the list below manually.
-->

pytorch-lightning==1.5.0
torch==1.10.0

  • Python version:

Additional context

I think it is related to this pull request
#9733

@KinWaiCheuk KinWaiCheuk added bug Something isn't working help wanted Open to be worked on labels Nov 8, 2021
@tchaton
Copy link
Contributor

tchaton commented Nov 8, 2021

Dear @KinWaiCheuk,

This is currently not supported in PyTorch Lightning but in Lightning Flash.

You could do the following

class Model(LightningModule):

   ...

    def __setattr__(self, key, value):
        if isinstance(value, (LightningModule, ModuleWrapperBase)):
            self._children.append(key)
        patched_attributes = ["_current_fx_name", "_current_hook_fx_name", "_results"]
        if isinstance(value, Trainer) or key in patched_attributes:
            if hasattr(self, "_children"):
                for child in self._children:
                    setattr(getattr(self, child), key, value)
        super().__setattr__(key, value)

@awaelchli awaelchli added the logging Related to the `LoggerConnector` and `log()` label Nov 8, 2021
@tchaton tchaton added the priority: 1 Medium priority task label Nov 8, 2021
@tchaton
Copy link
Contributor

tchaton commented Nov 15, 2021

Dear @KinWaiCheuk,

We won't support right now LightningModule in LightningModule as it might have more complex implications.
I will close this issue for now as you have a temporary solution to get unblocked.

Best,
T.C

@tchaton tchaton closed this as completed Nov 15, 2021
@simenhu
Copy link

simenhu commented Apr 5, 2024

Dear @KinWaiCheuk,

This is currently not supported in PyTorch Lightning but in Lightning Flash.

You could do the following

class Model(LightningModule):

   ...

    def __setattr__(self, key, value):
        if isinstance(value, (LightningModule, ModuleWrapperBase)):
            self._children.append(key)
        patched_attributes = ["_current_fx_name", "_current_hook_fx_name", "_results"]
        if isinstance(value, Trainer) or key in patched_attributes:
            if hasattr(self, "_children"):
                for child in self._children:
                    setattr(getattr(self, child), key, value)
        super().__setattr__(key, value)

What combination of pytorch_lightning, lightning and flash do you need to make this work? Would be nice to include it to understand how to get this fix to work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on logging Related to the `LoggerConnector` and `log()` priority: 1 Medium priority task
Projects
None yet
Development

No branches or pull requests

4 participants