Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New logger connector code #7882

Merged
merged 26 commits into from
Jun 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Refactored logging
* Renamed and moved `core/step_result.py` to `trainer/connectors/logger_connector/result.py` ([#7736](https://github.com/PyTorchLightning/pytorch-lightning/pull/7736))

* Dramatically simplify the `LoggerConnector` ([#7882](https://github.com/PyTorchLightning/pytorch-lightning/pull/7882))
* `trainer.{logged,progress_bar,callback}_metrics` are now updated on-demand ([#7882](https://github.com/PyTorchLightning/pytorch-lightning/pull/7882))
* Completely overhaul the `Result` object in favor of `ResultMetric` ([#7882](https://github.com/PyTorchLightning/pytorch-lightning/pull/7882))
* Improve epoch-level reduction time and overall memory usage ([#7882](https://github.com/PyTorchLightning/pytorch-lightning/pull/7882))

- Moved `ignore_scalar_return_in_dp` warning suppression to the DataParallelPlugin class ([#7421](https://github.com/PyTorchLightning/pytorch-lightning/pull/7421/))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def should_update_logs(self) -> bool:
should_log_every_n_steps = (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0
return should_log_every_n_steps or self.trainer.should_stop

def configure_logger(self, logger: LightningLoggerBase) -> None:
def configure_logger(self, logger: Union[bool, Iterable, LightningLoggerBase]) -> None:
if logger is True:
version = os.environ.get('PL_EXP_VERSION', self.trainer.slurm_job_id)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,311 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from pprint import pprint
from typing import Any, Dict, Iterable, Mapping, Optional, Union

import torch

import pytorch_lightning as pl
from pytorch_lightning.core import memory
from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger
from pytorch_lightning.trainer.connectors.logger_connector.result_new import _METRIC, MetricSource
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
from pytorch_lightning.utilities import DeviceType
from pytorch_lightning.utilities.metrics import metrics_to_scalars
from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT


# TODO(@carmocca): Remove `New` suffix
class LoggerConnectorNew:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super nit: for renamings i've always been a fan of _v2 instead of New in the worst case circumstance that there are 3 or more versions around at the same time

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the tip! Will keep it in mind for future refactors


def __init__(self, trainer: 'pl.Trainer', log_gpu_memory: Optional[str] = None) -> None:
self.trainer = trainer
self.log_gpu_memory = log_gpu_memory
self.eval_loop_results = []
self._val_log_step: int = 0
self._test_log_step: int = 0
self._progress_bar_metrics: Dict[str, float] = {}
self._logged_metrics: Dict[str, _METRIC] = {}
self._callback_metrics: Dict[str, _METRIC] = {}
self._epoch_end_reached = False
self._current_fx: Optional[str] = None
self._batch_idx: Optional[int] = None
self._split_idx: Optional[int] = None

def on_trainer_init(
self,
logger: LightningLoggerBase,
flush_logs_every_n_steps: int,
log_every_n_steps: int,
move_metrics_to_cpu: bool,
) -> None:
self.configure_logger(logger)
self.trainer.flush_logs_every_n_steps = flush_logs_every_n_steps
self.trainer.log_every_n_steps = log_every_n_steps
self.trainer.move_metrics_to_cpu = move_metrics_to_cpu
carmocca marked this conversation as resolved.
Show resolved Hide resolved

@property
def should_flush_logs(self) -> bool:
should_flush = (self.trainer.global_step + 1) % self.trainer.flush_logs_every_n_steps == 0
return should_flush or self.trainer.should_stop

@property
def should_update_logs(self) -> bool:
should_log_every_n_steps = (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0
return should_log_every_n_steps or self.trainer.should_stop

def configure_logger(self, logger: Union[bool, Iterable, LightningLoggerBase]) -> None:
if logger is True:
version = os.environ.get('PL_EXP_VERSION', self.trainer.slurm_job_id)

# default logger
self.trainer.logger = TensorBoardLogger(
save_dir=self.trainer.default_root_dir, version=version, name='lightning_logs'
)
elif logger is False:
self.trainer.logger = None
else:
if isinstance(logger, Iterable):
self.trainer.logger = LoggerCollection(logger)
else:
self.trainer.logger = logger

def log_metrics(self, metrics: Dict[str, _METRIC], step: Optional[int] = None) -> None:
"""Logs the metric dict passed in.
If `step` parameter is None and `step` key is presented is metrics,
uses metrics["step"] as a step

Args:
metrics: Metric values
step: Step for which metrics should be logged. Default value is `self.global_step` during training or
the total validation / test log step count during validation and testing.
"""
# add gpu memory
if self.trainer._device_type == DeviceType.GPU and self.log_gpu_memory:
mem_map = memory.get_memory_profile(self.log_gpu_memory)
metrics.update(mem_map)

# turn all tensors to scalars
scalar_metrics = metrics_to_scalars(metrics)

if "step" in scalar_metrics and step is None:
step = scalar_metrics.pop("step")

elif step is None:
# added metrics by Lightning for convenience
scalar_metrics['epoch'] = self.trainer.current_epoch
step = self.trainer.global_step
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a lot of Codecov warnings.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is expected considering we haven't integrated the loops to use these new files.

You can take a peek at #7631 for the full integration

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR includes a few unit tests for the code in Result but LoggerConnectorNew is completely unused

Comment on lines +107 to +109
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n00b questions:

  • why does step indicate we should epoch to scalar metrics?
  • what happens if epoch was already in scalar metrics as another value? if the user does self.log("epoch", ...) ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does step indicate we should epoch to scalar metrics?

Didn't get this, can you rephrase?

what happens if epoch was already in scalar metrics as another value?

Well, as you see here, it gets overwritten.

It wouldn't be a bad idea to check if it's there already. Will do in a future PR


# log actual metrics
if self.trainer.logger is not None:
if self.trainer.is_global_zero:
self.trainer.logger.agg_and_log_metrics(scalar_metrics, step=step)
self.trainer.logger.save()

self._logged_metrics.update(scalar_metrics)

"""
Evaluation metric updates
"""

@property
def _eval_log_step(self) -> Optional[int]:
if self.trainer.state.stage is RunningStage.VALIDATING:
return self._val_log_step
elif self.trainer.state.stage is RunningStage.TESTING:
return self._test_log_step
else:
return None

def _increment_eval_log_step(self) -> None:
if self.trainer.state.stage is RunningStage.VALIDATING:
self._val_log_step += 1
elif self.trainer.state.stage is RunningStage.TESTING:
self._test_log_step += 1

def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int, num_dataloaders: int) -> None:
model = self.trainer.lightning_module
# set dataloader_idx only if multiple ones
model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None

# track batch_size
self.trainer.result_collection.extract_batch_size(batch)
self._batch_idx = batch_idx

def update_eval_step_metrics(self) -> None:
if self.trainer.sanity_checking:
return

# logs user requested information to logger
assert not self._epoch_end_reached
metrics = self.metrics[MetricSource.LOG]
if metrics:
self.log_metrics(metrics, step=self._eval_log_step)

# increment the step even if nothing was logged
self._increment_eval_log_step()

def _prepare_eval_loop_results(self, metrics: Mapping[str, _METRIC]) -> None:
if self.trainer.sanity_checking:
return

num_dataloaders = self.trainer.evaluation_loop.num_dataloaders
has_been_initialized = len(self.eval_loop_results) == num_dataloaders
for dl_idx in range(self.trainer.evaluation_loop.num_dataloaders):
# remove callback metrics that don't belong to this dataloader
callback_metrics = {
k: v
for k, v in metrics.items() if "dataloader_idx" not in k or f"dataloader_idx_{dl_idx}" in k
}
if has_been_initialized:
self.eval_loop_results[dl_idx].update(callback_metrics)
else:
self.eval_loop_results.append(callback_metrics)

def update_eval_epoch_metrics(self) -> _EVALUATE_OUTPUT:
assert self._epoch_end_reached
metrics = self.metrics

if not self.trainer.sanity_checking:
# log all the metrics as a single dict
log_metrics = metrics[MetricSource.LOG]
if log_metrics:
self.log_metrics(log_metrics)

self._prepare_eval_loop_results(metrics[MetricSource.CALLBACK])

# log results of evaluation
if (
self.trainer.state.fn != TrainerFn.FITTING and self.trainer.evaluating and self.trainer.is_global_zero
and self.trainer.verbose_evaluate
):
print('-' * 80)
for result_idx, results in enumerate(self.eval_loop_results):
print(f'DATALOADER:{result_idx} {self.trainer.state.stage.upper()} RESULTS')
pprint({
k: (v.item() if v.numel() == 1 else v.tolist()) if isinstance(v, torch.Tensor) else v
for k, v in results.items()
})
print('-' * 80)

results = self.eval_loop_results
# clear mem
self.eval_loop_results = []
return results

"""
Train metric updates
"""

def on_train_split_start(self, batch_idx: int, split_idx: int, split_batch: Any) -> None:
self.trainer.results.extract_batch_size(split_batch)
self._batch_idx = batch_idx
self._split_idx = split_idx

def update_train_step_metrics(self) -> None:
if self.trainer.train_loop.should_accumulate() and self.trainer.lightning_module.automatic_optimization:
return

# when metrics should be logged
assert not self._epoch_end_reached
metrics = self.metrics[MetricSource.LOG]
if self.should_update_logs or self.trainer.fast_dev_run is True and metrics:
self.log_metrics(metrics)

def update_train_epoch_metrics(self) -> None:
# add the metrics to the loggers
assert self._epoch_end_reached
metrics = self.metrics[MetricSource.LOG]
if metrics:
self.log_metrics(metrics)

# reset result collection for next epoch
self.trainer.results.reset(metrics=True)

"""
Utilities and properties
"""

def on_epoch_start(self) -> None:
self._epoch_end_reached = False

def on_batch_start(self) -> None:
self._epoch_end_reached = False

def epoch_end_reached(self):
self.trainer.logger_connector._epoch_end_reached = True
self.trainer.logger_connector._batch_idx = None
self.trainer.logger_connector._split_idx = None

def on_epoch_end(self) -> None:
assert self._epoch_end_reached
metrics = self.metrics
self._progress_bar_metrics.update(metrics[MetricSource.PBAR])
self._callback_metrics.update(metrics[MetricSource.CALLBACK])
self._logged_metrics.update(metrics[MetricSource.LOG])
self._current_fx = None

def on_batch_end(self) -> None:
assert not self._epoch_end_reached
metrics = self.metrics
self._progress_bar_metrics.update(metrics[MetricSource.PBAR])
self._callback_metrics.update(metrics[MetricSource.CALLBACK])
self._logged_metrics.update(metrics[MetricSource.LOG])

def should_reset_tensors(self, fx: str) -> bool:
is_different_fx = self._current_fx != fx
if self._split_idx is None:
is_first_batch = self._batch_idx in (None, 0)
else:
is_first_batch = self._batch_idx + self._split_idx == 0
return is_different_fx and is_first_batch

def reset(self, metrics: Optional[bool] = None) -> None:
self.trainer.results.reset(metrics=metrics)
self._batch_idx = None
self._split_idx = None
self._current_fx = None

@property
def metrics(self) -> Dict[MetricSource, Dict[str, _METRIC]]:
"""This function returns either batch or epoch metrics depending on ``_epoch_end_reached``."""
on_step = not self._epoch_end_reached
return self.trainer.results.metrics(on_step)

@property
def callback_metrics(self) -> Dict[str, _METRIC]:
if self.trainer.results:
metrics = self.metrics[MetricSource.CALLBACK]
self._callback_metrics.update(metrics)
return self._callback_metrics

@property
def logged_metrics(self) -> Dict[str, _METRIC]:
if self.trainer.results:
metrics = self.metrics[MetricSource.LOG]
self._logged_metrics.update(metrics)
return self._logged_metrics

@property
def progress_bar_metrics(self) -> Dict[str, float]:
if self.trainer.results:
metrics = self.metrics[MetricSource.PBAR]
self._progress_bar_metrics.update(metrics)
return self._progress_bar_metrics

def teardown(self):
self.trainer.train_loop.results.cpu()
self.trainer.evaluation_loop._val_results.cpu()
self.trainer.evaluation_loop._test_results.cpu()
Comment on lines +309 to +311
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awaelchli do you think each loop should run this on their teardown?
instead of the logger connector doing it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these results are the resultscollection right? I think yes, if the loops own this results collection I think they should handle the teardown

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay. Do you think I should do that change here or you do it when it's all merged?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we shouldn't hold this back. Added a TODO to our POC.

Loading