Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix references for ResultCollection.extra and improve str and repr #8622

Merged
merged 2 commits into from
Jul 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Saved checkpoints will no longer use the type of a `Callback` as the key to avoid issues with unpickling ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886))


-
- Improved string conversion for `ResultCollection` ([#8622](https://github.com/PyTorchLightning/pytorch-lightning/pull/8622))


-
Expand Down Expand Up @@ -74,7 +74,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-
- Fixed references for `ResultCollection.extra` ([#8622](https://github.com/PyTorchLightning/pytorch-lightning/pull/8622))


-
Expand Down
19 changes: 14 additions & 5 deletions pytorch_lightning/trainer/connectors/logger_connector/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,9 +377,8 @@ def minimize(self) -> Optional[torch.Tensor]:

@minimize.setter
def minimize(self, loss: Optional[torch.Tensor]) -> None:
if loss is not None:
if not isinstance(loss, torch.Tensor):
raise ValueError(f"`Result.minimize` must be a `torch.Tensor`, found: {loss}")
if loss is not None and not isinstance(loss, torch.Tensor):
raise ValueError(f"`Result.minimize` must be a `torch.Tensor`, found: {loss}")
self._minimize = loss

@property
Expand All @@ -388,7 +387,8 @@ def extra(self) -> Dict[str, Any]:
Extras are any keys other than the loss returned by
:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`
"""
return self.get("_extra", {})
self.setdefault("_extra", {})
return self["_extra"]

@extra.setter
def extra(self, extra: Dict[str, Any]) -> None:
Expand Down Expand Up @@ -605,7 +605,16 @@ def cpu(self) -> "ResultCollection":
return self.to(device="cpu")

def __str__(self) -> str:
return f"{self.__class__.__name__}({self.training}, {self.device}, {repr(self)})"
# sample output: `ResultCollection(minimize=1.23, {})`
minimize = f"minimize={self.minimize}, " if self.minimize is not None else ""
# remove empty values
self_str = str({k: v for k, v in self.items() if v})
return f"{self.__class__.__name__}({minimize}{self_str})"

def __repr__(self):
# sample output: `{True, cpu, minimize=tensor(1.23 grad_fn=<SumBackward0>), {'_extra': {}}}`
minimize = f"minimize={repr(self.minimize)}, " if self.minimize is not None else ""
return f"{{{self.training}, {repr(self.device)}, " + minimize + f"{super().__repr__()}}}"

def __getstate__(self, drop_value: bool = True) -> dict:
d = self.__dict__.copy()
Expand Down
23 changes: 22 additions & 1 deletion tests/core/test_metric_result_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,28 @@ def test_result_metric_integration():

assert epoch_log == {"b": cumulative_sum, "a_epoch": cumulative_sum}

result.minimize = torch.tensor(1.0)
result.extra = {}
assert str(result) == (
"ResultCollection(True, cpu, {"
"ResultCollection("
"minimize=1.0, "
carmocca marked this conversation as resolved.
Show resolved Hide resolved
"{"
"'h.a': ResultMetric('a', value=DummyMetric()), "
"'h.b': ResultMetric('b', value=DummyMetric()), "
"'h.c': ResultMetric('c', value=DummyMetric())"
"})"
)
assert repr(result) == (
"{"
"True, "
"device(type='cpu'), "
"minimize=tensor(1.), "
"{'h.a': ResultMetric('a', value=DummyMetric()), "
"'h.b': ResultMetric('b', value=DummyMetric()), "
"'h.c': ResultMetric('c', value=DummyMetric()), "
"'_extra': {}}"
"}"
)


def test_result_collection_simple_loop():
Expand Down Expand Up @@ -332,3 +347,9 @@ def on_save_checkpoint(self, checkpoint) -> None:
gpus=1 if device == "cuda" else 0,
)
trainer.fit(model)


def test_result_collection_extra_reference():
"""Unit-test to check that the `extra` dict reference is properly set."""
rc = ResultCollection(True)
assert rc.extra is rc["_extra"]