Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Callback Metric Dict getting overwritten by Log and Progress Bar Dict #1800

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed data transfer to device when using `torchtext.data.Field` and `include_lengths is True` ([#2689](https://github.com/PyTorchLightning/pytorch-lightning/pull/2689))

- Fixed callback metric getting overwritten by progress bar or log metric ([#1800](https://github.com/PyTorchLightning/pytorch-lightning/pull/1800))

- Fixed shuffle argument for distributed sampler ([#2789](https://github.com/PyTorchLightning/pytorch-lightning/pull/2789))

- Fixed logging interval ([#2694](https://github.com/PyTorchLightning/pytorch-lightning/pull/2694))
Expand All @@ -96,6 +98,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed shell injection vulnerability in subprocess call ([#2786](https://github.com/PyTorchLightning/pytorch-lightning/pull/2786))

- Fixed callback metric getting overwritten by progress bar or log metric ([#1800](https://github.com/PyTorchLightning/pytorch-lightning/pull/1800))

## [0.8.5] - 2020-07-09

### Added
Expand Down
15 changes: 12 additions & 3 deletions pytorch_lightning/trainer/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,18 @@ def process_output(self, output, train=False):
# ---------------
hiddens = output.get('hiddens')

# use every metric passed in as a candidate for callback
callback_metrics.update(progress_bar_metrics)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to remove this?
without this log metrics and progress bar metrics won't be candidates for the callbacks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In #1727 @kessido had the issue, that progress bar or log metric overwrites the callback metric of the top layer dict. An example was also given by @kessido see COLAB

I don't know if this needs to be fixed, that's why I asked in the issue for more opinions. Only @awaelchli responded and said he thinks that this also needs to be fixed.

Because no one started a PR I did to initiate a discussion. I have several ideas on how this could be fixed and mentioned some in the issue above. But this was the easiest and quickest solution. I didn't want to spend too much afford on a solution which then will be discarded.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@williamFalcon @olineumann, in the current update, when that line is removed and we use Result Obj, we cannot save the model checkpoint in form of {val_loss}, it will result epoch=1-val_loss=0 which cannot get the val_loss due to the filename params based on the callback_metrics. Is there another way to assign Result/TrainResul/EvalResult Obj with callback_metrics.

callback_metrics.update(log_metrics)
# iterate over log_metric and progress_bar metric values
# and add it to the callback metric dict because every
# metric value of logging or progressbar could be a candidate
# for early stopping or similar
#
# NOTE: through the dict looping sequence a priority is defined
# so first log metrics values will be added if not existing and
# then progress bar values if not existing in callback and log metric
for metric_dict in [log_metrics, progress_bar_metrics]:
for key in metric_dict.keys():
if key not in callback_metrics.keys():
callback_metrics[key] = metric_dict[key]

# detach all metrics for callbacks to prevent memory leaks
# no .item() because it will slow things down
Expand Down