Skip to content

Commit

Permalink
Better errors for logging corner cases (#13164)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Jun 28, 2022
1 parent a475010 commit b1e38bf
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 12 deletions.
6 changes: 6 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Include a version suffix for new "last" checkpoints of later runs in the same directory ([#12902](https://github.com/PyTorchLightning/pytorch-lightning/pull/12902))


- Show a better error message when a Metric that does not return a Tensor is logged ([#13164](https://github.com/PyTorchLightning/pytorch-lightning/pull/13164))


- Added missing `predict_dataset` argument in `LightningDataModule.from_datasets` to create predict dataloaders ([#12942](https://github.com/PyTorchLightning/pytorch-lightning/pull/12942))


Expand Down Expand Up @@ -123,6 +126,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- `DataLoader` instantiated inside a `*_dataloader` hook will not set the passed arguments as attributes anymore ([#12981](https://github.com/PyTorchLightning/pytorch-lightning/pull/12981))


- When a multi-element tensor is logged, an error is now raised instead of silently taking the mean of all elements ([#13164](https://github.com/PyTorchLightning/pytorch-lightning/pull/13164))


- The `WandbLogger` will now use the run name in the logs folder if it is provided, and otherwise the project name ([#12604](https://github.com/PyTorchLightning/pytorch-lightning/pull/12604))


Expand Down
12 changes: 10 additions & 2 deletions src/pytorch_lightning/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ def log(
)

value = apply_to_collection(value, numbers.Number, self.__to_tensor)
apply_to_collection(value, torch.Tensor, self.__check_numel_1, name)

if self.trainer._logger_connector.should_reset_tensors(self._current_fx_name):
# if we started a new epoch (running its first batch) the hook name has changed
Expand Down Expand Up @@ -518,11 +519,10 @@ def log_dict(
)

@staticmethod
def __check_not_nested(value: dict, name: str) -> dict:
def __check_not_nested(value: dict, name: str) -> None:
# self-imposed restriction. for simplicity
if any(isinstance(v, dict) for v in value.values()):
raise ValueError(f"`self.log({name}, {value})` was called, but nested dictionaries cannot be logged")
return value

@staticmethod
def __check_allowed(v: Any, name: str, value: Any) -> None:
Expand All @@ -531,6 +531,14 @@ def __check_allowed(v: Any, name: str, value: Any) -> None:
def __to_tensor(self, value: numbers.Number) -> Tensor:
return torch.tensor(value, device=self.device)

@staticmethod
def __check_numel_1(value: torch.Tensor, name: str) -> None:
if not torch.numel(value) == 1:
raise ValueError(
f"`self.log({name}, {value})` was called, but the tensor must have a single element."
f" You can try doing `self.log({name}, {value}.mean())`"
)

def log_grad_norm(self, grad_norm_dict: Dict[str, float]) -> None:
"""Override this method to change the default behaviour of ``log_grad_norm``.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,12 +244,12 @@ def update(self, value: _IN_METRIC, batch_size: int) -> None: # type: ignore[ov
# perform accumulation with reduction
if self.meta.is_mean_reduction:
# do not use `+=` as it doesn't do type promotion
self.value = self.value + value.mean() * batch_size
self.value = self.value + value * batch_size
self.cumulated_batch_size = self.cumulated_batch_size + batch_size
elif self.meta.is_max_reduction or self.meta.is_min_reduction:
self.value = self.meta.reduce_fx(self.value, value.mean())
self.value = self.meta.reduce_fx(self.value, value)
elif self.meta.is_sum_reduction:
self.value = self.value + value.mean()
self.value = self.value + value
else:
value = cast(Metric, value)
self.value = value
Expand Down Expand Up @@ -528,8 +528,14 @@ def _get_cache(result_metric: _ResultMetric, on_step: bool) -> Optional[Tensor]:
result_metric.compute()
result_metric.meta.sync.should = should
cache = result_metric._computed
if cache is not None and not result_metric.meta.enable_graph:
return cache.detach()
if cache is not None:
if not isinstance(cache, torch.Tensor):
raise ValueError(
f"The `.compute()` return of the metric logged as {result_metric.meta.name!r} must be a tensor."
f" Found {cache}"
)
if not result_metric.meta.enable_graph:
return cache.detach()
return cache

def valid_items(self) -> Generator:
Expand Down
30 changes: 30 additions & 0 deletions tests/tests_pytorch/core/test_metric_result_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,3 +629,33 @@ def test_result_metric_max_min(reduce_fx, expected):
rm = _ResultMetric(metadata, is_tensor=True)
rm.update(torch.tensor(expected), 1)
assert rm.compute() == expected


def test_compute_not_a_tensor_raises():
class RandomMetric(Metric):
def update(self):
pass

def compute(self):
return torch.tensor(1.0), torch.tensor(2.0)

class MyModel(BoringModel):
def __init__(self):
super().__init__()
self.metric = RandomMetric()

def on_train_start(self):
self.log("foo", self.metric)

model = MyModel()
trainer = Trainer(
limit_train_batches=1,
limit_val_batches=0,
max_epochs=1,
enable_progress_bar=False,
enable_checkpointing=False,
logger=False,
enable_model_summary=False,
)
with pytest.raises(ValueError, match=r"compute\(\)` return of.*foo' must be a tensor"):
trainer.fit(model)
23 changes: 18 additions & 5 deletions tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,16 @@ class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
self.log("foo/dataloader_idx_0", -1)

trainer = Trainer(default_root_dir=tmpdir)
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=1,
limit_val_batches=0,
max_epochs=1,
enable_progress_bar=False,
enable_checkpointing=False,
logger=False,
enable_model_summary=False,
)
model = TestModel()
with pytest.raises(MisconfigurationException, match="`self.log` with the key `foo/dataloader_idx_0`"):
trainer.fit(model)
Expand All @@ -640,7 +649,6 @@ class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
self.log("foo", Accuracy())

trainer = Trainer(default_root_dir=tmpdir)
model = TestModel()
with pytest.raises(MisconfigurationException, match="fix this by setting an attribute for the metric in your"):
trainer.fit(model)
Expand All @@ -653,7 +661,6 @@ def __init__(self):
def training_step(self, batch, batch_idx):
self.log("foo", Accuracy())

trainer = Trainer(default_root_dir=tmpdir)
model = TestModel()
with pytest.raises(
MisconfigurationException,
Expand All @@ -667,7 +674,6 @@ def training_step(self, *args):
self.log("foo", -1, prog_bar=True)
return super().training_step(*args)

trainer = Trainer(default_root_dir=tmpdir)
model = TestModel()
with pytest.raises(MisconfigurationException, match=r"self.log\(foo, ...\)` twice in `training_step`"):
trainer.fit(model)
Expand All @@ -677,11 +683,18 @@ def training_step(self, *args):
self.log("foo", -1, reduce_fx=torch.argmax)
return super().training_step(*args)

trainer = Trainer(default_root_dir=tmpdir)
model = TestModel()
with pytest.raises(MisconfigurationException, match=r"reduce_fx={min,max,mean,sum}\)` are supported"):
trainer.fit(model)

class TestModel(BoringModel):
def on_train_start(self):
self.log("foo", torch.tensor([1.0, 2.0]))

model = TestModel()
with pytest.raises(ValueError, match="tensor must have a single element"):
trainer.fit(model)


def test_sanity_metrics_are_reset(tmpdir):
class TestModel(BoringModel):
Expand Down

0 comments on commit b1e38bf

Please sign in to comment.