Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor codebase to use trainer.loggers over trainer.logger when needed #11920

Merged
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,8 @@ def on_train_epoch_end(self):
# log sampled images
sample_imgs = self(z)
grid = torchvision.utils.make_grid(sample_imgs)
self.logger.experiment.add_image("generated_images", grid, self.current_epoch)
for logger in self.loggers:
self.logger.experiment.add_image("generated_images", grid, self.current_epoch)
akashkw marked this conversation as resolved.
Show resolved Hide resolved


def main(args: Namespace) -> None:
Expand Down
20 changes: 11 additions & 9 deletions pytorch_lightning/callbacks/device_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class DeviceStatsMonitor(Callback):
"""

def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
if not trainer.logger:
if not trainer.loggers:
akashkw marked this conversation as resolved.
Show resolved Hide resolved
raise MisconfigurationException("Cannot use DeviceStatsMonitor callback with Trainer that has no logger.")

def on_train_batch_start(
Expand All @@ -55,17 +55,18 @@ def on_train_batch_start(
batch_idx: int,
unused: Optional[int] = 0,
) -> None:
if not trainer.logger:
if not trainer.loggers:
raise MisconfigurationException("Cannot use `DeviceStatsMonitor` callback with `Trainer(logger=False)`.")

if not trainer.logger_connector.should_update_logs:
return

device = trainer.strategy.root_device
device_stats = trainer.accelerator.get_device_stats(device)
separator = trainer.logger.group_separator
prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_start", separator)
trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step)
for logger in trainer.loggers:
separator = logger.group_separator
prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_start", separator)
logger.log_metrics(prefixed_device_stats, step=trainer.global_step)

def on_train_batch_end(
self,
Expand All @@ -76,17 +77,18 @@ def on_train_batch_end(
batch_idx: int,
unused: Optional[int] = 0,
) -> None:
if not trainer.logger:
if not trainer.loggers:
raise MisconfigurationException("Cannot use `DeviceStatsMonitor` callback with `Trainer(logger=False)`.")

if not trainer.logger_connector.should_update_logs:
return

device = trainer.strategy.root_device
device_stats = trainer.accelerator.get_device_stats(device)
separator = trainer.logger.group_separator
prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_end", separator)
trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step)
for logger in trainer.loggers:
separator = logger.group_separator
prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_end", separator)
logger.log_metrics(prefixed_device_stats, step=trainer.global_step)


def _prefix_metric_keys(metrics_dict: Dict[str, float], prefix: str, separator: str) -> Dict[str, float]:
Expand Down
12 changes: 7 additions & 5 deletions pytorch_lightning/callbacks/gpu_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def __init__(
self._gpu_ids: List[str] = [] # will be assigned later in setup()

def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
if not trainer.logger:
if not trainer.loggers:
raise MisconfigurationException("Cannot use GPUStatsMonitor callback with Trainer that has no logger.")

if trainer._device_type != _AcceleratorType.GPU:
Expand Down Expand Up @@ -162,8 +162,9 @@ def on_train_batch_start(
# First log at beginning of second step
logs["batch_time/inter_step (ms)"] = (time.time() - self._snap_inter_step_time) * 1000

assert trainer.logger is not None
trainer.logger.log_metrics(logs, step=trainer.global_step)
assert trainer.loggers
akashkw marked this conversation as resolved.
Show resolved Hide resolved
for logger in trainer.loggers:
logger.log_metrics(logs, step=trainer.global_step)

@rank_zero_only
def on_train_batch_end(
Expand All @@ -187,8 +188,9 @@ def on_train_batch_end(
if self._log_stats.intra_step_time and self._snap_intra_step_time:
logs["batch_time/intra_step (ms)"] = (time.time() - self._snap_intra_step_time) * 1000

assert trainer.logger is not None
trainer.logger.log_metrics(logs, step=trainer.global_step)
assert trainer.loggers
for logger in trainer.loggers:
logger.log_metrics(logs, step=trainer.global_step)

@staticmethod
def _get_gpu_ids(device_ids: List[int]) -> List[str]:
Expand Down
12 changes: 7 additions & 5 deletions pytorch_lightning/callbacks/lr_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def on_train_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> No
MisconfigurationException:
If ``Trainer`` has no ``logger``.
"""
if not trainer.logger:
if not trainer.loggers:
raise MisconfigurationException(
"Cannot use `LearningRateMonitor` callback with `Trainer` that has no logger."
)
Expand Down Expand Up @@ -149,7 +149,7 @@ def _check_no_key(key: str) -> bool:
self.last_momentum_values = {name + "-momentum": None for name in names_flatten}

def on_train_batch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
assert trainer.logger is not None
assert trainer.loggers
akashkw marked this conversation as resolved.
Show resolved Hide resolved
if not trainer.logger_connector.should_update_logs:
return

Expand All @@ -158,16 +158,18 @@ def on_train_batch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any)
latest_stat = self._extract_stats(trainer, interval)

if latest_stat:
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)
for logger in trainer.loggers:
logger.log_metrics(latest_stat, step=trainer.global_step)

def on_train_epoch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
assert trainer.logger is not None
assert trainer.loggers
akashkw marked this conversation as resolved.
Show resolved Hide resolved
if self.logging_interval != "step":
interval = "epoch" if self.logging_interval is None else "any"
latest_stat = self._extract_stats(trainer, interval)

if latest_stat:
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)
for logger in trainer.loggers:
logger.log_metrics(latest_stat, step=trainer.global_step)

def _extract_stats(self, trainer: "pl.Trainer", interval: str) -> Dict[str, float]:
latest_stat = {}
Expand Down
31 changes: 19 additions & 12 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,9 @@ def save_checkpoint(self, trainer: "pl.Trainer") -> None:
self._save_last_checkpoint(trainer, monitor_candidates)

# notify loggers
if trainer.is_global_zero and trainer.logger:
trainer.logger.after_save_checkpoint(proxy(self))
if trainer.is_global_zero:
for logger in trainer.loggers:
logger.after_save_checkpoint(proxy(self))

def _should_skip_saving_checkpoint(self, trainer: "pl.Trainer") -> bool:
from pytorch_lightning.trainer.states import TrainerFn
Expand Down Expand Up @@ -578,20 +579,26 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None:
"""
if self.dirpath is not None:
return # short circuit

if trainer.logger is not None:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
if trainer.loggers:
akashkw marked this conversation as resolved.
Show resolved Hide resolved
if trainer.weights_save_path != trainer.default_root_dir:
# the user has changed weights_save_path, it overrides anything
save_dir = trainer.weights_save_path
else:
save_dir = trainer.logger.save_dir or trainer.default_root_dir

version = (
trainer.logger.version
if isinstance(trainer.logger.version, str)
else f"version_{trainer.logger.version}"
)
ckpt_path = os.path.join(save_dir, str(trainer.logger.name), version, "checkpoints")
if len(trainer.loggers) == 1:
akashkw marked this conversation as resolved.
Show resolved Hide resolved
save_dir = trainer.logger.save_dir or trainer.default_root_dir
else:
save_dir = trainer.default_root_dir
akashkw marked this conversation as resolved.
Show resolved Hide resolved

if len(trainer.loggers) == 1:
version = (
trainer.logger.version
if isinstance(trainer.logger.version, str)
else f"version_{trainer.logger.version}"
)
# TODO: Find out what ckpt_path should be with multiple loggers
akashkw marked this conversation as resolved.
Show resolved Hide resolved
ckpt_path = os.path.join(save_dir, str(trainer.logger.name), version, "checkpoints")
else:
ckpt_path = os.path.join(save_dir, "checkpoints")
else:
ckpt_path = os.path.join(trainer.weights_save_path, "checkpoints")

Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/callbacks/progress/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,9 @@ def get_standard_metrics(trainer: "pl.Trainer", pl_module: "pl.LightningModule")
if pl_module.truncated_bptt_steps > 0:
items_dict["split_idx"] = trainer.fit_loop.split_idx

if trainer.logger is not None and trainer.logger.version is not None:
version = trainer.logger.version
akashkw marked this conversation as resolved.
Show resolved Hide resolved
# TODO: Find out if this is the correct approach
if len(trainer.loggers) == 1 and trainer.loggers[0].version is not None:
version = trainer.loggers[0].version
if isinstance(version, str):
# show last 4 places of long version strings
version = version[-4:]
Expand Down
13 changes: 7 additions & 6 deletions pytorch_lightning/callbacks/xla_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__(self, verbose: bool = True) -> None:
self._verbose = verbose

def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if not trainer.logger:
if not trainer.loggers:
raise MisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.")

if trainer._device_type != _AcceleratorType.TPU:
Expand All @@ -87,7 +87,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
self._start_time = time.time()

def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if not trainer.logger:
if not trainer.loggers:
raise MisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.")

device = trainer.strategy.root_device
Expand All @@ -101,10 +101,11 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu
peak_memory = trainer.strategy.reduce(peak_memory) * 0.001
epoch_time = trainer.strategy.reduce(epoch_time)

trainer.logger.log_metrics(
{"avg. free memory (MB)": float(free_memory), "avg. peak memory (MB)": float(peak_memory)},
step=trainer.current_epoch,
)
for logger in trainer.loggers:
logger.log_metrics(
{"avg. free memory (MB)": float(free_memory), "avg. peak memory (MB)": float(peak_memory)},
step=trainer.current_epoch,
)

if self._verbose:
rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,8 +503,9 @@ def _save_loggers_on_train_batch_end(self) -> None:
"""Flushes loggers to disk."""
# when loggers should save to disk
should_flush_logs = self.trainer.logger_connector.should_flush_logs
if should_flush_logs and self.trainer.is_global_zero and self.trainer.logger is not None:
self.trainer.logger.save()
if should_flush_logs and self.trainer.is_global_zero:
akashkw marked this conversation as resolved.
Show resolved Hide resolved
for logger in self.trainer.loggers:
logger.save()

def _reload_dataloader_state_dict(self, data_fetcher: AbstractDataFetcher) -> None:
if self._dataloader_state_dict:
Expand Down
9 changes: 8 additions & 1 deletion pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,14 @@ def optimizer_step(
def _track_grad_norm(self, trainer: "pl.Trainer") -> None:
if trainer.track_grad_norm == -1:
return
kwargs = {"group_separator": trainer.logger.group_separator} if trainer.logger is not None else {}

if not trainer.loggers:
kwargs = {}
elif len(trainer.loggers) > 1:
kwargs = {"group_separator": trainer.loggers[0].group_separator}
else:
kwargs = {"group_separator": "/"}
akashkw marked this conversation as resolved.
Show resolved Hide resolved

grad_norm_dict = grad_norm(trainer.lightning_module, trainer.track_grad_norm, **kwargs)
if grad_norm_dict:
prev_fx = trainer.lightning_module._current_fx_name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch

import pytorch_lightning as pl
from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger
from pytorch_lightning.loggers import LightningLoggerBase, TensorBoardLogger
from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment
from pytorch_lightning.trainer.connectors.logger_connector.result import _METRICS, _OUT_DICT, _PBAR_DICT
from pytorch_lightning.trainer.states import RunningStage
Expand Down Expand Up @@ -78,15 +78,15 @@ def should_update_logs(self) -> bool:
def configure_logger(self, logger: Union[bool, LightningLoggerBase, Iterable[LightningLoggerBase]]) -> None:
if isinstance(logger, bool):
# default logger
self.trainer.logger = (
TensorBoardLogger(save_dir=self.trainer.default_root_dir, version=SLURMEnvironment.job_id())
self.trainer.loggers = (
[TensorBoardLogger(save_dir=self.trainer.default_root_dir, version=SLURMEnvironment.job_id())]
if logger
else None
else []
)
elif isinstance(logger, Iterable):
self.trainer.logger = LoggerCollection(logger)
self.trainer.loggers = list(logger)
else:
self.trainer.logger = logger
self.trainer.loggers = [logger]

def log_metrics(self, metrics: _OUT_DICT, step: Optional[int] = None) -> None:
"""Logs the metric dict passed in. If `step` parameter is None and `step` key is presented is metrics, uses
Expand All @@ -97,7 +97,7 @@ def log_metrics(self, metrics: _OUT_DICT, step: Optional[int] = None) -> None:
step: Step for which metrics should be logged. Default value is `self.global_step` during training or
the total validation / test log step count during validation and testing.
"""
if self.trainer.logger is None or not metrics:
if not self.trainer.loggers or not metrics:
return

self._logged_metrics.update(metrics)
Expand All @@ -114,8 +114,9 @@ def log_metrics(self, metrics: _OUT_DICT, step: Optional[int] = None) -> None:
step = self.trainer.global_step

# log actual metrics
self.trainer.logger.agg_and_log_metrics(scalar_metrics, step=step)
self.trainer.logger.save()
for logger in self.trainer.loggers:
logger.agg_and_log_metrics(scalar_metrics, step=step)
logger.save()

"""
Evaluation metric updates
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/connectors/signal_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def slurm_sigusr1_handler_fn(self, signum: _SIGNUM, frame: FrameType) -> None:
rank_zero_info("handling SIGUSR1")

# save logger to make sure we get all the metrics
if self.trainer.logger:
self.trainer.logger.finalize("finished")
for logger in self.trainer.loggers:
logger.finalize("finished")
hpc_save_path = self.trainer._checkpoint_connector.hpc_save_path(self.trainer.weights_save_path)
self.trainer.save_checkpoint(hpc_save_path)

Expand Down
32 changes: 16 additions & 16 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,7 @@ def _init_debugging_flags(
self.fit_loop.max_epochs = 1
val_check_interval = 1.0
self.check_val_every_n_epoch = 1
self.logger = DummyLogger() if self.logger is not None else None
self.loggers = [DummyLogger()] if self.loggers else []

rank_zero_info(
"Running in fast_dev_run mode: will run a full train,"
Expand Down Expand Up @@ -1212,10 +1212,10 @@ def _log_hyperparams(self) -> None:
# log hyper-parameters
hparams_initial = None

if self.logger is not None:
# save exp to get started (this is where the first experiment logs are written)
datamodule_log_hyperparams = self.datamodule._log_hyperparams if self.datamodule is not None else False
# save exp to get started (this is where the first experiment logs are written)
datamodule_log_hyperparams = self.datamodule._log_hyperparams if self.datamodule is not None else False

if self.loggers:
akashkw marked this conversation as resolved.
Show resolved Hide resolved
if self.lightning_module._log_hyperparams and datamodule_log_hyperparams:
datamodule_hparams = self.datamodule.hparams_initial
lightning_hparams = self.lightning_module.hparams_initial
Expand All @@ -1240,10 +1240,11 @@ def _log_hyperparams(self) -> None:
elif datamodule_log_hyperparams:
hparams_initial = self.datamodule.hparams_initial

for logger in self.loggers:
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
if hparams_initial is not None:
self.logger.log_hyperparams(hparams_initial)
self.logger.log_graph(self.lightning_module)
self.logger.save()
logger.log_hyperparams(hparams_initial)
logger.log_graph(self.lightning_module)
logger.save()

def _teardown(self):
"""This is the Trainer's internal teardown, unrelated to the `teardown` hooks in LightningModule and
Expand Down Expand Up @@ -1472,8 +1473,8 @@ def _call_teardown_hook(self) -> None:

# todo: TPU 8 cores hangs in flush with TensorBoard. Might do for all loggers.
# It might be related to xla tensors blocked when moving the cpu kill loggers.
if self.logger is not None:
self.logger.finalize("success")
for logger in self.loggers:
logger.finalize("success")

# summarize profile results
self.profiler.describe()
Expand Down Expand Up @@ -2096,14 +2097,13 @@ def model(self, model: torch.nn.Module) -> None:

@property
def log_dir(self) -> Optional[str]:
if self.logger is None:
dirpath = self.default_root_dir
elif isinstance(self.logger, TensorBoardLogger):
dirpath = self.logger.log_dir
elif isinstance(self.logger, LoggerCollection):
dirpath = self.default_root_dir
if len(self.loggers) == 1:
if isinstance(self.logger, TensorBoardLogger):
dirpath = self.logger.log_dir
else:
dirpath = self.logger.save_dir
else:
dirpath = self.logger.save_dir
dirpath = self.default_root_dir

dirpath = self.strategy.broadcast(dirpath)
return dirpath
Expand Down
Loading