Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LightningModule should support Dataclass #12506

Closed
dinhanhx opened this issue Mar 29, 2022 · 7 comments
Closed

LightningModule should support Dataclass #12506

dinhanhx opened this issue Mar 29, 2022 · 7 comments
Labels
feature Is an improvement or enhancement won't fix This will not be worked on

Comments

@dinhanhx
Copy link

dinhanhx commented Mar 29, 2022

Please read this comment #12506 (comment)

🚀 Feature

In issue #8272 we can see that LightningDataModule is compatible with dataclass. LightningModule should also be compatible alo

Motivation

  • reduce number of lines of codes (when assigning attributes)
  • be able to export hyper configuration

Code example

@dataclass
class SimpleModel(pl.LightningModule):
    num_class: int = 10
    pixel_side_length: int = 28
    num_heads: int = 28

    def __init__(self):
        super().__init__(self)
        # Layers
        self.attn = nn.MultiheadAttention(embed_dim=self.pixel_side_length, num_heads=self.num_heads, batch_first=True)
        self.norm = nn.LayerNorm(self.pixel_side_length)

        # A very tight ffn
        self.ffn = nn.Sequential(nn.Linear(self.pixel_side_length*self.pixel_side_length, self.pixel_side_length), 
                                    nn.ReLU(),
                                    nn.Linear(self.pixel_side_length, self.num_class))
        self.softmax = nn.Softmax()

    def forward(self, x):
        x = torch.squeeze(x, 1)
        attn_out, _ = self.attn(x, x, x)
        norm_out = self.norm(x + attn_out)

        ffn_out = self.ffn(norm_out.reshape(-1, self.pixel_side_length*self.pixel_side_length))
        return self.softmax(ffn_out)

    def training_step(self, batch, batch_idx):
        x, y = batch
        out = self(x)
        loss = F.cross_entropy(out, y)
        acc = (out.argmax(dim=-1) == y).float().mean()
        self.log("train_loss", loss)
        self.log("train_acc", acc)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        out = self(x)
        loss = F.cross_entropy(out, y)
        acc = (out.argmax(dim=-1) == y).float().mean()
        return acc

    def validation_epoch_end(self, outputs):
        self.log("val_acc", torch.stack(outputs).mean())

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001) 

Alternatives

Additional context


If you enjoy Lightning, check out our other projects! ⚡

  • Metrics: Machine learning metrics for distributed, scalable PyTorch applications.

  • Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.

  • Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.

  • Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.

  • Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.

cc @Borda

@dinhanhx dinhanhx added the needs triage Waiting to be triaged by maintainers label Mar 29, 2022
@dinhanhx
Copy link
Author

dinhanhx commented Mar 29, 2022

🐛 Bug

When add dataclass to a lightning module then build layers in __post_init__, it becomes unhasable

Global seed set to 0
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[<ipython-input-12-2363c7c2ca1d>](https://localhost:8080/#) in <module>()
      2 simple_model = SimpleModel()
      3 trainer = Trainer(max_epochs=1, gpus=1)
----> 4 trainer.fit(simple_model, mnist_train, mnist_val)

10 frames
[/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in named_modules(self, memo, prefix, remove_duplicate)
   1685         if memo is None:
   1686             memo = set()
-> 1687         if self not in memo:
   1688             if remove_duplicate:
   1689                 memo.add(self)

TypeError: unhashable type: 'SimpleModel'

To Reproduce

https://colab.research.google.com/drive/1aqHDBUaEf7DWKWIo86uoTtxPqCybLLb9?usp=sharing

@dataclass
class SimpleModel(pl.LightningModule):
    num_class: int = 10
    pixel_side_length: int = 28
    num_heads: int = 28

    def __post_init__(self):
        pl.LightningModule.__init__(self)
        # Layers
        self.attn = nn.MultiheadAttention(embed_dim=self.pixel_side_length, num_heads=self.num_heads, batch_first=True)
        self.norm = nn.LayerNorm(self.pixel_side_length)

        # A very tight ffn
        self.ffn = nn.Sequential(nn.Linear(self.pixel_side_length*self.pixel_side_length, self.pixel_side_length), 
                                    nn.ReLU(),
                                    nn.Linear(self.pixel_side_length, self.num_class))
        self.softmax = nn.Softmax()

    def forward(self, x):
        x = torch.squeeze(x, 1)
        attn_out, _ = self.attn(x, x, x)
        norm_out = self.norm(x + attn_out)

        ffn_out = self.ffn(norm_out.reshape(-1, self.pixel_side_length*self.pixel_side_length))
        return self.softmax(ffn_out)

    def training_step(self, batch, batch_idx):
        x, y = batch
        out = self(x)
        loss = F.cross_entropy(out, y)
        acc = (out.argmax(dim=-1) == y).float().mean()
        self.log("train_loss", loss)
        self.log("train_acc", acc)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        out = self(x)
        loss = F.cross_entropy(out, y)
        acc = (out.argmax(dim=-1) == y).float().mean()
        return acc

    def validation_epoch_end(self, outputs):
        self.log("val_acc", torch.stack(outputs).mean())

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001) 

Environment

  • CUDA:
    • GPU:
      • Tesla K80
    • available: True
    • version: 11.1
  • Packages:
    • numpy: 1.21.5
    • pyTorch_debug: False
    • pyTorch_version: 1.10.0+cu111
    • pytorch-lightning: 1.5.10
    • tqdm: 4.63.0
  • System:
    • OS: Linux
    • architecture:
      • 64bit
    • processor: x86_64
    • python: 3.7.13
    • version: Proposal for help #1 SMP Tue Dec 7 09:58:10 PST 2021

Additional context

@dinhanhx
Copy link
Author

I also added my attempt using dataclass with lightning module in the form of Bug Report

@dinhanhx
Copy link
Author

dinhanhx commented Mar 29, 2022

There are this post on PyTorch. When using @dataclass, __hash__ is overridden hence the program yield TypeError: unhashable type: 'SimpleModel' from PyTorch files, not Lightning.

So should I close this issue? However I think Lightning can support this feature to be consistent with LightningDataModule

@justusschock
Copy link
Member

justusschock commented Mar 30, 2022

Hi, Have you tried with

@dataclass(eq=False)?

Other than that, I think it's very unlikely that we will support this if it does override __hash__ since both PyTorch and we (explicitly and implicitly through PyTorch) require the model to be hashable and we cannot say which parts of PyTorch would be influenced and have to be patched. So IMO this is too dangerous.

@dinhanhx
Copy link
Author

@justusschock I have tried it with @dataclass(eq=False). The model is trained as normal.

I understand what you are saying.

@dinhanhx
Copy link
Author

dinhanhx commented Mar 30, 2022

@justusschock I would like to propose something else. It's not off title. It's still about LightningModule working with dataclass .The users can write their config for the Lightning module derived class as follows

from dataclasses import dataclass
from dataclass_wizard import YAMLWizard

@dataclass
class HyperConfig(YAMLWizard):
    feature_in: int = 10
    feature_out: int = 2

then they can define their model

class BoringModel(LightningModule):
    def __init__(self, hyper_config: HyperConfig = HyperConfig()):
        super().__init__()
        self.hyper_config = hyper_config
        self.layer = torch.nn.Linear(hyper_config.feature_in, hyper_config.feature_out)
        self.save_hyperconfig() # a new function to not conflict self.save_hyperparameters()

When defining like this, Code Hint/Code Suggestion/Intelisense of Jupyter/Vscode will be able to know what self.hypter_config possesses unlike self.hparams.

self.save_hyperconfig() should use YAMLWizard to handle nested dataclass and export to yaml file.

self.save_hyperconfig() should behave like self.save_hyperparameters() like saving to yaml file.

so what do you think?

HOWEVER, I do think that it would be better if there is a way to make self.save_hyperparameters() work with dataclass.

When I save hyper_config with self.save_hyperparameters(hyper_config). The created (by Lightning) yaml file is like this

hyper_config: !!python/object:__main__.HyperConfig
  feature_in: 10
  feature_out: 2

@dinhanhx dinhanhx changed the title LightningModule should support Dataclass like LightningDataModule LightningModule should support Dataclass Mar 30, 2022
@justusschock justusschock added feature Is an improvement or enhancement and removed needs triage Waiting to be triaged by maintainers labels Mar 30, 2022
@stale
Copy link

stale bot commented May 2, 2022

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label May 2, 2022
@stale stale bot closed this as completed Jun 6, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement won't fix This will not be worked on
Projects
None yet
Development

No branches or pull requests

2 participants