Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor codebase to use trainer.loggers over trainer.logger when needed #11920

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,8 @@ def on_train_epoch_end(self):
# log sampled images
sample_imgs = self(z)
grid = torchvision.utils.make_grid(sample_imgs)
self.logger.experiment.add_image("generated_images", grid, self.current_epoch)
for logger in self.loggers:
logger.experiment.add_image("generated_images", grid, self.current_epoch)


def main(args: Namespace) -> None:
Expand Down
20 changes: 11 additions & 9 deletions pytorch_lightning/callbacks/device_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class DeviceStatsMonitor(Callback):
"""

def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
if not trainer.logger:
if not trainer.loggers:
akashkw marked this conversation as resolved.
Show resolved Hide resolved
raise MisconfigurationException("Cannot use DeviceStatsMonitor callback with Trainer that has no logger.")

def on_train_batch_start(
Expand All @@ -55,17 +55,18 @@ def on_train_batch_start(
batch_idx: int,
unused: Optional[int] = 0,
) -> None:
if not trainer.logger:
if not trainer.loggers:
raise MisconfigurationException("Cannot use `DeviceStatsMonitor` callback with `Trainer(logger=False)`.")

if not trainer.logger_connector.should_update_logs:
return

device = trainer.strategy.root_device
device_stats = trainer.accelerator.get_device_stats(device)
separator = trainer.logger.group_separator
prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_start", separator)
trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step)
for logger in trainer.loggers:
separator = logger.group_separator
prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_start", separator)
logger.log_metrics(prefixed_device_stats, step=trainer.global_step)

def on_train_batch_end(
self,
Expand All @@ -76,17 +77,18 @@ def on_train_batch_end(
batch_idx: int,
unused: Optional[int] = 0,
) -> None:
if not trainer.logger:
if not trainer.loggers:
raise MisconfigurationException("Cannot use `DeviceStatsMonitor` callback with `Trainer(logger=False)`.")

if not trainer.logger_connector.should_update_logs:
return

device = trainer.strategy.root_device
device_stats = trainer.accelerator.get_device_stats(device)
separator = trainer.logger.group_separator
prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_end", separator)
trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step)
for logger in trainer.loggers:
separator = logger.group_separator
prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_end", separator)
logger.log_metrics(prefixed_device_stats, step=trainer.global_step)


def _prefix_metric_keys(metrics_dict: Dict[str, float], prefix: str, separator: str) -> Dict[str, float]:
Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/callbacks/gpu_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __init__(
self._gpu_ids: List[str] = [] # will be assigned later in setup()

def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
if not trainer.logger:
if not trainer.loggers:
raise MisconfigurationException("Cannot use GPUStatsMonitor callback with Trainer that has no logger.")

if trainer.strategy.root_device.type != "cuda":
Expand Down Expand Up @@ -161,8 +161,8 @@ def on_train_batch_start(
# First log at beginning of second step
logs["batch_time/inter_step (ms)"] = (time.time() - self._snap_inter_step_time) * 1000

assert trainer.logger is not None
trainer.logger.log_metrics(logs, step=trainer.global_step)
for logger in trainer.loggers:
logger.log_metrics(logs, step=trainer.global_step)

@rank_zero_only
def on_train_batch_end(
Expand All @@ -186,8 +186,8 @@ def on_train_batch_end(
if self._log_stats.intra_step_time and self._snap_intra_step_time:
logs["batch_time/intra_step (ms)"] = (time.time() - self._snap_intra_step_time) * 1000

assert trainer.logger is not None
trainer.logger.log_metrics(logs, step=trainer.global_step)
for logger in trainer.loggers:
logger.log_metrics(logs, step=trainer.global_step)

@staticmethod
def _get_gpu_ids(device_ids: List[int]) -> List[str]:
Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/callbacks/lr_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def on_train_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> No
MisconfigurationException:
If ``Trainer`` has no ``logger``.
"""
if not trainer.logger:
if not trainer.loggers:
raise MisconfigurationException(
"Cannot use `LearningRateMonitor` callback with `Trainer` that has no logger."
)
Expand Down Expand Up @@ -149,7 +149,6 @@ def _check_no_key(key: str) -> bool:
self.last_momentum_values = {name + "-momentum": None for name in names_flatten}

def on_train_batch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
assert trainer.logger is not None
if not trainer.logger_connector.should_update_logs:
return

Expand All @@ -158,16 +157,17 @@ def on_train_batch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any)
latest_stat = self._extract_stats(trainer, interval)

if latest_stat:
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)
for logger in trainer.loggers:
logger.log_metrics(latest_stat, step=trainer.global_step)

def on_train_epoch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
assert trainer.logger is not None
if self.logging_interval != "step":
interval = "epoch" if self.logging_interval is None else "any"
latest_stat = self._extract_stats(trainer, interval)

if latest_stat:
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)
for logger in trainer.loggers:
logger.log_metrics(latest_stat, step=trainer.global_step)

def _extract_stats(self, trainer: "pl.Trainer", interval: str) -> Dict[str, float]:
latest_stat = {}
Expand Down
24 changes: 13 additions & 11 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.logger import _name, _version
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.types import _METRIC, _PATH, STEP_OUTPUT
from pytorch_lightning.utilities.warnings import WarningCache
Expand Down Expand Up @@ -379,8 +380,9 @@ def save_checkpoint(self, trainer: "pl.Trainer") -> None:
self._save_last_checkpoint(trainer, monitor_candidates)

# notify loggers
if trainer.is_global_zero and trainer.logger:
trainer.logger.after_save_checkpoint(proxy(self))
if trainer.is_global_zero:
for logger in trainer.loggers:
logger.after_save_checkpoint(proxy(self))

def _should_skip_saving_checkpoint(self, trainer: "pl.Trainer") -> bool:
from pytorch_lightning.trainer.states import TrainerFn
Expand Down Expand Up @@ -572,20 +574,20 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None:
"""
if self.dirpath is not None:
return # short circuit

if trainer.logger is not None:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
if trainer.loggers:
akashkw marked this conversation as resolved.
Show resolved Hide resolved
if trainer.weights_save_path != trainer.default_root_dir:
# the user has changed weights_save_path, it overrides anything
save_dir = trainer.weights_save_path
else:
elif len(trainer.loggers) == 1:
save_dir = trainer.logger.save_dir or trainer.default_root_dir
else:
save_dir = trainer.default_root_dir

version = (
trainer.logger.version
if isinstance(trainer.logger.version, str)
else f"version_{trainer.logger.version}"
)
ckpt_path = os.path.join(save_dir, str(trainer.logger.name), version, "checkpoints")
name = _name(trainer.loggers)
akashkw marked this conversation as resolved.
Show resolved Hide resolved
version = _version(trainer.loggers)
version = version if isinstance(version, str) else f"version_{version}"

ckpt_path = os.path.join(save_dir, str(name), version, "checkpoints")
else:
ckpt_path = os.path.join(trainer.weights_save_path, "checkpoints")

Expand Down
14 changes: 8 additions & 6 deletions pytorch_lightning/callbacks/progress/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities.logger import _version
from pytorch_lightning.utilities.rank_zero import rank_zero_warn


Expand Down Expand Up @@ -213,11 +214,12 @@ def get_standard_metrics(trainer: "pl.Trainer", pl_module: "pl.LightningModule")
if pl_module.truncated_bptt_steps > 0:
items_dict["split_idx"] = trainer.fit_loop.split_idx

if trainer.logger is not None and trainer.logger.version is not None:
version = trainer.logger.version
akashkw marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(version, str):
# show last 4 places of long version strings
version = version[-4:]
items_dict["v_num"] = version
if trainer.loggers:
version = _version(trainer.loggers)
if version is not None:
if isinstance(version, str):
# show last 4 places of long version strings
version = version[-4:]
items_dict["v_num"] = version
akashkw marked this conversation as resolved.
Show resolved Hide resolved

return items_dict
13 changes: 7 additions & 6 deletions pytorch_lightning/callbacks/xla_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(self, verbose: bool = True) -> None:
self._verbose = verbose

def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if not trainer.logger:
if not trainer.loggers:
raise MisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.")

if isinstance(trainer.accelerator, TPUAccelerator):
Expand All @@ -88,7 +88,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
self._start_time = time.time()

def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if not trainer.logger:
if not trainer.loggers:
raise MisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.")

device = trainer.strategy.root_device
Expand All @@ -102,10 +102,11 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu
peak_memory = trainer.strategy.reduce(peak_memory) * 0.001
epoch_time = trainer.strategy.reduce(epoch_time)

trainer.logger.log_metrics(
{"avg. free memory (MB)": float(free_memory), "avg. peak memory (MB)": float(peak_memory)},
step=trainer.current_epoch,
)
for logger in trainer.loggers:
logger.log_metrics(
{"avg. free memory (MB)": float(free_memory), "avg. peak memory (MB)": float(peak_memory)},
step=trainer.current_epoch,
)

if self._verbose:
rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def logger(self) -> Optional[LightningLoggerBase]:

@property
def loggers(self) -> List[LightningLoggerBase]:
"""Reference to the loggers object in the Trainer."""
"""Reference to the list of loggers in the Trainer."""
return self.trainer.loggers if self.trainer else []

def _apply_batch_transfer_handler(
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,8 +504,10 @@ def _save_loggers_on_train_batch_end(self) -> None:
"""Flushes loggers to disk."""
# when loggers should save to disk
should_flush_logs = self.trainer.logger_connector.should_flush_logs
if should_flush_logs and self.trainer.is_global_zero and self.trainer.logger is not None:
self.trainer.logger.save()
# TODO: is_global_zero check should be moved to logger.save() implementation
if should_flush_logs and self.trainer.is_global_zero:
akashkw marked this conversation as resolved.
Show resolved Hide resolved
for logger in self.trainer.loggers:
logger.save()

def _reload_dataloader_state_dict(self, data_fetcher: AbstractDataFetcher) -> None:
if self._dataloader_state_dict:
Expand Down
6 changes: 5 additions & 1 deletion pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,11 @@ def optimizer_step(
def _track_grad_norm(self, trainer: "pl.Trainer") -> None:
if trainer.track_grad_norm == -1:
return
kwargs = {"group_separator": trainer.logger.group_separator} if trainer.logger is not None else {}

kwargs = {}
if len(trainer.loggers) == 1:
kwargs["group_separator"] = trainer.loggers[0].group_separator

grad_norm_dict = grad_norm(trainer.lightning_module, trainer.track_grad_norm, **kwargs)
if grad_norm_dict:
prev_fx = trainer.lightning_module._current_fx_name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import pytorch_lightning as pl
from pytorch_lightning.accelerators import GPUAccelerator
from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger
from pytorch_lightning.loggers import LightningLoggerBase, TensorBoardLogger
from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment
from pytorch_lightning.trainer.connectors.logger_connector.result import _METRICS, _OUT_DICT, _PBAR_DICT
from pytorch_lightning.trainer.states import RunningStage
Expand Down Expand Up @@ -90,15 +90,15 @@ def should_update_logs(self) -> bool:
def configure_logger(self, logger: Union[bool, LightningLoggerBase, Iterable[LightningLoggerBase]]) -> None:
if isinstance(logger, bool):
# default logger
self.trainer.logger = (
TensorBoardLogger(save_dir=self.trainer.default_root_dir, version=SLURMEnvironment.job_id())
self.trainer.loggers = (
[TensorBoardLogger(save_dir=self.trainer.default_root_dir, version=SLURMEnvironment.job_id())]
if logger
else None
else []
)
elif isinstance(logger, Iterable):
self.trainer.logger = LoggerCollection(logger)
self.trainer.loggers = list(logger)
else:
self.trainer.logger = logger
self.trainer.loggers = [logger]

def log_metrics(self, metrics: _OUT_DICT, step: Optional[int] = None) -> None:
"""Logs the metric dict passed in. If `step` parameter is None and `step` key is presented is metrics, uses
Expand All @@ -109,7 +109,7 @@ def log_metrics(self, metrics: _OUT_DICT, step: Optional[int] = None) -> None:
step: Step for which metrics should be logged. Default value is `self.global_step` during training or
the total validation / test log step count during validation and testing.
"""
if self.trainer.logger is None or not metrics:
if not self.trainer.loggers or not metrics:
return

self._logged_metrics.update(metrics)
Expand All @@ -126,11 +126,12 @@ def log_metrics(self, metrics: _OUT_DICT, step: Optional[int] = None) -> None:
step = self.trainer.global_step

# log actual metrics
if self._override_agg_and_log_metrics:
self.trainer.logger.agg_and_log_metrics(metrics=scalar_metrics, step=step)
else:
self.trainer.logger.log_metrics(metrics=scalar_metrics, step=step)
self.trainer.logger.save()
for logger in self.trainer.loggers:
if self._override_agg_and_log_metrics:
logger.agg_and_log_metrics(metrics=scalar_metrics, step=step)
else:
logger.log_metrics(metrics=scalar_metrics, step=step)
logger.save()

"""
Evaluation metric updates
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/connectors/signal_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def slurm_sigusr1_handler_fn(self, signum: _SIGNUM, frame: FrameType) -> None:
rank_zero_info("handling SIGUSR1")

# save logger to make sure we get all the metrics
if self.trainer.logger:
self.trainer.logger.finalize("finished")
for logger in self.trainer.loggers:
logger.finalize("finished")
hpc_save_path = self.trainer._checkpoint_connector.hpc_save_path(self.trainer.weights_save_path)
self.trainer.save_checkpoint(hpc_save_path)

Expand Down
Loading