Skip to content

Cannot use torch.softmax as a Callable parameter in CLI #13092

@quancs

Description

@quancs

🐛 Bug

To Reproduce

boring.py

import os
from typing import Callable

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer, LightningDataModule
from pytorch_lightning.utilities.cli import LightningCLI
from pytorch_lightning.profiler import PyTorchProfiler

from jsonargparse import Namespace


class RandomDataset(Dataset):

    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):

    def __init__(self, x: int = 10, a_func: Callable = torch.softmax, **mykwargs):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)
        print('b', mykwargs['b'])

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


class MyDataModule(LightningDataModule):

    def __init__(self, train_transforms=None, val_transforms=None, test_transforms=None, dims=None):
        super().__init__(train_transforms=train_transforms, val_transforms=val_transforms, test_transforms=test_transforms, dims=dims)

    def train_dataloader(self) -> DataLoader:
        return DataLoader(RandomDataset(32, 64), batch_size=2)

    def val_dataloader(self) -> DataLoader:
        return DataLoader(RandomDataset(32, 64), batch_size=2)

    def test_dataloader(self) -> DataLoader:
        return DataLoader(RandomDataset(32, 64), batch_size=2)


if __name__ == "__main__":
    LightningCLI(BoringModel, MyDataModule, seed_everything_default=None, save_config_overwrite=True)
python boring.py --print_config

Expected behavior

print config sucessfully

Environment

  • CUDA:
    - GPU:
    - NVIDIA A100-SXM4-40GB
    - NVIDIA A100-SXM4-40GB
    - NVIDIA A100-SXM4-40GB
    - NVIDIA A100-SXM4-40GB
    - NVIDIA A100-SXM4-40GB
    - NVIDIA A100-SXM4-40GB
    - NVIDIA A100-SXM4-40GB
    - NVIDIA A100-SXM4-40GB
    - available: True
    - version: 11.3
  • Packages:
    - numpy: 1.21.2
    - pyTorch_debug: False
    - pyTorch_version: 1.10.2
    - pytorch-lightning: 1.5.10
    - tqdm: 4.62.3
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.9.7
    - version: Proposal for help #1 SMP Mon Oct 19 16:18:59 UTC 2020

Additional context

cc @carmocca @mauvilsa

Metadata

Metadata

Assignees

Labels

3rd partyRelated to a 3rd-partybugSomething isn't workinglightningclipl.cli.LightningCLI

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions