-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fast_dev_run fail on log_hyperparams #6395
Labels
Comments
import torch
from torch.utils.data import Dataset
from pytorch_lightning import LightningModule, Trainer
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
output = self.layer(batch)
return output.sum()
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
return optimizer
def on_train_start(self):
if self.logger:
self.logger.log_hyperparams(self.hparams, {"x": 0})
if __name__ == '__main__':
train_data = torch.utils.data.DataLoader(RandomDataset(32, 64))
model = BoringModel()
trainer = Trainer(fast_dev_run=True)
trainer.fit(model, train_data) minimal repro example |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
🐛 Bug
Issue when running:
fast_dev_run=True
"TypeError: log_hyperparams() takes 2 positional arguments but 3 were given"
To Reproduce
When using the following: Where self.hp_metrics is a list of strings where each string is an available metric that is being logged, example "accuracy/val".
Expected behavior
Assume the unit test is wrong since the documentation say that self.logger.log_hyperparams takes one positional argument and one dictionary. The code run fine without fast_dev_run=True and everything is logged correctly to tensorboard.
Environment
pytorch_lightning 1.2.2
The text was updated successfully, but these errors were encountered: