From 3cc43a70f7b97e27ebedac9cbbfb71295938e25b Mon Sep 17 00:00:00 2001 From: Akash Kwatra Date: Fri, 11 Feb 2022 19:01:03 -0800 Subject: [PATCH 01/22] Update internal codebase to use loggers --- .../generative_adversarial_net.py | 3 +- .../callbacks/device_stats_monitor.py | 20 ++--- .../callbacks/gpu_stats_monitor.py | 12 +-- pytorch_lightning/callbacks/lr_monitor.py | 12 +-- .../callbacks/model_checkpoint.py | 9 ++- pytorch_lightning/callbacks/progress/base.py | 1 + .../callbacks/xla_stats_monitor.py | 13 ++-- .../loops/epoch/training_epoch_loop.py | 5 +- .../plugins/precision/precision_plugin.py | 9 ++- .../logger_connector/logger_connector.py | 17 +++-- .../trainer/connectors/signal_connector.py | 4 +- pytorch_lightning/trainer/trainer.py | 73 +++++++++---------- pytorch_lightning/tuner/batch_size_scaling.py | 2 +- pytorch_lightning/tuner/lr_finder.py | 2 +- 14 files changed, 100 insertions(+), 82 deletions(-) diff --git a/pl_examples/domain_templates/generative_adversarial_net.py b/pl_examples/domain_templates/generative_adversarial_net.py index fd2cf69f14259..a1637c2dfde04 100644 --- a/pl_examples/domain_templates/generative_adversarial_net.py +++ b/pl_examples/domain_templates/generative_adversarial_net.py @@ -206,7 +206,8 @@ def on_train_epoch_end(self): # log sampled images sample_imgs = self(z) grid = torchvision.utils.make_grid(sample_imgs) - self.logger.experiment.add_image("generated_images", grid, self.current_epoch) + for logger in self.loggers: + self.logger.experiment.add_image("generated_images", grid, self.current_epoch) def main(args: Namespace) -> None: diff --git a/pytorch_lightning/callbacks/device_stats_monitor.py b/pytorch_lightning/callbacks/device_stats_monitor.py index eaf67db9189e2..f9cb3cf623c1b 100644 --- a/pytorch_lightning/callbacks/device_stats_monitor.py +++ b/pytorch_lightning/callbacks/device_stats_monitor.py @@ -44,7 +44,7 @@ class DeviceStatsMonitor(Callback): """ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: - if not trainer.logger: + if not trainer.loggers: raise MisconfigurationException("Cannot use DeviceStatsMonitor callback with Trainer that has no logger.") def on_train_batch_start( @@ -55,7 +55,7 @@ def on_train_batch_start( batch_idx: int, unused: Optional[int] = 0, ) -> None: - if not trainer.logger: + if not trainer.loggers: raise MisconfigurationException("Cannot use `DeviceStatsMonitor` callback with `Trainer(logger=False)`.") if not trainer.logger_connector.should_update_logs: @@ -63,9 +63,10 @@ def on_train_batch_start( device = trainer.strategy.root_device device_stats = trainer.accelerator.get_device_stats(device) - separator = trainer.logger.group_separator - prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_start", separator) - trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step) + for logger in trainer.loggers: + separator = logger.group_separator + prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_start", separator) + logger.log_metrics(prefixed_device_stats, step=trainer.global_step) def on_train_batch_end( self, @@ -76,7 +77,7 @@ def on_train_batch_end( batch_idx: int, unused: Optional[int] = 0, ) -> None: - if not trainer.logger: + if not trainer.loggers: raise MisconfigurationException("Cannot use `DeviceStatsMonitor` callback with `Trainer(logger=False)`.") if not trainer.logger_connector.should_update_logs: @@ -84,9 +85,10 @@ def on_train_batch_end( device = trainer.strategy.root_device device_stats = trainer.accelerator.get_device_stats(device) - separator = trainer.logger.group_separator - prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_end", separator) - trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step) + for logger in trainer.loggers: + separator = logger.group_separator + prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_end", separator) + logger.log_metrics(prefixed_device_stats, step=trainer.global_step) def _prefix_metric_keys(metrics_dict: Dict[str, float], prefix: str, separator: str) -> Dict[str, float]: diff --git a/pytorch_lightning/callbacks/gpu_stats_monitor.py b/pytorch_lightning/callbacks/gpu_stats_monitor.py index f2aa17e1118cf..6c6f05238c6d9 100644 --- a/pytorch_lightning/callbacks/gpu_stats_monitor.py +++ b/pytorch_lightning/callbacks/gpu_stats_monitor.py @@ -124,7 +124,7 @@ def __init__( self._gpu_ids: List[str] = [] # will be assigned later in setup() def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: - if not trainer.logger: + if not trainer.loggers: raise MisconfigurationException("Cannot use GPUStatsMonitor callback with Trainer that has no logger.") if trainer._device_type != _AcceleratorType.GPU: @@ -162,8 +162,9 @@ def on_train_batch_start( # First log at beginning of second step logs["batch_time/inter_step (ms)"] = (time.time() - self._snap_inter_step_time) * 1000 - assert trainer.logger is not None - trainer.logger.log_metrics(logs, step=trainer.global_step) + assert trainer.loggers + for logger in trainer.loggers: + logger.log_metrics(logs, step=trainer.global_step) @rank_zero_only def on_train_batch_end( @@ -187,8 +188,9 @@ def on_train_batch_end( if self._log_stats.intra_step_time and self._snap_intra_step_time: logs["batch_time/intra_step (ms)"] = (time.time() - self._snap_intra_step_time) * 1000 - assert trainer.logger is not None - trainer.logger.log_metrics(logs, step=trainer.global_step) + assert trainer.loggers + for logger in trainer.loggers: + logger.log_metrics(logs, step=trainer.global_step) @staticmethod def _get_gpu_ids(device_ids: List[int]) -> List[str]: diff --git a/pytorch_lightning/callbacks/lr_monitor.py b/pytorch_lightning/callbacks/lr_monitor.py index 0f3519d8fe8c2..25dff2418c72b 100644 --- a/pytorch_lightning/callbacks/lr_monitor.py +++ b/pytorch_lightning/callbacks/lr_monitor.py @@ -104,7 +104,7 @@ def on_train_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> No MisconfigurationException: If ``Trainer`` has no ``logger``. """ - if not trainer.logger: + if not trainer.loggers: raise MisconfigurationException( "Cannot use `LearningRateMonitor` callback with `Trainer` that has no logger." ) @@ -149,7 +149,7 @@ def _check_no_key(key: str) -> bool: self.last_momentum_values = {name + "-momentum": None for name in names_flatten} def on_train_batch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: - assert trainer.logger is not None + assert trainer.loggers if not trainer.logger_connector.should_update_logs: return @@ -158,16 +158,18 @@ def on_train_batch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) latest_stat = self._extract_stats(trainer, interval) if latest_stat: - trainer.logger.log_metrics(latest_stat, step=trainer.global_step) + for logger in trainer.loggers: + logger.log_metrics(latest_stat, step=trainer.global_step) def on_train_epoch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: - assert trainer.logger is not None + assert trainer.loggers if self.logging_interval != "step": interval = "epoch" if self.logging_interval is None else "any" latest_stat = self._extract_stats(trainer, interval) if latest_stat: - trainer.logger.log_metrics(latest_stat, step=trainer.global_step) + for logger in trainer.loggers: + logger.log_metrics(latest_stat, step=trainer.global_step) def _extract_stats(self, trainer: "pl.Trainer", interval: str) -> Dict[str, float]: latest_stat = {} diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 278094dc7bff0..48454bb7e024e 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -385,8 +385,9 @@ def save_checkpoint(self, trainer: "pl.Trainer") -> None: self._save_last_checkpoint(trainer, monitor_candidates) # notify loggers - if trainer.is_global_zero and trainer.logger: - trainer.logger.after_save_checkpoint(proxy(self)) + if trainer.is_global_zero: + for logger in trainer.loggers: + logger.after_save_checkpoint(proxy(self)) def _should_skip_saving_checkpoint(self, trainer: "pl.Trainer") -> bool: from pytorch_lightning.trainer.states import TrainerFn @@ -578,8 +579,8 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None: """ if self.dirpath is not None: return # short circuit - - if trainer.logger is not None: + # TODO: Adapt for trainer.loggers + if trainer.logger: if trainer.weights_save_path != trainer.default_root_dir: # the user has changed weights_save_path, it overrides anything save_dir = trainer.weights_save_path diff --git a/pytorch_lightning/callbacks/progress/base.py b/pytorch_lightning/callbacks/progress/base.py index 291fb495a81c9..ed874878f3ee9 100644 --- a/pytorch_lightning/callbacks/progress/base.py +++ b/pytorch_lightning/callbacks/progress/base.py @@ -213,6 +213,7 @@ def get_standard_metrics(trainer: "pl.Trainer", pl_module: "pl.LightningModule") if pl_module.truncated_bptt_steps > 0: items_dict["split_idx"] = trainer.fit_loop.split_idx + # TODO: Adapt for trainer.loggers if trainer.logger is not None and trainer.logger.version is not None: version = trainer.logger.version if isinstance(version, str): diff --git a/pytorch_lightning/callbacks/xla_stats_monitor.py b/pytorch_lightning/callbacks/xla_stats_monitor.py index 20555f5228e0a..d470e6d180598 100644 --- a/pytorch_lightning/callbacks/xla_stats_monitor.py +++ b/pytorch_lightning/callbacks/xla_stats_monitor.py @@ -69,7 +69,7 @@ def __init__(self, verbose: bool = True) -> None: self._verbose = verbose def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - if not trainer.logger: + if not trainer.loggers: raise MisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.") if trainer._device_type != _AcceleratorType.TPU: @@ -87,7 +87,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo self._start_time = time.time() def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - if not trainer.logger: + if not trainer.loggers: raise MisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.") device = trainer.strategy.root_device @@ -101,10 +101,11 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu peak_memory = trainer.strategy.reduce(peak_memory) * 0.001 epoch_time = trainer.strategy.reduce(epoch_time) - trainer.logger.log_metrics( - {"avg. free memory (MB)": float(free_memory), "avg. peak memory (MB)": float(peak_memory)}, - step=trainer.current_epoch, - ) + for logger in trainer.loggers: + logger.log_metrics( + {"avg. free memory (MB)": float(free_memory), "avg. peak memory (MB)": float(peak_memory)}, + step=trainer.current_epoch, + ) if self._verbose: rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds") diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 97f80cc7e4c7e..34ba4e0fcb262 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -503,8 +503,9 @@ def _save_loggers_on_train_batch_end(self) -> None: """Flushes loggers to disk.""" # when loggers should save to disk should_flush_logs = self.trainer.logger_connector.should_flush_logs - if should_flush_logs and self.trainer.is_global_zero and self.trainer.logger is not None: - self.trainer.logger.save() + if should_flush_logs and self.trainer.is_global_zero: + for logger in self.trainer.loggers: + logger.save() def _reload_dataloader_state_dict(self, data_fetcher: AbstractDataFetcher) -> None: if self._dataloader_state_dict: diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 6bdfe12dc2a2d..e858080eb6550 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -155,7 +155,14 @@ def optimizer_step( def _track_grad_norm(self, trainer: "pl.Trainer") -> None: if trainer.track_grad_norm == -1: return - kwargs = {"group_separator": trainer.logger.group_separator} if trainer.logger is not None else {} + + if not trainer.loggers: + kwargs = {} + elif len(trainer.loggers) == 1: + kwargs = {"group_separator": trainer.logger.group_separator} + else: + kwargs = {"group_separator": "/"} + grad_norm_dict = grad_norm(trainer.lightning_module, trainer.track_grad_norm, **kwargs) if grad_norm_dict: prev_fx = trainer.lightning_module._current_fx_name diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 28c1d17772915..99f41d219d47e 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -78,15 +78,15 @@ def should_update_logs(self) -> bool: def configure_logger(self, logger: Union[bool, LightningLoggerBase, Iterable[LightningLoggerBase]]) -> None: if isinstance(logger, bool): # default logger - self.trainer.logger = ( - TensorBoardLogger(save_dir=self.trainer.default_root_dir, version=SLURMEnvironment.job_id()) + self.trainer.loggers = ( + [TensorBoardLogger(save_dir=self.trainer.default_root_dir, version=SLURMEnvironment.job_id())] if logger - else None + else [] ) elif isinstance(logger, Iterable): - self.trainer.logger = LoggerCollection(logger) + self.trainer.loggers = list(logger) else: - self.trainer.logger = logger + self.trainer.loggers = [logger] def log_metrics(self, metrics: _OUT_DICT, step: Optional[int] = None) -> None: """Logs the metric dict passed in. If `step` parameter is None and `step` key is presented is metrics, uses @@ -97,7 +97,7 @@ def log_metrics(self, metrics: _OUT_DICT, step: Optional[int] = None) -> None: step: Step for which metrics should be logged. Default value is `self.global_step` during training or the total validation / test log step count during validation and testing. """ - if self.trainer.logger is None or not metrics: + if not self.trainer.loggers or not metrics: return self._logged_metrics.update(metrics) @@ -114,8 +114,9 @@ def log_metrics(self, metrics: _OUT_DICT, step: Optional[int] = None) -> None: step = self.trainer.global_step # log actual metrics - self.trainer.logger.agg_and_log_metrics(scalar_metrics, step=step) - self.trainer.logger.save() + for logger in self.trainer.loggers: + logger.agg_and_log_metrics(scalar_metrics, step=step) + logger.save() """ Evaluation metric updates diff --git a/pytorch_lightning/trainer/connectors/signal_connector.py b/pytorch_lightning/trainer/connectors/signal_connector.py index 3fe6a75df1e3d..9dda9ad21d649 100644 --- a/pytorch_lightning/trainer/connectors/signal_connector.py +++ b/pytorch_lightning/trainer/connectors/signal_connector.py @@ -66,8 +66,8 @@ def slurm_sigusr1_handler_fn(self, signum: _SIGNUM, frame: FrameType) -> None: rank_zero_info("handling SIGUSR1") # save logger to make sure we get all the metrics - if self.trainer.logger: - self.trainer.logger.finalize("finished") + for logger in self.trainer.loggers: + logger.finalize("finished") hpc_save_path = self.trainer._checkpoint_connector.hpc_save_path(self.trainer.weights_save_path) self.trainer.save_checkpoint(hpc_save_path) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 4b68adbdf31b7..2707852c2c3c9 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -614,7 +614,7 @@ def _init_debugging_flags( self.fit_loop.max_epochs = 1 val_check_interval = 1.0 self.check_val_every_n_epoch = 1 - self.logger = DummyLogger() if self.logger is not None else None + self.loggers = [DummyLogger()] if self.loggers else [] rank_zero_info( "Running in fast_dev_run mode: will run a full train," @@ -1212,34 +1212,34 @@ def _log_hyperparams(self) -> None: # log hyper-parameters hparams_initial = None - if self.logger is not None: - # save exp to get started (this is where the first experiment logs are written) - datamodule_log_hyperparams = self.datamodule._log_hyperparams if self.datamodule is not None else False - - if self.lightning_module._log_hyperparams and datamodule_log_hyperparams: - datamodule_hparams = self.datamodule.hparams_initial - lightning_hparams = self.lightning_module.hparams_initial - inconsistent_keys = [] - for key in lightning_hparams.keys() & datamodule_hparams.keys(): - lm_val, dm_val = lightning_hparams[key], datamodule_hparams[key] - if type(lm_val) != type(dm_val): - inconsistent_keys.append(key) - elif isinstance(lm_val, torch.Tensor) and id(lm_val) != id(dm_val): - inconsistent_keys.append(key) - elif lm_val != dm_val: - inconsistent_keys.append(key) - if inconsistent_keys: - raise MisconfigurationException( - f"Error while merging hparams: the keys {inconsistent_keys} are present " - "in both the LightningModule's and LightningDataModule's hparams " - "but have different values." - ) - hparams_initial = {**lightning_hparams, **datamodule_hparams} - elif self.lightning_module._log_hyperparams: - hparams_initial = self.lightning_module.hparams_initial - elif datamodule_log_hyperparams: - hparams_initial = self.datamodule.hparams_initial + # save exp to get started (this is where the first experiment logs are written) + datamodule_log_hyperparams = self.datamodule._log_hyperparams if self.datamodule is not None else False + + if self.lightning_module._log_hyperparams and datamodule_log_hyperparams: + datamodule_hparams = self.datamodule.hparams_initial + lightning_hparams = self.lightning_module.hparams_initial + inconsistent_keys = [] + for key in lightning_hparams.keys() & datamodule_hparams.keys(): + lm_val, dm_val = lightning_hparams[key], datamodule_hparams[key] + if type(lm_val) != type(dm_val): + inconsistent_keys.append(key) + elif isinstance(lm_val, torch.Tensor) and id(lm_val) != id(dm_val): + inconsistent_keys.append(key) + elif lm_val != dm_val: + inconsistent_keys.append(key) + if inconsistent_keys: + raise MisconfigurationException( + f"Error while merging hparams: the keys {inconsistent_keys} are present " + "in both the LightningModule's and LightningDataModule's hparams " + "but have different values." + ) + hparams_initial = {**lightning_hparams, **datamodule_hparams} + elif self.lightning_module._log_hyperparams: + hparams_initial = self.lightning_module.hparams_initial + elif datamodule_log_hyperparams: + hparams_initial = self.datamodule.hparams_initial + for logger in trainer.loggers: if hparams_initial is not None: self.logger.log_hyperparams(hparams_initial) self.logger.log_graph(self.lightning_module) @@ -1472,8 +1472,8 @@ def _call_teardown_hook(self) -> None: # todo: TPU 8 cores hangs in flush with TensorBoard. Might do for all loggers. # It might be related to xla tensors blocked when moving the cpu kill loggers. - if self.logger is not None: - self.logger.finalize("success") + for logger in self.loggers: + logger.finalize("success") # summarize profile results self.profiler.describe() @@ -2096,14 +2096,13 @@ def model(self, model: torch.nn.Module) -> None: @property def log_dir(self) -> Optional[str]: - if self.logger is None: - dirpath = self.default_root_dir - elif isinstance(self.logger, TensorBoardLogger): - dirpath = self.logger.log_dir - elif isinstance(self.logger, LoggerCollection): - dirpath = self.default_root_dir + if len(self.loggers) == 1: + if isinstance(self.logger, TensorBoardLogger): + dirpath = self.logger.log_dir + else: + dirpath = self.logger.save_dir else: - dirpath = self.logger.save_dir + dirpath = self.default_root_dir dirpath = self.strategy.broadcast(dirpath) return dirpath diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 1526e570dabe5..3d5916e3f8bd9 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -109,7 +109,7 @@ def __scale_batch_reset_params(trainer: "pl.Trainer", steps_per_trial: int) -> N trainer.auto_scale_batch_size = None # prevent recursion trainer.auto_lr_find = False # avoid lr find being called multiple times trainer.fit_loop.max_steps = steps_per_trial # take few steps - trainer.logger = DummyLogger() if trainer.logger is not None else None + trainer.loggers = [DummyLogger()] if trainer.loggers else [] trainer.callbacks = [] # not needed before full run trainer.limit_train_batches = 1.0 diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index 876ff7823b2dc..d929bbe2f87c7 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -267,7 +267,7 @@ def __lr_finder_reset_params(trainer: "pl.Trainer", num_training: int, early_sto # Use special lr logger callback trainer.callbacks = [_LRCallback(num_training, early_stop_threshold, progress_bar_refresh_rate=1)] # No logging - trainer.logger = DummyLogger() if trainer.logger is not None else None + trainer.loggers = [DummyLogger()] if trainer.loggers else [] # Max step set to number of iterations trainer.fit_loop.max_steps = num_training From 3c4f25891fb56a06b6026edd5024380441dfe0ea Mon Sep 17 00:00:00 2001 From: Akash Kwatra Date: Mon, 14 Feb 2022 15:43:33 -0800 Subject: [PATCH 02/22] refactor unit tests --- pytorch_lightning/trainer/trainer.py | 2 +- tests/loggers/test_base.py | 7 ++++--- tests/profiler/test_profiler.py | 5 ++--- tests/trainer/properties/test_log_dir.py | 5 ++--- tests/trainer/properties/test_loggers.py | 4 +++- 5 files changed, 12 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 2707852c2c3c9..ed5e934dd6248 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1239,7 +1239,7 @@ def _log_hyperparams(self) -> None: elif datamodule_log_hyperparams: hparams_initial = self.datamodule.hparams_initial - for logger in trainer.loggers: + for logger in self.loggers: if hparams_initial is not None: self.logger.log_hyperparams(hparams_initial) self.logger.log_graph(self.lightning_module) diff --git a/tests/loggers/test_base.py b/tests/loggers/test_base.py index 41a3aec103ad8..762973d386f63 100644 --- a/tests/loggers/test_base.py +++ b/tests/loggers/test_base.py @@ -186,10 +186,11 @@ def test_multiple_loggers_pickle(tmpdir): trainer = Trainer(logger=[logger1, logger2]) pkl_bytes = pickle.dumps(trainer) trainer2 = pickle.loads(pkl_bytes) - trainer2.logger.log_metrics({"acc": 1.0}, 0) + for logger in trainer2.loggers: + logger.log_metrics({"acc": 1.0}, 0) - assert trainer2.logger[0].metrics_logged == {"acc": 1.0} - assert trainer2.logger[1].metrics_logged == {"acc": 1.0} + for logger in trainer2.loggers: + assert logger.metrics_logged == {"acc": 1.0} def test_adding_step_key(tmpdir): diff --git a/tests/profiler/test_profiler.py b/tests/profiler/test_profiler.py index e4eef57c6a153..9093b84ef75b4 100644 --- a/tests/profiler/test_profiler.py +++ b/tests/profiler/test_profiler.py @@ -24,8 +24,7 @@ from pytorch_lightning import Callback, Trainer from pytorch_lightning.callbacks import EarlyStopping, StochasticWeightAveraging -from pytorch_lightning.loggers.base import DummyLogger, LoggerCollection -from pytorch_lightning.loggers.tensorboard import TensorBoardLogger +from pytorch_lightning.loggers import CSVLogger, LoggerCollection, TensorBoardLogger from pytorch_lightning.profiler import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler from pytorch_lightning.profiler.pytorch import RegisterRecordFunction, warning_cache from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -494,7 +493,7 @@ def look_for_trace(trace_dir): model = BoringModel() # Wrap the logger in a list so it becomes a LoggerCollection - logger = [TensorBoardLogger(save_dir=tmpdir), DummyLogger()] + logger = [TensorBoardLogger(save_dir=tmpdir), CSVLogger(tmpdir)] trainer = Trainer(default_root_dir=tmpdir, profiler="pytorch", logger=logger, limit_train_batches=5, max_epochs=1) assert isinstance(trainer.logger, LoggerCollection) trainer.fit(model) diff --git a/tests/trainer/properties/test_log_dir.py b/tests/trainer/properties/test_log_dir.py index 71920a6b079bf..6777ec8183737 100644 --- a/tests/trainer/properties/test_log_dir.py +++ b/tests/trainer/properties/test_log_dir.py @@ -15,8 +15,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger -from pytorch_lightning.loggers.base import DummyLogger +from pytorch_lightning.loggers import CSVLogger, LoggerCollection, TensorBoardLogger from tests.helpers.boring_model import BoringModel @@ -118,7 +117,7 @@ def test_logdir_logger_collection(tmpdir): trainer = Trainer( default_root_dir=default_root_dir, max_steps=2, - logger=[TensorBoardLogger(save_dir=save_dir, name="custom_logs"), DummyLogger()], + logger=[TensorBoardLogger(save_dir=save_dir, name="custom_logs"), CSVLogger(tmpdir)], ) assert isinstance(trainer.logger, LoggerCollection) assert trainer.log_dir == default_root_dir diff --git a/tests/trainer/properties/test_loggers.py b/tests/trainer/properties/test_loggers.py index 606c7b641ae1d..d3db78986f361 100644 --- a/tests/trainer/properties/test_loggers.py +++ b/tests/trainer/properties/test_loggers.py @@ -30,18 +30,20 @@ def test_trainer_loggers_property(): # trainer.loggers should create a list of size 1 trainer = Trainer(logger=logger1) + assert trainer.logger == logger1 assert trainer.loggers == [logger1] # trainer.loggers should be an empty list trainer = Trainer(logger=False) + assert trainer.logger is None assert trainer.loggers == [] # trainer.loggers should be a list of size 1 holding the default logger trainer = Trainer(logger=True) assert trainer.loggers == [trainer.logger] - assert type(trainer.loggers[0]) == TensorBoardLogger + assert isinstance(trainer.logger, TensorBoardLogger) def test_trainer_loggers_setters(): From 6cda5b5e64bd72288b29296cf5594d8869994315 Mon Sep 17 00:00:00 2001 From: Akash Kwatra Date: Mon, 14 Feb 2022 16:13:14 -0800 Subject: [PATCH 03/22] Flake fixes --- .../plugins/precision/precision_plugin.py | 4 +- .../logger_connector/logger_connector.py | 2 +- pytorch_lightning/trainer/trainer.py | 47 ++++++++++--------- 3 files changed, 27 insertions(+), 26 deletions(-) diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index e858080eb6550..1d365a5274efe 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -158,8 +158,8 @@ def _track_grad_norm(self, trainer: "pl.Trainer") -> None: if not trainer.loggers: kwargs = {} - elif len(trainer.loggers) == 1: - kwargs = {"group_separator": trainer.logger.group_separator} + elif len(trainer.loggers) > 1: + kwargs = {"group_separator": trainer.loggers[0].group_separator} else: kwargs = {"group_separator": "/"} diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 99f41d219d47e..a0882c80d4e43 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -16,7 +16,7 @@ import torch import pytorch_lightning as pl -from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger +from pytorch_lightning.loggers import LightningLoggerBase, TensorBoardLogger from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment from pytorch_lightning.trainer.connectors.logger_connector.result import _METRICS, _OUT_DICT, _PBAR_DICT from pytorch_lightning.trainer.states import RunningStage diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ed5e934dd6248..f7b172436eb67 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1215,29 +1215,30 @@ def _log_hyperparams(self) -> None: # save exp to get started (this is where the first experiment logs are written) datamodule_log_hyperparams = self.datamodule._log_hyperparams if self.datamodule is not None else False - if self.lightning_module._log_hyperparams and datamodule_log_hyperparams: - datamodule_hparams = self.datamodule.hparams_initial - lightning_hparams = self.lightning_module.hparams_initial - inconsistent_keys = [] - for key in lightning_hparams.keys() & datamodule_hparams.keys(): - lm_val, dm_val = lightning_hparams[key], datamodule_hparams[key] - if type(lm_val) != type(dm_val): - inconsistent_keys.append(key) - elif isinstance(lm_val, torch.Tensor) and id(lm_val) != id(dm_val): - inconsistent_keys.append(key) - elif lm_val != dm_val: - inconsistent_keys.append(key) - if inconsistent_keys: - raise MisconfigurationException( - f"Error while merging hparams: the keys {inconsistent_keys} are present " - "in both the LightningModule's and LightningDataModule's hparams " - "but have different values." - ) - hparams_initial = {**lightning_hparams, **datamodule_hparams} - elif self.lightning_module._log_hyperparams: - hparams_initial = self.lightning_module.hparams_initial - elif datamodule_log_hyperparams: - hparams_initial = self.datamodule.hparams_initial + if self.loggers: + if self.lightning_module._log_hyperparams and datamodule_log_hyperparams: + datamodule_hparams = self.datamodule.hparams_initial + lightning_hparams = self.lightning_module.hparams_initial + inconsistent_keys = [] + for key in lightning_hparams.keys() & datamodule_hparams.keys(): + lm_val, dm_val = lightning_hparams[key], datamodule_hparams[key] + if type(lm_val) != type(dm_val): + inconsistent_keys.append(key) + elif isinstance(lm_val, torch.Tensor) and id(lm_val) != id(dm_val): + inconsistent_keys.append(key) + elif lm_val != dm_val: + inconsistent_keys.append(key) + if inconsistent_keys: + raise MisconfigurationException( + f"Error while merging hparams: the keys {inconsistent_keys} are present " + "in both the LightningModule's and LightningDataModule's hparams " + "but have different values." + ) + hparams_initial = {**lightning_hparams, **datamodule_hparams} + elif self.lightning_module._log_hyperparams: + hparams_initial = self.lightning_module.hparams_initial + elif datamodule_log_hyperparams: + hparams_initial = self.datamodule.hparams_initial for logger in self.loggers: if hparams_initial is not None: From 44b2fb4662631338112a3ed2ee16f5d1fc6a7e09 Mon Sep 17 00:00:00 2001 From: Akash Kwatra Date: Tue, 15 Feb 2022 11:43:53 -0800 Subject: [PATCH 04/22] fix bug in log_hyperparams --- pytorch_lightning/callbacks/model_checkpoint.py | 1 + pytorch_lightning/trainer/trainer.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 48454bb7e024e..0cb36b54ad26d 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -592,6 +592,7 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None: if isinstance(trainer.logger.version, str) else f"version_{trainer.logger.version}" ) + # TODO: Find out what ckpt_path should be with multiple loggers ckpt_path = os.path.join(save_dir, str(trainer.logger.name), version, "checkpoints") else: ckpt_path = os.path.join(trainer.weights_save_path, "checkpoints") diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f7b172436eb67..72b6d311e0272 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1242,9 +1242,9 @@ def _log_hyperparams(self) -> None: for logger in self.loggers: if hparams_initial is not None: - self.logger.log_hyperparams(hparams_initial) - self.logger.log_graph(self.lightning_module) - self.logger.save() + logger.log_hyperparams(hparams_initial) + logger.log_graph(self.lightning_module) + logger.save() def _teardown(self): """This is the Trainer's internal teardown, unrelated to the `teardown` hooks in LightningModule and From bc28576753fc41e097d41dea820e518f2ac396e7 Mon Sep 17 00:00:00 2001 From: Akash Kwatra Date: Tue, 15 Feb 2022 11:52:04 -0800 Subject: [PATCH 05/22] Resolve a TODO --- pytorch_lightning/callbacks/model_checkpoint.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 0cb36b54ad26d..4f2a5425f426b 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -579,13 +579,12 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None: """ if self.dirpath is not None: return # short circuit - # TODO: Adapt for trainer.loggers - if trainer.logger: + if trainer.loggers: if trainer.weights_save_path != trainer.default_root_dir: # the user has changed weights_save_path, it overrides anything save_dir = trainer.weights_save_path else: - save_dir = trainer.logger.save_dir or trainer.default_root_dir + save_dir = trainer.logger.save_dir if len(trainer.loggers) == 1 else trainer.default_root_dir version = ( trainer.logger.version From 649260d0e216a5e4b4bbd98eb66fd09e2ab130e8 Mon Sep 17 00:00:00 2001 From: Akash Kwatra Date: Tue, 15 Feb 2022 15:28:01 -0800 Subject: [PATCH 06/22] Possible fixes for todos --- .../callbacks/model_checkpoint.py | 24 ++++++++++++------- pytorch_lightning/callbacks/progress/base.py | 6 ++--- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 4f2a5425f426b..78a22995e76fe 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -584,15 +584,21 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None: # the user has changed weights_save_path, it overrides anything save_dir = trainer.weights_save_path else: - save_dir = trainer.logger.save_dir if len(trainer.loggers) == 1 else trainer.default_root_dir - - version = ( - trainer.logger.version - if isinstance(trainer.logger.version, str) - else f"version_{trainer.logger.version}" - ) - # TODO: Find out what ckpt_path should be with multiple loggers - ckpt_path = os.path.join(save_dir, str(trainer.logger.name), version, "checkpoints") + if len(trainer.loggers) == 1: + save_dir = trainer.logger.save_dir or trainer.default_root_dir + else: + save_dir = trainer.default_root_dir + + if len(trainer.loggers) == 1: + version = ( + trainer.logger.version + if isinstance(trainer.logger.version, str) + else f"version_{trainer.logger.version}" + ) + # TODO: Find out what ckpt_path should be with multiple loggers + ckpt_path = os.path.join(save_dir, str(trainer.logger.name), version, "checkpoints") + else: + ckpt_path = os.path.join(save_dir, "checkpoints") else: ckpt_path = os.path.join(trainer.weights_save_path, "checkpoints") diff --git a/pytorch_lightning/callbacks/progress/base.py b/pytorch_lightning/callbacks/progress/base.py index ed874878f3ee9..7b0a6e5a3685b 100644 --- a/pytorch_lightning/callbacks/progress/base.py +++ b/pytorch_lightning/callbacks/progress/base.py @@ -213,9 +213,9 @@ def get_standard_metrics(trainer: "pl.Trainer", pl_module: "pl.LightningModule") if pl_module.truncated_bptt_steps > 0: items_dict["split_idx"] = trainer.fit_loop.split_idx - # TODO: Adapt for trainer.loggers - if trainer.logger is not None and trainer.logger.version is not None: - version = trainer.logger.version + # TODO: Find out if this is the correct approach + if len(trainer.loggers) == 1 and trainer.loggers[0].version is not None: + version = trainer.loggers[0].version if isinstance(version, str): # show last 4 places of long version strings version = version[-4:] From ed7bdc65e6b6c52f80df38bb2c32634f5437cbfd Mon Sep 17 00:00:00 2001 From: Akash Kwatra Date: Wed, 16 Feb 2022 14:57:38 -0800 Subject: [PATCH 07/22] fix typo --- pytorch_lightning/plugins/precision/precision_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 1d365a5274efe..3af7f703e6f22 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -158,7 +158,7 @@ def _track_grad_norm(self, trainer: "pl.Trainer") -> None: if not trainer.loggers: kwargs = {} - elif len(trainer.loggers) > 1: + elif len(trainer.loggers) == 1: kwargs = {"group_separator": trainer.loggers[0].group_separator} else: kwargs = {"group_separator": "/"} From 231a8c5756ac9b239614e40c0abbc5fc3fe2b489 Mon Sep 17 00:00:00 2001 From: Akash Kwatra Date: Fri, 18 Feb 2022 13:07:27 -0800 Subject: [PATCH 08/22] Remove and insert asserts for mypy as needed --- pytorch_lightning/callbacks/gpu_stats_monitor.py | 2 -- pytorch_lightning/callbacks/lr_monitor.py | 2 -- pytorch_lightning/callbacks/progress/base.py | 15 ++++++++------- .../plugins/precision/precision_plugin.py | 3 ++- 4 files changed, 10 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/callbacks/gpu_stats_monitor.py b/pytorch_lightning/callbacks/gpu_stats_monitor.py index 6c6f05238c6d9..7552e3a5c2a2b 100644 --- a/pytorch_lightning/callbacks/gpu_stats_monitor.py +++ b/pytorch_lightning/callbacks/gpu_stats_monitor.py @@ -162,7 +162,6 @@ def on_train_batch_start( # First log at beginning of second step logs["batch_time/inter_step (ms)"] = (time.time() - self._snap_inter_step_time) * 1000 - assert trainer.loggers for logger in trainer.loggers: logger.log_metrics(logs, step=trainer.global_step) @@ -188,7 +187,6 @@ def on_train_batch_end( if self._log_stats.intra_step_time and self._snap_intra_step_time: logs["batch_time/intra_step (ms)"] = (time.time() - self._snap_intra_step_time) * 1000 - assert trainer.loggers for logger in trainer.loggers: logger.log_metrics(logs, step=trainer.global_step) diff --git a/pytorch_lightning/callbacks/lr_monitor.py b/pytorch_lightning/callbacks/lr_monitor.py index 25dff2418c72b..00ff007af5e41 100644 --- a/pytorch_lightning/callbacks/lr_monitor.py +++ b/pytorch_lightning/callbacks/lr_monitor.py @@ -149,7 +149,6 @@ def _check_no_key(key: str) -> bool: self.last_momentum_values = {name + "-momentum": None for name in names_flatten} def on_train_batch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: - assert trainer.loggers if not trainer.logger_connector.should_update_logs: return @@ -162,7 +161,6 @@ def on_train_batch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) logger.log_metrics(latest_stat, step=trainer.global_step) def on_train_epoch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: - assert trainer.loggers if self.logging_interval != "step": interval = "epoch" if self.logging_interval is None else "any" latest_stat = self._extract_stats(trainer, interval) diff --git a/pytorch_lightning/callbacks/progress/base.py b/pytorch_lightning/callbacks/progress/base.py index 7b0a6e5a3685b..73d5f535af2ee 100644 --- a/pytorch_lightning/callbacks/progress/base.py +++ b/pytorch_lightning/callbacks/progress/base.py @@ -213,12 +213,13 @@ def get_standard_metrics(trainer: "pl.Trainer", pl_module: "pl.LightningModule") if pl_module.truncated_bptt_steps > 0: items_dict["split_idx"] = trainer.fit_loop.split_idx - # TODO: Find out if this is the correct approach - if len(trainer.loggers) == 1 and trainer.loggers[0].version is not None: - version = trainer.loggers[0].version - if isinstance(version, str): - # show last 4 places of long version strings - version = version[-4:] - items_dict["v_num"] = version + if len(trainer.loggers) == 1: + assert trainer.logger + version = trainer.logger.version + if version is not None: + if isinstance(version, str): + # show last 4 places of long version strings + version = version[-4:] + items_dict["v_num"] = version return items_dict diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 3af7f703e6f22..38d70dcafa076 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -159,7 +159,8 @@ def _track_grad_norm(self, trainer: "pl.Trainer") -> None: if not trainer.loggers: kwargs = {} elif len(trainer.loggers) == 1: - kwargs = {"group_separator": trainer.loggers[0].group_separator} + assert trainer.logger + kwargs = {"group_separator": trainer.logger.group_separator} else: kwargs = {"group_separator": "/"} From d347d94f4b5102ab31d13bff657114294fc9295a Mon Sep 17 00:00:00 2001 From: Akash Kwatra Date: Fri, 18 Feb 2022 13:41:20 -0800 Subject: [PATCH 09/22] Remove TODO and fix mypy error --- pytorch_lightning/callbacks/model_checkpoint.py | 1 - .../connectors/logger_connector/logger_connector.py | 9 ++++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 78a22995e76fe..f910b7dd81da7 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -595,7 +595,6 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None: if isinstance(trainer.logger.version, str) else f"version_{trainer.logger.version}" ) - # TODO: Find out what ckpt_path should be with multiple loggers ckpt_path = os.path.join(save_dir, str(trainer.logger.name), version, "checkpoints") else: ckpt_path = os.path.join(save_dir, "checkpoints") diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 66a58ee61e049..ded301621dfda 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -125,13 +125,12 @@ def log_metrics(self, metrics: _OUT_DICT, step: Optional[int] = None) -> None: step = self.trainer.global_step # log actual metrics - if self._override_agg_and_log_metrics: - for logger in self.trainer.loggers: + for logger in self.trainer.loggers: + if self._override_agg_and_log_metrics: logger.agg_and_log_metrics(metrics=scalar_metrics, step=step) - else: - for logger in self.trainer.loggers: + else: logger.log_metrics(metrics=scalar_metrics, step=step) - self.trainer.logger.save() + logger.save() """ Evaluation metric updates From 515087ce025bb7151d762b509a90e21549912502 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 19 Feb 2022 00:14:14 +0100 Subject: [PATCH 10/22] documentation update --- pytorch_lightning/core/lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 098956a703a8a..a39b15bfbfdc3 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -255,7 +255,7 @@ def logger(self) -> Optional[LightningLoggerBase]: @property def loggers(self) -> List[LightningLoggerBase]: - """Reference to the loggers object in the Trainer.""" + """Reference to the list of loggers in the Trainer.""" return self.trainer.loggers if self.trainer else [] def _apply_batch_transfer_handler( From 69ebc3146e35286e0f888fcf2797bdbf7240a984 Mon Sep 17 00:00:00 2001 From: Akash Kwatra Date: Fri, 18 Feb 2022 16:37:39 -0800 Subject: [PATCH 11/22] implement suggestions --- pytorch_lightning/callbacks/model_checkpoint.py | 7 +++---- pytorch_lightning/callbacks/progress/base.py | 3 +-- pytorch_lightning/plugins/precision/precision_plugin.py | 3 +-- pytorch_lightning/trainer/trainer.py | 2 +- 4 files changed, 6 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index f910b7dd81da7..b3b216717f939 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -583,11 +583,10 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None: if trainer.weights_save_path != trainer.default_root_dir: # the user has changed weights_save_path, it overrides anything save_dir = trainer.weights_save_path + elif len(trainer.loggers) == 1: + save_dir = trainer.logger.save_dir or trainer.default_root_dir else: - if len(trainer.loggers) == 1: - save_dir = trainer.logger.save_dir or trainer.default_root_dir - else: - save_dir = trainer.default_root_dir + save_dir = trainer.default_root_dir if len(trainer.loggers) == 1: version = ( diff --git a/pytorch_lightning/callbacks/progress/base.py b/pytorch_lightning/callbacks/progress/base.py index 73d5f535af2ee..dd900c4bbdbfd 100644 --- a/pytorch_lightning/callbacks/progress/base.py +++ b/pytorch_lightning/callbacks/progress/base.py @@ -214,8 +214,7 @@ def get_standard_metrics(trainer: "pl.Trainer", pl_module: "pl.LightningModule") items_dict["split_idx"] = trainer.fit_loop.split_idx if len(trainer.loggers) == 1: - assert trainer.logger - version = trainer.logger.version + version = trainer.loggers[0].version if version is not None: if isinstance(version, str): # show last 4 places of long version strings diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 1f43e6c134c14..32d942de46873 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -159,8 +159,7 @@ def _track_grad_norm(self, trainer: "pl.Trainer") -> None: if not trainer.loggers: kwargs = {} elif len(trainer.loggers) == 1: - assert trainer.logger - kwargs = {"group_separator": trainer.logger.group_separator} + kwargs = {"group_separator": trainer.loggers[0].group_separator} else: kwargs = {"group_separator": "/"} diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 22be31b8f297b..bcd0468924541 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1843,7 +1843,7 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) - self.val_check_batch = int(self.num_training_batches * self.val_check_interval) self.val_check_batch = max(1, self.val_check_batch) - if self.logger and self.num_training_batches < self.log_every_n_steps: + if self.loggers and self.num_training_batches < self.log_every_n_steps: rank_zero_warn( f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval" f" Trainer(log_every_n_steps={self.log_every_n_steps}). Set a lower value for log_every_n_steps if" From 6bd8d59028da6a206910bd206912834cbe039349 Mon Sep 17 00:00:00 2001 From: Akash Kwatra Date: Fri, 18 Feb 2022 17:03:52 -0800 Subject: [PATCH 12/22] Put concatenation logic in resolve_ckpt_dir --- .../callbacks/model_checkpoint.py | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index b3b216717f939..efd82943ff69f 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -588,15 +588,19 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None: else: save_dir = trainer.default_root_dir - if len(trainer.loggers) == 1: - version = ( - trainer.logger.version - if isinstance(trainer.logger.version, str) - else f"version_{trainer.logger.version}" - ) - ckpt_path = os.path.join(save_dir, str(trainer.logger.name), version, "checkpoints") - else: - ckpt_path = os.path.join(save_dir, "checkpoints") + name = ( + trainer.logger.name + if len(trainer.loggers) == 1 + else "_".join(dict.fromkeys(str(logger.name) for logger in trainer.loggers)) + ) + version = ( + trainer.logger.version + if len(trainer.loggers) == 1 + else "_".join(dict.fromkeys(str(logger.version) for logger in trainer.loggers)) + ) + version = version if isinstance(version, str) else f"version_{version}" + + ckpt_path = os.path.join(save_dir, str(name), version, "checkpoints") else: ckpt_path = os.path.join(trainer.weights_save_path, "checkpoints") From b722da2411687d69b9b7712549d2d6242c930666 Mon Sep 17 00:00:00 2001 From: Akash Kwatra Date: Wed, 23 Feb 2022 09:36:12 -0800 Subject: [PATCH 13/22] Update pytorch_lightning/plugins/precision/precision_plugin.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/plugins/precision/precision_plugin.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 32d942de46873..5a5606c5f2576 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -156,12 +156,9 @@ def _track_grad_norm(self, trainer: "pl.Trainer") -> None: if trainer.track_grad_norm == -1: return - if not trainer.loggers: - kwargs = {} - elif len(trainer.loggers) == 1: - kwargs = {"group_separator": trainer.loggers[0].group_separator} - else: - kwargs = {"group_separator": "/"} + kwargs = {} + if len(trainer.loggers) == 1: + kwargs["group_separator"] = trainer.loggers[0].group_separator grad_norm_dict = grad_norm(trainer.lightning_module, trainer.track_grad_norm, **kwargs) if grad_norm_dict: From d54ea35277f076100b2ba9dff233f8e5571e7774 Mon Sep 17 00:00:00 2001 From: Akash Kwatra Date: Wed, 23 Feb 2022 09:44:39 -0800 Subject: [PATCH 14/22] Change v_num to pure refactor --- pytorch_lightning/callbacks/model_checkpoint.py | 4 ++-- pytorch_lightning/callbacks/progress/base.py | 7 +++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index efd82943ff69f..bfdab612b8a1e 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -589,12 +589,12 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None: save_dir = trainer.default_root_dir name = ( - trainer.logger.name + trainer.loggers[0].name if len(trainer.loggers) == 1 else "_".join(dict.fromkeys(str(logger.name) for logger in trainer.loggers)) ) version = ( - trainer.logger.version + trainer.loggers[0].version if len(trainer.loggers) == 1 else "_".join(dict.fromkeys(str(logger.version) for logger in trainer.loggers)) ) diff --git a/pytorch_lightning/callbacks/progress/base.py b/pytorch_lightning/callbacks/progress/base.py index dd900c4bbdbfd..f20968c4cfa3e 100644 --- a/pytorch_lightning/callbacks/progress/base.py +++ b/pytorch_lightning/callbacks/progress/base.py @@ -213,8 +213,11 @@ def get_standard_metrics(trainer: "pl.Trainer", pl_module: "pl.LightningModule") if pl_module.truncated_bptt_steps > 0: items_dict["split_idx"] = trainer.fit_loop.split_idx - if len(trainer.loggers) == 1: - version = trainer.loggers[0].version + if trainer.loggers: + version = ( + trainer.loggers[0].version + if len(trainer.loggers) == 1 + else "_".join(dict.fromkeys(str(logger.version) for logger in trainer.loggers)) if version is not None: if isinstance(version, str): # show last 4 places of long version strings From bdd1757416192d4912c2bd61bedb1d47c3fa44bd Mon Sep 17 00:00:00 2001 From: Akash Kwatra Date: Thu, 24 Feb 2022 01:56:39 -0800 Subject: [PATCH 15/22] return early if no loggers --- pytorch_lightning/trainer/trainer.py | 50 +++++++++++++++------------- 1 file changed, 26 insertions(+), 24 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 9b76c2a1428fd..0e3d8b8d20995 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1245,36 +1245,38 @@ def _run( return results def _log_hyperparams(self) -> None: + if not self.loggers: + return + # log hyper-parameters hparams_initial = None # save exp to get started (this is where the first experiment logs are written) datamodule_log_hyperparams = self.datamodule._log_hyperparams if self.datamodule is not None else False - if self.loggers: - if self.lightning_module._log_hyperparams and datamodule_log_hyperparams: - datamodule_hparams = self.datamodule.hparams_initial - lightning_hparams = self.lightning_module.hparams_initial - inconsistent_keys = [] - for key in lightning_hparams.keys() & datamodule_hparams.keys(): - lm_val, dm_val = lightning_hparams[key], datamodule_hparams[key] - if type(lm_val) != type(dm_val): - inconsistent_keys.append(key) - elif isinstance(lm_val, torch.Tensor) and id(lm_val) != id(dm_val): - inconsistent_keys.append(key) - elif lm_val != dm_val: - inconsistent_keys.append(key) - if inconsistent_keys: - raise MisconfigurationException( - f"Error while merging hparams: the keys {inconsistent_keys} are present " - "in both the LightningModule's and LightningDataModule's hparams " - "but have different values." - ) - hparams_initial = {**lightning_hparams, **datamodule_hparams} - elif self.lightning_module._log_hyperparams: - hparams_initial = self.lightning_module.hparams_initial - elif datamodule_log_hyperparams: - hparams_initial = self.datamodule.hparams_initial + if self.lightning_module._log_hyperparams and datamodule_log_hyperparams: + datamodule_hparams = self.datamodule.hparams_initial + lightning_hparams = self.lightning_module.hparams_initial + inconsistent_keys = [] + for key in lightning_hparams.keys() & datamodule_hparams.keys(): + lm_val, dm_val = lightning_hparams[key], datamodule_hparams[key] + if type(lm_val) != type(dm_val): + inconsistent_keys.append(key) + elif isinstance(lm_val, torch.Tensor) and id(lm_val) != id(dm_val): + inconsistent_keys.append(key) + elif lm_val != dm_val: + inconsistent_keys.append(key) + if inconsistent_keys: + raise MisconfigurationException( + f"Error while merging hparams: the keys {inconsistent_keys} are present " + "in both the LightningModule's and LightningDataModule's hparams " + "but have different values." + ) + hparams_initial = {**lightning_hparams, **datamodule_hparams} + elif self.lightning_module._log_hyperparams: + hparams_initial = self.lightning_module.hparams_initial + elif datamodule_log_hyperparams: + hparams_initial = self.datamodule.hparams_initial for logger in self.loggers: if hparams_initial is not None: From ef4d520ee8ac1b66b88eda5450b70d3e039cebd3 Mon Sep 17 00:00:00 2001 From: Akash Kwatra Date: Thu, 24 Feb 2022 18:41:07 -0800 Subject: [PATCH 16/22] Revert "return early if no loggers" This reverts commit bdd1757416192d4912c2bd61bedb1d47c3fa44bd. --- pytorch_lightning/trainer/trainer.py | 50 +++++++++++++--------------- 1 file changed, 24 insertions(+), 26 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 0e3d8b8d20995..9b76c2a1428fd 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1245,38 +1245,36 @@ def _run( return results def _log_hyperparams(self) -> None: - if not self.loggers: - return - # log hyper-parameters hparams_initial = None # save exp to get started (this is where the first experiment logs are written) datamodule_log_hyperparams = self.datamodule._log_hyperparams if self.datamodule is not None else False - if self.lightning_module._log_hyperparams and datamodule_log_hyperparams: - datamodule_hparams = self.datamodule.hparams_initial - lightning_hparams = self.lightning_module.hparams_initial - inconsistent_keys = [] - for key in lightning_hparams.keys() & datamodule_hparams.keys(): - lm_val, dm_val = lightning_hparams[key], datamodule_hparams[key] - if type(lm_val) != type(dm_val): - inconsistent_keys.append(key) - elif isinstance(lm_val, torch.Tensor) and id(lm_val) != id(dm_val): - inconsistent_keys.append(key) - elif lm_val != dm_val: - inconsistent_keys.append(key) - if inconsistent_keys: - raise MisconfigurationException( - f"Error while merging hparams: the keys {inconsistent_keys} are present " - "in both the LightningModule's and LightningDataModule's hparams " - "but have different values." - ) - hparams_initial = {**lightning_hparams, **datamodule_hparams} - elif self.lightning_module._log_hyperparams: - hparams_initial = self.lightning_module.hparams_initial - elif datamodule_log_hyperparams: - hparams_initial = self.datamodule.hparams_initial + if self.loggers: + if self.lightning_module._log_hyperparams and datamodule_log_hyperparams: + datamodule_hparams = self.datamodule.hparams_initial + lightning_hparams = self.lightning_module.hparams_initial + inconsistent_keys = [] + for key in lightning_hparams.keys() & datamodule_hparams.keys(): + lm_val, dm_val = lightning_hparams[key], datamodule_hparams[key] + if type(lm_val) != type(dm_val): + inconsistent_keys.append(key) + elif isinstance(lm_val, torch.Tensor) and id(lm_val) != id(dm_val): + inconsistent_keys.append(key) + elif lm_val != dm_val: + inconsistent_keys.append(key) + if inconsistent_keys: + raise MisconfigurationException( + f"Error while merging hparams: the keys {inconsistent_keys} are present " + "in both the LightningModule's and LightningDataModule's hparams " + "but have different values." + ) + hparams_initial = {**lightning_hparams, **datamodule_hparams} + elif self.lightning_module._log_hyperparams: + hparams_initial = self.lightning_module.hparams_initial + elif datamodule_log_hyperparams: + hparams_initial = self.datamodule.hparams_initial for logger in self.loggers: if hparams_initial is not None: From 17caf7683fbd61f69c6943429f12fecfdec6bbcb Mon Sep 17 00:00:00 2001 From: Akash Kwatra Date: Thu, 24 Feb 2022 19:07:56 -0800 Subject: [PATCH 17/22] Changes based on suggestions --- pl_examples/domain_templates/generative_adversarial_net.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_examples/domain_templates/generative_adversarial_net.py b/pl_examples/domain_templates/generative_adversarial_net.py index a1637c2dfde04..cef2107550bdb 100644 --- a/pl_examples/domain_templates/generative_adversarial_net.py +++ b/pl_examples/domain_templates/generative_adversarial_net.py @@ -207,7 +207,7 @@ def on_train_epoch_end(self): sample_imgs = self(z) grid = torchvision.utils.make_grid(sample_imgs) for logger in self.loggers: - self.logger.experiment.add_image("generated_images", grid, self.current_epoch) + logger.experiment.add_image("generated_images", grid, self.current_epoch) def main(args: Namespace) -> None: From dfe078fb06afe7654fa96f0eec67757c2eba32e8 Mon Sep 17 00:00:00 2001 From: Akash Kwatra Date: Fri, 25 Feb 2022 12:07:46 -0800 Subject: [PATCH 18/22] Move _name and _version to utility file --- .../callbacks/model_checkpoint.py | 13 ++--- pytorch_lightning/callbacks/progress/base.py | 7 +-- pytorch_lightning/trainer/trainer.py | 49 ++++++++++--------- pytorch_lightning/utilities/logger.py | 18 +++++++ 4 files changed, 48 insertions(+), 39 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index bfdab612b8a1e..79756c0477654 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -35,6 +35,7 @@ from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.logger import _name, _version from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.types import _METRIC, _PATH, STEP_OUTPUT from pytorch_lightning.utilities.warnings import WarningCache @@ -588,16 +589,8 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None: else: save_dir = trainer.default_root_dir - name = ( - trainer.loggers[0].name - if len(trainer.loggers) == 1 - else "_".join(dict.fromkeys(str(logger.name) for logger in trainer.loggers)) - ) - version = ( - trainer.loggers[0].version - if len(trainer.loggers) == 1 - else "_".join(dict.fromkeys(str(logger.version) for logger in trainer.loggers)) - ) + name = _name(trainer.loggers) + version = _version(trainer.loggers) version = version if isinstance(version, str) else f"version_{version}" ckpt_path = os.path.join(save_dir, str(name), version, "checkpoints") diff --git a/pytorch_lightning/callbacks/progress/base.py b/pytorch_lightning/callbacks/progress/base.py index 7892fc0e77781..3ee1f83a547c5 100644 --- a/pytorch_lightning/callbacks/progress/base.py +++ b/pytorch_lightning/callbacks/progress/base.py @@ -15,6 +15,7 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks import Callback +from pytorch_lightning.utilities.logger import _version from pytorch_lightning.utilities.rank_zero import rank_zero_warn @@ -214,11 +215,7 @@ def get_standard_metrics(trainer: "pl.Trainer", pl_module: "pl.LightningModule") items_dict["split_idx"] = trainer.fit_loop.split_idx if trainer.loggers: - version = ( - trainer.loggers[0].version - if len(trainer.loggers) == 1 - else "_".join(dict.fromkeys(str(logger.version) for logger in trainer.loggers)) - ) + version = _version(trainer.loggers) if version is not None: if isinstance(version, str): # show last 4 places of long version strings diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 9b76c2a1428fd..20a5951e24535 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1245,36 +1245,37 @@ def _run( return results def _log_hyperparams(self) -> None: + if not self.loggers: + return # log hyper-parameters hparams_initial = None # save exp to get started (this is where the first experiment logs are written) datamodule_log_hyperparams = self.datamodule._log_hyperparams if self.datamodule is not None else False - if self.loggers: - if self.lightning_module._log_hyperparams and datamodule_log_hyperparams: - datamodule_hparams = self.datamodule.hparams_initial - lightning_hparams = self.lightning_module.hparams_initial - inconsistent_keys = [] - for key in lightning_hparams.keys() & datamodule_hparams.keys(): - lm_val, dm_val = lightning_hparams[key], datamodule_hparams[key] - if type(lm_val) != type(dm_val): - inconsistent_keys.append(key) - elif isinstance(lm_val, torch.Tensor) and id(lm_val) != id(dm_val): - inconsistent_keys.append(key) - elif lm_val != dm_val: - inconsistent_keys.append(key) - if inconsistent_keys: - raise MisconfigurationException( - f"Error while merging hparams: the keys {inconsistent_keys} are present " - "in both the LightningModule's and LightningDataModule's hparams " - "but have different values." - ) - hparams_initial = {**lightning_hparams, **datamodule_hparams} - elif self.lightning_module._log_hyperparams: - hparams_initial = self.lightning_module.hparams_initial - elif datamodule_log_hyperparams: - hparams_initial = self.datamodule.hparams_initial + if self.lightning_module._log_hyperparams and datamodule_log_hyperparams: + datamodule_hparams = self.datamodule.hparams_initial + lightning_hparams = self.lightning_module.hparams_initial + inconsistent_keys = [] + for key in lightning_hparams.keys() & datamodule_hparams.keys(): + lm_val, dm_val = lightning_hparams[key], datamodule_hparams[key] + if type(lm_val) != type(dm_val): + inconsistent_keys.append(key) + elif isinstance(lm_val, torch.Tensor) and id(lm_val) != id(dm_val): + inconsistent_keys.append(key) + elif lm_val != dm_val: + inconsistent_keys.append(key) + if inconsistent_keys: + raise MisconfigurationException( + f"Error while merging hparams: the keys {inconsistent_keys} are present " + "in both the LightningModule's and LightningDataModule's hparams " + "but have different values." + ) + hparams_initial = {**lightning_hparams, **datamodule_hparams} + elif self.lightning_module._log_hyperparams: + hparams_initial = self.lightning_module.hparams_initial + elif datamodule_log_hyperparams: + hparams_initial = self.datamodule.hparams_initial for logger in self.loggers: if hparams_initial is not None: diff --git a/pytorch_lightning/utilities/logger.py b/pytorch_lightning/utilities/logger.py index a66582fd8466f..578fbefcd0566 100644 --- a/pytorch_lightning/utilities/logger.py +++ b/pytorch_lightning/utilities/logger.py @@ -19,6 +19,8 @@ import numpy as np import torch +from pytorch_lightning.loggers import LightningLoggerBase + def _convert_params(params: Union[Dict[str, Any], Namespace]) -> Dict[str, Any]: """Ensure parameters are a dict or convert to dict if necessary. @@ -146,3 +148,19 @@ def _add_prefix(metrics: Dict[str, float], prefix: str, separator: str) -> Dict[ metrics = {f"{prefix}{separator}{k}": v for k, v in metrics.items()} return metrics + + +def _name(loggers: List[LightningLoggerBase]) -> str: + if len(loggers) == 1: + return loggers[0].name + else: + # Concatenate names together, removing duplicates and preserving order + return "_".join(dict.fromkeys(str(logger.name) for logger in loggers)) + + +def _version(loggers: List[LightningLoggerBase]) -> Union[int, str]: + if len(loggers) == 1: + return loggers[0].version + else: + # Concatenate versions together, removing duplicates and preserving order + return "_".join(dict.fromkeys(str(logger.version) for logger in loggers)) From 6cf1eb411d12c3aa0527899dfef44af8af8c60fb Mon Sep 17 00:00:00 2001 From: Akash Kwatra Date: Fri, 25 Feb 2022 12:23:15 -0800 Subject: [PATCH 19/22] tmp fix to run unit tests --- pytorch_lightning/utilities/logger.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/utilities/logger.py b/pytorch_lightning/utilities/logger.py index 578fbefcd0566..2bb541819eb1a 100644 --- a/pytorch_lightning/utilities/logger.py +++ b/pytorch_lightning/utilities/logger.py @@ -19,8 +19,6 @@ import numpy as np import torch -from pytorch_lightning.loggers import LightningLoggerBase - def _convert_params(params: Union[Dict[str, Any], Namespace]) -> Dict[str, Any]: """Ensure parameters are a dict or convert to dict if necessary. @@ -150,7 +148,7 @@ def _add_prefix(metrics: Dict[str, float], prefix: str, separator: str) -> Dict[ return metrics -def _name(loggers: List[LightningLoggerBase]) -> str: +def _name(loggers: List[Any]) -> str: if len(loggers) == 1: return loggers[0].name else: @@ -158,7 +156,7 @@ def _name(loggers: List[LightningLoggerBase]) -> str: return "_".join(dict.fromkeys(str(logger.name) for logger in loggers)) -def _version(loggers: List[LightningLoggerBase]) -> Union[int, str]: +def _version(loggers: List[Any]) -> Union[int, str]: if len(loggers) == 1: return loggers[0].version else: From c277c90f12ce901d7e27cf3deb54706a9fe3364f Mon Sep 17 00:00:00 2001 From: Akash Kwatra Date: Fri, 25 Feb 2022 13:55:19 -0800 Subject: [PATCH 20/22] Add unit tests --- pytorch_lightning/utilities/logger.py | 8 +++--- tests/utilities/test_logger.py | 37 +++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/utilities/logger.py b/pytorch_lightning/utilities/logger.py index 2bb541819eb1a..ef27761a2ec6e 100644 --- a/pytorch_lightning/utilities/logger.py +++ b/pytorch_lightning/utilities/logger.py @@ -148,17 +148,17 @@ def _add_prefix(metrics: Dict[str, float], prefix: str, separator: str) -> Dict[ return metrics -def _name(loggers: List[Any]) -> str: +def _name(loggers: List[Any], separator: str = "_") -> str: if len(loggers) == 1: return loggers[0].name else: # Concatenate names together, removing duplicates and preserving order - return "_".join(dict.fromkeys(str(logger.name) for logger in loggers)) + return separator.join(dict.fromkeys(str(logger.name) for logger in loggers)) -def _version(loggers: List[Any]) -> Union[int, str]: +def _version(loggers: List[Any], separator: str = "_") -> Union[int, str]: if len(loggers) == 1: return loggers[0].version else: # Concatenate versions together, removing duplicates and preserving order - return "_".join(dict.fromkeys(str(logger.version) for logger in loggers)) + return separator.join(dict.fromkeys(str(logger.version) for logger in loggers)) diff --git a/tests/utilities/test_logger.py b/tests/utilities/test_logger.py index 8d9b495fb96bf..02d2795defb50 100644 --- a/tests/utilities/test_logger.py +++ b/tests/utilities/test_logger.py @@ -17,12 +17,15 @@ import torch from pytorch_lightning import Trainer +from pytorch_lightning.loggers import CSVLogger from pytorch_lightning.utilities.logger import ( _add_prefix, _convert_params, _flatten_dict, + _name, _sanitize_callable_params, _sanitize_params, + _version, ) @@ -172,3 +175,37 @@ def test_add_prefix(): assert "prefix-metric2" not in metrics assert metrics["prefix2_prefix-metric1"] == 1 assert metrics["prefix2_prefix-metric2"] == 2 + + +def test_name(tmpdir): + """Verify names of loggers are concatenated properly.""" + logger1 = CSVLogger(tmpdir, name="foo") + logger2 = CSVLogger(tmpdir, name="bar") + logger3 = CSVLogger(tmpdir, name="foo") + logger4 = CSVLogger(tmpdir, name="baz") + loggers = [logger1, logger2, logger3, logger4] + name = _name([]) + assert name == "" + name = _name([logger3]) + assert name == "foo" + name = _name(loggers) + assert name == "foo_bar_baz" + name = _name(loggers, "-") + assert name == "foo-bar-baz" + + +def test_version(tmpdir): + """Verify names of loggers are concatenated properly.""" + logger1 = CSVLogger(tmpdir, version=0) + logger2 = CSVLogger(tmpdir, version=2) + logger3 = CSVLogger(tmpdir, version=1) + logger4 = CSVLogger(tmpdir, version=0) + loggers = [logger1, logger2, logger3, logger4] + version = _version([]) + assert version == "" + version = _version([logger3]) + assert version == 1 + version = _version(loggers) + assert version == "0_2_1" + version = _version(loggers, "-") + assert version == "0-2-1" From 961faef50d7b8f2feb9ed472a1fe2e9301f67bac Mon Sep 17 00:00:00 2001 From: Akash Kwatra Date: Fri, 25 Feb 2022 14:40:47 -0800 Subject: [PATCH 21/22] fix typo --- pytorch_lightning/loops/epoch/training_epoch_loop.py | 1 + tests/utilities/test_logger.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index d1c12ebb69e37..87216944893ef 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -504,6 +504,7 @@ def _save_loggers_on_train_batch_end(self) -> None: """Flushes loggers to disk.""" # when loggers should save to disk should_flush_logs = self.trainer.logger_connector.should_flush_logs + # TODO: is_global_zero check should be moved to logger.save() implementation if should_flush_logs and self.trainer.is_global_zero: for logger in self.trainer.loggers: logger.save() diff --git a/tests/utilities/test_logger.py b/tests/utilities/test_logger.py index 02d2795defb50..6b67272289fc3 100644 --- a/tests/utilities/test_logger.py +++ b/tests/utilities/test_logger.py @@ -195,7 +195,7 @@ def test_name(tmpdir): def test_version(tmpdir): - """Verify names of loggers are concatenated properly.""" + """Verify versions of loggers are concatenated properly.""" logger1 = CSVLogger(tmpdir, version=0) logger2 = CSVLogger(tmpdir, version=2) logger3 = CSVLogger(tmpdir, version=1) From 15b378d2f5629dbff2eb24642fe7970307e4bfdf Mon Sep 17 00:00:00 2001 From: Akash Kwatra Date: Fri, 25 Feb 2022 14:56:37 -0800 Subject: [PATCH 22/22] Fix for failing unit test --- tests/checkpointing/test_model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index a5c3aae5b14d7..2c65426534aff 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -1273,7 +1273,7 @@ def test_none_monitor_saves_correct_best_model_path(tmpdir): def test_last_global_step_saved(): # this should not save anything model_checkpoint = ModelCheckpoint(save_top_k=0, save_last=False, monitor="foo") - trainer = Mock() + trainer = MagicMock() trainer.callback_metrics = {"foo": 123} model_checkpoint.save_checkpoint(trainer) assert model_checkpoint._last_global_step_saved == -1