-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
3rd partyRelated to a 3rd-partyRelated to a 3rd-partybugSomething isn't workingSomething isn't workinglightningclipl.cli.LightningCLIpl.cli.LightningCLI
Milestone
Description
🐛 Bug
To Reproduce
boring.py
import os
from typing import Callable
import torch
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer, LightningDataModule
from pytorch_lightning.utilities.cli import LightningCLI
from pytorch_lightning.profiler import PyTorchProfiler
from jsonargparse import Namespace
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self, x: int = 10, a_func: Callable = torch.softmax, **mykwargs):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
print('b', mykwargs['b'])
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("valid_loss", loss)
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("test_loss", loss)
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
class MyDataModule(LightningDataModule):
def __init__(self, train_transforms=None, val_transforms=None, test_transforms=None, dims=None):
super().__init__(train_transforms=train_transforms, val_transforms=val_transforms, test_transforms=test_transforms, dims=dims)
def train_dataloader(self) -> DataLoader:
return DataLoader(RandomDataset(32, 64), batch_size=2)
def val_dataloader(self) -> DataLoader:
return DataLoader(RandomDataset(32, 64), batch_size=2)
def test_dataloader(self) -> DataLoader:
return DataLoader(RandomDataset(32, 64), batch_size=2)
if __name__ == "__main__":
LightningCLI(BoringModel, MyDataModule, seed_everything_default=None, save_config_overwrite=True)
python boring.py --print_config
Expected behavior
print config sucessfully
Environment
- CUDA:
- GPU:
- NVIDIA A100-SXM4-40GB
- NVIDIA A100-SXM4-40GB
- NVIDIA A100-SXM4-40GB
- NVIDIA A100-SXM4-40GB
- NVIDIA A100-SXM4-40GB
- NVIDIA A100-SXM4-40GB
- NVIDIA A100-SXM4-40GB
- NVIDIA A100-SXM4-40GB
- available: True
- version: 11.3 - Packages:
- numpy: 1.21.2
- pyTorch_debug: False
- pyTorch_version: 1.10.2
- pytorch-lightning: 1.5.10
- tqdm: 4.62.3 - System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.9.7
- version: Proposal for help #1 SMP Mon Oct 19 16:18:59 UTC 2020
Additional context
Metadata
Metadata
Assignees
Labels
3rd partyRelated to a 3rd-partyRelated to a 3rd-partybugSomething isn't workingSomething isn't workinglightningclipl.cli.LightningCLIpl.cli.LightningCLI