Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Dec 24, 2022
1 parent 5927185 commit 292fa2c
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
16 changes: 8 additions & 8 deletions src/pytorch_lightning/callbacks/progress/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def get_metrics(self, trainer, model):

def get_standard_metrics(trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> Dict[str, Union[int, str]]:
r"""
Returns several standard metrics displayed in the progress bar, including the average loss value,
Returns several standard metrics displayed in the progress bar, including the latest loss value,
split index of BPTT (if used) and the version of the experiment when using a logger.
.. code-block::
Expand All @@ -268,16 +268,16 @@ def get_standard_metrics(trainer: "pl.Trainer", pl_module: "pl.LightningModule")
Dictionary with the standard metrics to be displayed in the progress bar.
"""
# call .item() only once but store elements without graphs
running_train_loss = trainer.fit_loop.running_loss.mean()
avg_training_loss = None
if running_train_loss is not None:
avg_training_loss = running_train_loss.cpu().item()
loss_metric = trainer.fit_loop._results.get("training_step.train_loss")
loss_value = None
if loss_metric is not None:
loss_value = loss_metric.value.cpu().item()
elif pl_module.automatic_optimization:
avg_training_loss = float("NaN")
loss_value = float("NaN")

items_dict: Dict[str, Union[int, str]] = {}
if avg_training_loss is not None:
items_dict["loss"] = f"{avg_training_loss:.3g}"
if loss_value is not None:
items_dict["loss"] = f"{loss_value:.3g}"

if trainer.loggers:
version = _version(trainer.loggers)
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pytorch_lightning.loops.utilities import _is_max_limit_reached, _set_sampler_epoch
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
from pytorch_lightning.trainer.progress import Progress
from pytorch_lightning.trainer.supporters import CombinedLoader, TensorRunningAccum
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import (
AbstractDataFetcher,
Expand Down
4 changes: 2 additions & 2 deletions src/pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,8 +379,8 @@ def on_train_batch_end(
if self.progress_bar:
self.progress_bar.update()

# TODO: How to access the loss in the tuner?
loss_tensor = trainer.fit_loop.running_loss.last()
# TODO: should we read it from the local variable "outputs"?
loss_tensor = trainer.fit_loop._results["training_step.train_loss"].value
assert loss_tensor is not None
current_loss = loss_tensor.item()
current_step = trainer.global_step
Expand Down

0 comments on commit 292fa2c

Please sign in to comment.