diff --git a/CHANGELOG.md b/CHANGELOG.md index ae07f5420ec0b..1535ab1b15606 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -107,7 +107,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Refactored logging * Renamed and moved `core/step_result.py` to `trainer/connectors/logger_connector/result.py` ([#7736](https://github.com/PyTorchLightning/pytorch-lightning/pull/7736)) - + * Dramatically simplify the `LoggerConnector` ([#7882](https://github.com/PyTorchLightning/pytorch-lightning/pull/7882)) + * `trainer.{logged,progress_bar,callback}_metrics` are now updated on-demand ([#7882](https://github.com/PyTorchLightning/pytorch-lightning/pull/7882)) + * Completely overhaul the `Result` object in favor of `ResultMetric` ([#7882](https://github.com/PyTorchLightning/pytorch-lightning/pull/7882)) + * Improve epoch-level reduction time and overall memory usage ([#7882](https://github.com/PyTorchLightning/pytorch-lightning/pull/7882)) - Moved `ignore_scalar_return_in_dp` warning suppression to the DataParallelPlugin class ([#7421](https://github.com/PyTorchLightning/pytorch-lightning/pull/7421/)) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index a05e273ece8aa..93bcc5527643e 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -141,7 +141,7 @@ def should_update_logs(self) -> bool: should_log_every_n_steps = (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0 return should_log_every_n_steps or self.trainer.should_stop - def configure_logger(self, logger: LightningLoggerBase) -> None: + def configure_logger(self, logger: Union[bool, Iterable, LightningLoggerBase]) -> None: if logger is True: version = os.environ.get('PL_EXP_VERSION', self.trainer.slurm_job_id) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py new file mode 100644 index 0000000000000..069a7f5183c70 --- /dev/null +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py @@ -0,0 +1,311 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from pprint import pprint +from typing import Any, Dict, Iterable, Mapping, Optional, Union + +import torch + +import pytorch_lightning as pl +from pytorch_lightning.core import memory +from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger +from pytorch_lightning.trainer.connectors.logger_connector.result_new import _METRIC, MetricSource +from pytorch_lightning.trainer.states import RunningStage, TrainerFn +from pytorch_lightning.utilities import DeviceType +from pytorch_lightning.utilities.metrics import metrics_to_scalars +from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT + + +# TODO(@carmocca): Remove `New` suffix +class LoggerConnectorNew: + + def __init__(self, trainer: 'pl.Trainer', log_gpu_memory: Optional[str] = None) -> None: + self.trainer = trainer + self.log_gpu_memory = log_gpu_memory + self.eval_loop_results = [] + self._val_log_step: int = 0 + self._test_log_step: int = 0 + self._progress_bar_metrics: Dict[str, float] = {} + self._logged_metrics: Dict[str, _METRIC] = {} + self._callback_metrics: Dict[str, _METRIC] = {} + self._epoch_end_reached = False + self._current_fx: Optional[str] = None + self._batch_idx: Optional[int] = None + self._split_idx: Optional[int] = None + + def on_trainer_init( + self, + logger: LightningLoggerBase, + flush_logs_every_n_steps: int, + log_every_n_steps: int, + move_metrics_to_cpu: bool, + ) -> None: + self.configure_logger(logger) + self.trainer.flush_logs_every_n_steps = flush_logs_every_n_steps + self.trainer.log_every_n_steps = log_every_n_steps + self.trainer.move_metrics_to_cpu = move_metrics_to_cpu + + @property + def should_flush_logs(self) -> bool: + should_flush = (self.trainer.global_step + 1) % self.trainer.flush_logs_every_n_steps == 0 + return should_flush or self.trainer.should_stop + + @property + def should_update_logs(self) -> bool: + should_log_every_n_steps = (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0 + return should_log_every_n_steps or self.trainer.should_stop + + def configure_logger(self, logger: Union[bool, Iterable, LightningLoggerBase]) -> None: + if logger is True: + version = os.environ.get('PL_EXP_VERSION', self.trainer.slurm_job_id) + + # default logger + self.trainer.logger = TensorBoardLogger( + save_dir=self.trainer.default_root_dir, version=version, name='lightning_logs' + ) + elif logger is False: + self.trainer.logger = None + else: + if isinstance(logger, Iterable): + self.trainer.logger = LoggerCollection(logger) + else: + self.trainer.logger = logger + + def log_metrics(self, metrics: Dict[str, _METRIC], step: Optional[int] = None) -> None: + """Logs the metric dict passed in. + If `step` parameter is None and `step` key is presented is metrics, + uses metrics["step"] as a step + + Args: + metrics: Metric values + step: Step for which metrics should be logged. Default value is `self.global_step` during training or + the total validation / test log step count during validation and testing. + """ + # add gpu memory + if self.trainer._device_type == DeviceType.GPU and self.log_gpu_memory: + mem_map = memory.get_memory_profile(self.log_gpu_memory) + metrics.update(mem_map) + + # turn all tensors to scalars + scalar_metrics = metrics_to_scalars(metrics) + + if "step" in scalar_metrics and step is None: + step = scalar_metrics.pop("step") + + elif step is None: + # added metrics by Lightning for convenience + scalar_metrics['epoch'] = self.trainer.current_epoch + step = self.trainer.global_step + + # log actual metrics + if self.trainer.logger is not None: + if self.trainer.is_global_zero: + self.trainer.logger.agg_and_log_metrics(scalar_metrics, step=step) + self.trainer.logger.save() + + self._logged_metrics.update(scalar_metrics) + + """ + Evaluation metric updates + """ + + @property + def _eval_log_step(self) -> Optional[int]: + if self.trainer.state.stage is RunningStage.VALIDATING: + return self._val_log_step + elif self.trainer.state.stage is RunningStage.TESTING: + return self._test_log_step + else: + return None + + def _increment_eval_log_step(self) -> None: + if self.trainer.state.stage is RunningStage.VALIDATING: + self._val_log_step += 1 + elif self.trainer.state.stage is RunningStage.TESTING: + self._test_log_step += 1 + + def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int, num_dataloaders: int) -> None: + model = self.trainer.lightning_module + # set dataloader_idx only if multiple ones + model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None + + # track batch_size + self.trainer.result_collection.extract_batch_size(batch) + self._batch_idx = batch_idx + + def update_eval_step_metrics(self) -> None: + if self.trainer.sanity_checking: + return + + # logs user requested information to logger + assert not self._epoch_end_reached + metrics = self.metrics[MetricSource.LOG] + if metrics: + self.log_metrics(metrics, step=self._eval_log_step) + + # increment the step even if nothing was logged + self._increment_eval_log_step() + + def _prepare_eval_loop_results(self, metrics: Mapping[str, _METRIC]) -> None: + if self.trainer.sanity_checking: + return + + num_dataloaders = self.trainer.evaluation_loop.num_dataloaders + has_been_initialized = len(self.eval_loop_results) == num_dataloaders + for dl_idx in range(self.trainer.evaluation_loop.num_dataloaders): + # remove callback metrics that don't belong to this dataloader + callback_metrics = { + k: v + for k, v in metrics.items() if "dataloader_idx" not in k or f"dataloader_idx_{dl_idx}" in k + } + if has_been_initialized: + self.eval_loop_results[dl_idx].update(callback_metrics) + else: + self.eval_loop_results.append(callback_metrics) + + def update_eval_epoch_metrics(self) -> _EVALUATE_OUTPUT: + assert self._epoch_end_reached + metrics = self.metrics + + if not self.trainer.sanity_checking: + # log all the metrics as a single dict + log_metrics = metrics[MetricSource.LOG] + if log_metrics: + self.log_metrics(log_metrics) + + self._prepare_eval_loop_results(metrics[MetricSource.CALLBACK]) + + # log results of evaluation + if ( + self.trainer.state.fn != TrainerFn.FITTING and self.trainer.evaluating and self.trainer.is_global_zero + and self.trainer.verbose_evaluate + ): + print('-' * 80) + for result_idx, results in enumerate(self.eval_loop_results): + print(f'DATALOADER:{result_idx} {self.trainer.state.stage.upper()} RESULTS') + pprint({ + k: (v.item() if v.numel() == 1 else v.tolist()) if isinstance(v, torch.Tensor) else v + for k, v in results.items() + }) + print('-' * 80) + + results = self.eval_loop_results + # clear mem + self.eval_loop_results = [] + return results + + """ + Train metric updates + """ + + def on_train_split_start(self, batch_idx: int, split_idx: int, split_batch: Any) -> None: + self.trainer.results.extract_batch_size(split_batch) + self._batch_idx = batch_idx + self._split_idx = split_idx + + def update_train_step_metrics(self) -> None: + if self.trainer.train_loop.should_accumulate() and self.trainer.lightning_module.automatic_optimization: + return + + # when metrics should be logged + assert not self._epoch_end_reached + metrics = self.metrics[MetricSource.LOG] + if self.should_update_logs or self.trainer.fast_dev_run is True and metrics: + self.log_metrics(metrics) + + def update_train_epoch_metrics(self) -> None: + # add the metrics to the loggers + assert self._epoch_end_reached + metrics = self.metrics[MetricSource.LOG] + if metrics: + self.log_metrics(metrics) + + # reset result collection for next epoch + self.trainer.results.reset(metrics=True) + + """ + Utilities and properties + """ + + def on_epoch_start(self) -> None: + self._epoch_end_reached = False + + def on_batch_start(self) -> None: + self._epoch_end_reached = False + + def epoch_end_reached(self): + self.trainer.logger_connector._epoch_end_reached = True + self.trainer.logger_connector._batch_idx = None + self.trainer.logger_connector._split_idx = None + + def on_epoch_end(self) -> None: + assert self._epoch_end_reached + metrics = self.metrics + self._progress_bar_metrics.update(metrics[MetricSource.PBAR]) + self._callback_metrics.update(metrics[MetricSource.CALLBACK]) + self._logged_metrics.update(metrics[MetricSource.LOG]) + self._current_fx = None + + def on_batch_end(self) -> None: + assert not self._epoch_end_reached + metrics = self.metrics + self._progress_bar_metrics.update(metrics[MetricSource.PBAR]) + self._callback_metrics.update(metrics[MetricSource.CALLBACK]) + self._logged_metrics.update(metrics[MetricSource.LOG]) + + def should_reset_tensors(self, fx: str) -> bool: + is_different_fx = self._current_fx != fx + if self._split_idx is None: + is_first_batch = self._batch_idx in (None, 0) + else: + is_first_batch = self._batch_idx + self._split_idx == 0 + return is_different_fx and is_first_batch + + def reset(self, metrics: Optional[bool] = None) -> None: + self.trainer.results.reset(metrics=metrics) + self._batch_idx = None + self._split_idx = None + self._current_fx = None + + @property + def metrics(self) -> Dict[MetricSource, Dict[str, _METRIC]]: + """This function returns either batch or epoch metrics depending on ``_epoch_end_reached``.""" + on_step = not self._epoch_end_reached + return self.trainer.results.metrics(on_step) + + @property + def callback_metrics(self) -> Dict[str, _METRIC]: + if self.trainer.results: + metrics = self.metrics[MetricSource.CALLBACK] + self._callback_metrics.update(metrics) + return self._callback_metrics + + @property + def logged_metrics(self) -> Dict[str, _METRIC]: + if self.trainer.results: + metrics = self.metrics[MetricSource.LOG] + self._logged_metrics.update(metrics) + return self._logged_metrics + + @property + def progress_bar_metrics(self) -> Dict[str, float]: + if self.trainer.results: + metrics = self.metrics[MetricSource.PBAR] + self._progress_bar_metrics.update(metrics) + return self._progress_bar_metrics + + def teardown(self): + self.trainer.train_loop.results.cpu() + self.trainer.evaluation_loop._val_results.cpu() + self.trainer.evaluation_loop._test_results.cpu() diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result_new.py b/pytorch_lightning/trainer/connectors/logger_connector/result_new.py new file mode 100644 index 0000000000000..03a7c78e11175 --- /dev/null +++ b/pytorch_lightning/trainer/connectors/logger_connector/result_new.py @@ -0,0 +1,499 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Generator +from dataclasses import dataclass, field +from functools import partial, wraps +from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Tuple, Union + +import torch +from torchmetrics import Metric + +from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator +from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections +from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin +from pytorch_lightning.utilities.enums import LightningEnum +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.metrics import metrics_to_scalars + +# re-define the ones from pytorch_lightning.utilities.types without the `Number` type +_METRIC = Union[Metric, torch.Tensor] +_METRIC_COLLECTION = Union[_METRIC, Mapping[str, _METRIC]] + + +class MetricSource(LightningEnum): + CALLBACK = "callback" + PBAR = "pbar" + LOG = "log" + + +@dataclass +class _Sync: + fn: Callable + should: bool = False + op: Union[Any, str] = 'mean' + group: Optional[Any] = None + + @property + def __call__(self) -> Any: + return partial(self.fn, reduce_op=self.op, group=self.group) if self.should else self.no_op + + @staticmethod + def no_op(value: Any, *_, **__) -> Any: + return value + + +@dataclass +class _Metadata: + fx: str + name: str + prog_bar: bool = False + logger: bool = True + on_step: bool = False + on_epoch: bool = True + reduce_fx: Callable = torch.mean + enable_graph: bool = False + dataloader_idx: Optional[int] = None + metric_attribute: Optional[str] = None + sync: _Sync = field(default_factory=_Sync) + + @property + def forked(self) -> bool: + return self.on_step and self.on_epoch + + def forked_name(self, on_step: bool) -> str: + if self.forked: + return f'{self.name}_{"step" if on_step else "epoch"}' + return self.name + + @property + def is_mean_reduction(self) -> bool: + return self.reduce_fx == torch.mean + + @property + def is_max_reduction(self) -> bool: + return self.reduce_fx in (torch.max, max) + + @property + def is_min_reduction(self) -> bool: + return self.reduce_fx in (torch.min, min) + + @property + def is_custom_reduction(self) -> bool: + return not (self.is_mean_reduction or self.is_max_reduction or self.is_min_reduction) + + +class ResultMetric(Metric, DeviceDtypeModuleMixin): + """Wraps the value provided to `:meth:`~pytorch_lightning.core.lightning.LightningModule.log`""" + + def __init__(self, metadata: _Metadata, is_tensor: bool) -> None: + super().__init__() + self.is_tensor = is_tensor + self.meta = metadata + self.has_reset = False + if is_tensor: + self.add_state("value", torch.tensor(0, dtype=torch.float)) + if self.meta.is_mean_reduction: + self.add_state("cumulated_batch_size", torch.tensor(0, dtype=torch.float)) + + def update(self, value: _METRIC, batch_size: torch.Tensor) -> None: + if self.is_tensor: + value = value.float() + self._forward_cache = value + # performance: no need to accumulate on values only logged on_step + if self.meta.on_step and not self.meta.on_epoch: + self.value = self.meta.sync(value) + return + # perform accumulation with reduction + if self.meta.is_mean_reduction: + self.value += value.mean() * batch_size + self.cumulated_batch_size += batch_size + elif self.meta.is_max_reduction or self.meta.is_min_reduction: + self.value = self.meta.reduce_fx(self.value, value.mean()) + else: + self.value = value # noqa: attribute-defined-outside-init + self._forward_cache = value._forward_cache + + def compute(self) -> torch.Tensor: + if self.is_tensor: + value = self.meta.sync(self.value) + if self.meta.is_mean_reduction: + cumulated_batch_size = self.meta.sync(self.cumulated_batch_size) + return value / cumulated_batch_size + elif self.meta.is_max_reduction or self.meta.is_min_reduction: + return value + raise MisconfigurationException( + f"Only [min, max, mean] reductions are supported. Found {self.meta.reduce_fx}" + ) + return self.value.compute() + + def reset(self) -> None: + if self.is_tensor: + super().reset() + else: + self.value.reset() + self.has_reset = True + + def forward(self, value: _METRIC, batch_size: torch.Tensor) -> None: + if self.meta.enable_graph: + with torch.no_grad(): + self.update(value, batch_size) + else: + # performance: skip the `torch.no_grad` context manager by calling `update` directly + self.update(value, batch_size) + + def _wrap_compute(self, compute: Any) -> Any: + # Override to avoid syncing - we handle it ourselves. + @wraps(compute) + def wrapped_func(*args, **kwargs): + if not self._update_called: + rank_zero_warn( + f"The ``compute`` method of metric {self.__class__.__name__}" + " was called before the ``update`` method which may lead to errors," + " as metric states have not yet been updated.", UserWarning + ) + + # return cached value + if self._computed is not None: + return self._computed + self._computed = compute(*args, **kwargs) + return self._computed + + return wrapped_func + + def __setattr__(self, key: str, value: Any) -> None: + # performance: skip the `torch.nn.Module.__setattr__` checks + object.__setattr__(self, key, value) + + def __repr__(self) -> str: + state = f"value={self.value}" + if self.is_tensor and self.meta.is_mean_reduction: + state += f", cumulated_batch_size={self.cumulated_batch_size}" + return f"{self.__class__.__name__}({state})" + + +class ResultMetricCollection(dict): + """ + Dict wrapper for easy access to metadata. + + All of the leaf items should be instances of + :class:`~pytorch_lightning.trainer.connectors.logger_connector.result.ResultMetric` + with the same metadata. + """ + + def __init__(self, *args, metadata: Optional[_Metadata] = None) -> None: + super().__init__(*args) + self.meta = metadata + + +class ResultCollection(dict): + """ + Collection (dictionary) of :class:`~pytorch_lightning.trainer.connectors.logger_connector.result.ResultMetric` or + :class:`~pytorch_lightning.trainer.connectors.logger_connector.result.ResultMetricCollection` + + Example: + + # `device` needs to be provided before logging + result = ResultCollection(True, torch.device("cpu")) + + # you can log to a specific collection. + # arguments: fx, key, value, metadata + result.log('training_step', 'acc', torch.tensor(...), on_step=True, on_epoch=True) + result.log('validation_step', 'recall', torch.tensor(...), on_step=True, on_epoch=True) + """ + + DATALOADER_SUFFIX = "/dataloader_idx_{}" + + def __init__(self, training: bool, device: Optional[torch.device] = None) -> None: + super().__init__() + self.training = training + self._minimize = None + self._batch_size = torch.tensor(1, device=device) + self.device: Optional[Union[str, torch.device]] = device + self.fx_validator = FxValidator() + + @property + def batch_size(self) -> torch.Tensor: + # performance: cache the `batch_size` tensor instead of re-creating it + return self._batch_size + + @batch_size.setter + def batch_size(self, value: int) -> None: + self._batch_size = torch.tensor(value, device=self.device) + + @property + def minimize(self) -> Optional[torch.Tensor]: + """ + The :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` loss + will be saved as the ``minimize`` attribute. + """ + return self._minimize + + @minimize.setter + def minimize(self, loss: Optional[torch.Tensor]) -> None: + if loss is not None: + if not isinstance(loss, torch.Tensor): + raise ValueError(f"`Result.minimize` must be a `torch.Tensor`, found: {loss}") + if loss.grad_fn is None: + raise RuntimeError("`Result.minimize` must have a `grad_fn`") + self._minimize = loss + + @property + def extra(self) -> Dict[str, Any]: + """ + Extras are any keys other than the loss returned by + :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` + """ + return self.get('_extra', {}) + + @extra.setter + def extra(self, extra: Mapping[str, Any]) -> None: + + def check_fn(v): + if v.grad_fn is not None: + raise MisconfigurationException(f'You returned a tensor with `grad_fn`. The extra values are {extra}') + + apply_to_collection(extra, torch.Tensor, check_fn) + self['_extra'] = extra + + def log( + self, + fx: str, + name: str, + value: _METRIC_COLLECTION, + prog_bar: bool = False, + logger: bool = True, + on_step: bool = False, + on_epoch: bool = True, + reduce_fx: Callable = torch.mean, + enable_graph: bool = False, + sync_dist: bool = False, + sync_dist_fn: Callable = _Sync.no_op, + sync_dist_op: Union[Any, str] = 'mean', + sync_dist_group: Optional[Any] = None, + dataloader_idx: Optional[int] = None, + batch_size: Optional[int] = None, + metric_attribute: Optional[str] = None, + ) -> None: + """See :meth:`~pytorch_lightning.core.lightning.LightningModule.log`""" + # no metrics should be logged with graphs + if not enable_graph and isinstance(value, torch.Tensor): + value = value.detach() + + # move metrics to cpu on TPU. + if isinstance(value, torch.Tensor) and value.device.type == "xla": + value = value.cpu() + + # storage key + key = f"{fx}.{name}" + # add dataloader_suffix to both key and fx + if dataloader_idx is not None: + key += f'.{dataloader_idx}' + fx += f'.{dataloader_idx}' + + meta = _Metadata( + fx=fx, + name=name, + prog_bar=prog_bar, + logger=logger, + on_step=on_step, + on_epoch=on_epoch, + reduce_fx=reduce_fx, + enable_graph=enable_graph, + dataloader_idx=dataloader_idx, + metric_attribute=metric_attribute, + sync=_Sync( + should=sync_dist, + fn=sync_dist_fn, + op=sync_dist_op, + group=sync_dist_group, + ) + ) + if key not in self: + if meta.is_custom_reduction: + raise MisconfigurationException( + 'Only `self.log(..., reduce_fx={min,max,mean})` are currently supported.' + ' Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`' + ) + self.register_key(key, meta, value) + elif meta != self[key].meta: + raise MisconfigurationException( + f'You called `self.log({name}, ...)` twice in `{fx}` with different arguments. This is not allowed' + ) + + if batch_size is not None: + self.batch_size = batch_size + + self.update_metrics(key, value) + + def register_key(self, key: str, meta: _Metadata, value: _METRIC_COLLECTION) -> None: + """Create one ResultMetric object per value. Value can be provided as a nested collection""" + + def fn(v: _METRIC) -> ResultMetric: + metric = ResultMetric(meta, isinstance(v, torch.Tensor)) + return metric.to(self.device) + + value = apply_to_collection(value, (torch.Tensor, Metric), fn) + if isinstance(value, dict): + value = ResultMetricCollection(value, metadata=meta) + self[key] = value + + def update_metrics(self, key: str, value: _METRIC_COLLECTION) -> None: + + def fn(result_metric, v): + # performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl` + result_metric.forward(v.to(self.device), self.batch_size) + result_metric.has_reset = False + + apply_to_collections(self[key], value, ResultMetric, fn) + + @staticmethod + def _get_cache(result_metric: ResultMetric, on_step: bool) -> Optional[torch.Tensor]: + cache = None + if on_step and result_metric.meta.on_step: + cache = result_metric._forward_cache + elif not on_step and result_metric.meta.on_epoch: + if not result_metric._computed: + result_metric.compute() + cache = result_metric._computed + if cache is not None and not result_metric.meta.enable_graph: + return cache.detach() + return cache + + def valid_items(self) -> Generator: + """This function is used to iterate over current valid metrics.""" + return ((k, v) for k, v in self.items() + if not k == "_extra" and not (isinstance(v, ResultMetric) and v.has_reset)) + + def _forked_name(self, result_metric: ResultMetric, on_step: bool) -> Tuple[str, str]: + name = result_metric.meta.name + forked_name = result_metric.meta.forked_name(on_step) + dl_idx = result_metric.meta.dataloader_idx + if dl_idx is not None: + dataloader_suffix = self.DATALOADER_SUFFIX.format(dl_idx) + name += dataloader_suffix + forked_name += dataloader_suffix + return name, forked_name + + def metrics(self, on_step: bool) -> Dict[MetricSource, Dict[str, _METRIC]]: + metrics = {k: {} for k in MetricSource} + + for key, result_metric in self.valid_items(): + + # extract forward_cache or computed from the ResultMetric. ignore when the output is None + value = apply_to_collection(result_metric, ResultMetric, self._get_cache, on_step, include_none=False) + + # check if the collection is empty + has_tensor = False + + def any_tensor(_): + nonlocal has_tensor + has_tensor = True + + apply_to_collection(value, torch.Tensor, any_tensor) + if not has_tensor: + continue + + name, forked_name = self._forked_name(result_metric, on_step) + + # populate logging metrics + if result_metric.meta.logger: + metrics[MetricSource.LOG][forked_name] = value + + # populate callback metrics. callback metrics don't take `_step` forked metrics + if self.training or result_metric.meta.on_epoch and not on_step: + metrics[MetricSource.CALLBACK][name] = value + metrics[MetricSource.CALLBACK][forked_name] = value + + # populate progress_bar metrics. convert tensors to numbers + if result_metric.meta.prog_bar: + metrics[MetricSource.PBAR][forked_name] = metrics_to_scalars(value) + + return metrics + + def reset(self, metrics: Optional[bool] = None, fx: Optional[str] = None) -> None: + """ + Reset the result collection + + Args: + metrics: If True, only ``torchmetrics.Metric`` results are reset, + if False, only ``torch.Tensors`` are reset, + if ``None``, both are. + fx: Function to reset + """ + + def fn(item: ResultMetric) -> None: + requested_type = metrics is None or metrics ^ item.is_tensor + same_fx = fx is None or fx == item.meta.fx + if requested_type and same_fx: + item.reset() + + apply_to_collection(self, ResultMetric, fn) + + def extract_batch_size(self, batch: Any) -> None: + try: + self.batch_size = self._extract_batch_size(batch) + except RecursionError: + self.batch_size = 1 + + def _extract_batch_size(self, batch: Any) -> int: + """ + Recursively unpack a batch to find a torch.Tensor. + + Returns: + ``len(tensor)`` when found, or ``1`` when it hits an empty or non iterable. + """ + if isinstance(batch, torch.Tensor): + size = batch.size(0) + elif isinstance(batch, str): + return len(batch) + elif isinstance(batch, dict): + sample = next(iter(batch.values()), 1) + size = self._extract_batch_size(sample) + elif isinstance(batch, Iterable): + sample = next(iter(batch), 1) + size = self._extract_batch_size(sample) + else: + size = 1 + return size + + def to(self, *args, **kwargs) -> 'ResultCollection': + """Move all data to the given device.""" + + def to_(item: Union[torch.Tensor, Metric], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Metric]: + return item.to(*args, **kwargs) + + apply_to_collection(self, (torch.Tensor, Metric), to_, *args, **kwargs) + + if self.minimize is not None: + self.minimize = self.minimize.to(*args, **kwargs) + self._batch_size = self._batch_size.to(*args, **kwargs) + if 'device' in kwargs: + self.device = kwargs['device'] + return self + + def cpu(self) -> 'ResultCollection': + """Move all data to CPU.""" + return self.to(device="cpu") + + def __str__(self) -> str: + return f'{self.__class__.__name__}({self.training}, {self.device}, {repr(self)})' + + def __getstate__(self) -> dict: + d = self.__dict__.copy() + # can't deepcopy tensors with grad_fn + minimize = d.get('_minimize') + if minimize is not None: + d['_minimize'] = minimize.detach() + return d diff --git a/requirements.txt b/requirements.txt index 47cc3d47542da..b564e13551a54 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ tqdm>=4.41.0 PyYAML>=5.1,<=5.4.1 fsspec[http]>=2021.05.0, !=2021.06.0 tensorboard>=2.2.0, !=2.5.0 # 2.5.0 GPU CI error: 'Couldn't build proto file into descriptor pool!' -torchmetrics>=0.2.0 +torchmetrics>=0.3.2 pyDeprecate==0.3.1 packaging typing-extensions # TypedDict support for python<3.8 diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index fd08890604807..8a636a0b15dd1 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -11,14 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import torch import torch.distributed as dist import torch.multiprocessing as mp from torchmetrics import Metric import tests.helpers.utils as tutils -from pytorch_lightning.trainer.connectors.logger_connector.result import Result +from pytorch_lightning.trainer.connectors.logger_connector.result_new import MetricSource, ResultCollection from tests.helpers.runif import RunIf @@ -52,12 +51,14 @@ def _ddp_test_fn(rank, worldsize): metric_b = DummyMetric() metric_c = DummyMetric() - # dist_sync_on_step is False by default - result = Result() + metric_a = metric_a.to(f"cuda:{rank}") + metric_b = metric_b.to(f"cuda:{rank}") + metric_c = metric_c.to(f"cuda:{rank}") - for epoch in range(3): - cumulative_sum = 0 + result = ResultCollection(True, torch.device(f"cuda:{rank}")) + for _ in range(3): + cumulative_sum = 0 for i in range(5): metric_a(i) metric_b(i) @@ -65,32 +66,25 @@ def _ddp_test_fn(rank, worldsize): cumulative_sum += i - result.log('a', metric_a, on_step=True, on_epoch=True) - result.log('b', metric_b, on_step=False, on_epoch=True) - result.log('c', metric_c, on_step=True, on_epoch=False) + result.log('h', 'a', metric_a, on_step=True, on_epoch=True, metric_attribute="metric_a") + result.log('h', 'b', metric_b, on_step=False, on_epoch=True, metric_attribute="metric_b") + result.log('h', 'c', metric_c, on_step=True, on_epoch=False, metric_attribute="metric_c") - batch_log = result.get_batch_log_metrics() - batch_expected = {"a_step": i, "a": i, "c": i} - assert set(batch_log.keys()) == set(batch_expected.keys()) - for k in batch_expected.keys(): - assert batch_expected[k] == batch_log[k] + batch_log = result.metrics(True)[MetricSource.LOG] + assert batch_log == {"a_step": i, "c": i} - epoch_log = result.get_epoch_log_metrics() + epoch_log = result.metrics(False)[MetricSource.LOG] result.reset() # assert metric state reset to default values - assert metric_a.x == metric_a._defaults['x'] + assert metric_a.x == metric_a._defaults['x'], (metric_a.x, metric_a._defaults['x']) assert metric_b.x == metric_b._defaults['x'] assert metric_c.x == metric_c._defaults['x'] - epoch_expected = {"b": cumulative_sum * worldsize, "a_epoch": cumulative_sum * worldsize} + assert epoch_log == {"b": cumulative_sum * worldsize, "a_epoch": cumulative_sum * worldsize} - assert set(epoch_log.keys()) == set(epoch_expected.keys()) - for k in epoch_expected.keys(): - assert epoch_expected[k] == epoch_log[k] - -@RunIf(skip_windows=True) +@RunIf(skip_windows=True, min_gpus=2) def test_result_reduce_ddp(): """Make sure result logging works with DDP""" tutils.set_random_master_port() @@ -104,11 +98,10 @@ def test_result_metric_integration(): metric_b = DummyMetric() metric_c = DummyMetric() - result = Result() + result = ResultCollection(True, torch.device("cpu")) - for epoch in range(3): + for _ in range(3): cumulative_sum = 0 - for i in range(5): metric_a(i) metric_b(i) @@ -116,17 +109,14 @@ def test_result_metric_integration(): cumulative_sum += i - result.log('a', metric_a, on_step=True, on_epoch=True) - result.log('b', metric_b, on_step=False, on_epoch=True) - result.log('c', metric_c, on_step=True, on_epoch=False) + result.log('h', 'a', metric_a, on_step=True, on_epoch=True, metric_attribute="metric_a") + result.log('h', 'b', metric_b, on_step=False, on_epoch=True, metric_attribute="metric_b") + result.log('h', 'c', metric_c, on_step=True, on_epoch=False, metric_attribute="metric_c") - batch_log = result.get_batch_log_metrics() - batch_expected = {"a_step": i, "a": i, "c": i} - assert set(batch_log.keys()) == set(batch_expected.keys()) - for k in batch_expected.keys(): - assert batch_expected[k] == batch_log[k] + batch_log = result.metrics(True)[MetricSource.LOG] + assert batch_log == {"a_step": i, "c": i} - epoch_log = result.get_epoch_log_metrics() + epoch_log = result.metrics(False)[MetricSource.LOG] result.reset() # assert metric state reset to default values @@ -134,8 +124,54 @@ def test_result_metric_integration(): assert metric_b.x == metric_b._defaults['x'] assert metric_c.x == metric_c._defaults['x'] - epoch_expected = {"b": cumulative_sum, "a_epoch": cumulative_sum} - - assert set(epoch_log.keys()) == set(epoch_expected.keys()) - for k in epoch_expected.keys(): - assert epoch_expected[k] == epoch_log[k] + assert epoch_log == {"b": cumulative_sum, "a_epoch": cumulative_sum} + + assert str(result) == ( + "ResultCollection(True, cpu, {" + "'h.a': ResultMetric(value=DummyMetric()), " + "'h.b': ResultMetric(value=DummyMetric()), " + "'h.c': ResultMetric(value=DummyMetric())" + "})" + ) + + +def test_result_collection_simple_loop(): + result = ResultCollection(True, torch.device("cpu")) + current_fx_name = None + batch_idx = None + + def lightning_log(fx, *args, **kwargs): + nonlocal current_fx_name + if current_fx_name != fx and batch_idx in (None, 0): + result.reset(metrics=False, fx=fx) + result.log(fx, *args, **kwargs) + current_fx_name = fx + + lightning_log('a0', 'a', torch.tensor(0.), on_step=True, on_epoch=True) + lightning_log('a1', 'a', torch.tensor(0.), on_step=True, on_epoch=True) + for epoch in range(2): + lightning_log('b0', 'a', torch.tensor(1.) + epoch, on_step=True, on_epoch=True) + lightning_log('b1', 'a', torch.tensor(1.) + epoch, on_step=True, on_epoch=True) + for batch_idx in range(2): + lightning_log('c0', 'a', torch.tensor(2.) + epoch, on_step=True, on_epoch=True) + lightning_log('c1', 'a', torch.tensor(2.) + epoch, on_step=True, on_epoch=True) + lightning_log('c2', 'a', torch.tensor(2.) + epoch, on_step=True, on_epoch=True) + batch_idx = None + lightning_log('d0', 'a', torch.tensor(3.) + epoch, on_step=False, on_epoch=True) + lightning_log('d1', 'a', torch.tensor(3.) + epoch, on_step=False, on_epoch=True) + + for k in ('a0.a', 'a1.a'): + assert result[k].value == torch.tensor(0.), k + assert result[k].cumulated_batch_size == torch.tensor(1.), k + + for k in ('b0.a', 'b1.a'): + assert result[k].value == torch.tensor(1.) + epoch, k + assert result[k].cumulated_batch_size == torch.tensor(1.), k + + for k in ('c0.a', 'c1.a', 'c2.a'): + assert result[k].value == torch.tensor(4.) + epoch * 2, k + assert result[k].cumulated_batch_size == torch.tensor(2.), k + + for k in ('d0.a', 'd1.a'): + assert result[k].value == torch.tensor(3.) + epoch, k + assert result[k].cumulated_batch_size == torch.tensor(1.), k diff --git a/tests/core/test_results.py b/tests/core/test_results.py index 9fce99ffbecc9..59f01184ba33e 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -20,7 +20,9 @@ import torch.multiprocessing as mp import tests.helpers.utils as tutils -from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning import Trainer +from pytorch_lightning.trainer.connectors.logger_connector.result_new import _Sync +from pytorch_lightning.utilities.distributed import sync_ddp_if_available from tests.helpers import BoringDataModule, BoringModel from tests.helpers.runif import RunIf @@ -37,7 +39,8 @@ def _setup_ddp(rank, worldsize): def _ddp_test_fn(rank, worldsize): _setup_ddp(rank, worldsize) tensor = torch.tensor([1.0]) - actual = LightningModule._LightningModule__sync(tensor, sync_dist=True, sync_dist_op=torch.distributed.ReduceOp.SUM) + sync = _Sync(sync_ddp_if_available, should=True, op=torch.distributed.ReduceOp.SUM) + actual = sync(tensor) assert actual.item() == dist.get_world_size(), "Result-Log does not work properly with DDP and Tensors" diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index 8e3bab7350f7f..b178126c4f81b 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -21,10 +21,11 @@ import tests.helpers.pipelines as tpipes import tests.helpers.utils as tutils -from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning import Trainer from pytorch_lightning.accelerators import TPUAccelerator from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.plugins import TPUSpawnPlugin +from pytorch_lightning.trainer.connectors.logger_connector.result_new import _Sync from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -424,12 +425,9 @@ def test_tpu_sync_dist(): """Test tpu spawn sync dist operation """ def test_sync_dist(_): - value = LightningModule._LightningModule__sync( - torch.tensor([1.0]), - sync_fn=TPUSpawnPlugin().reduce, - sync_dist=True, - sync_dist_op=torch.distributed.ReduceOp.SUM - ) + sync = _Sync(TPUSpawnPlugin().reduce, should=True, op=torch.distributed.ReduceOp.SUM) + value = torch.tensor([1.0]) + value = sync(value), assert value.item() == 8 xmp.spawn(test_sync_dist, nprocs=8, start_method='fork') diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index adf26467e86d9..b7ee13fed6d8e 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -29,6 +29,7 @@ from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder from pytorch_lightning.trainer.connectors.logger_connector.result import Result +from pytorch_lightning.trainer.connectors.logger_connector.result_new import MetricSource, ResultCollection from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel, RandomDataset from tests.helpers.runif import RunIf @@ -108,8 +109,8 @@ def training_step_end(self, *_): assert train_results.has_reduced is True generated = train_results(fx_name="training_step", dl_idx=0, opt_idx=0, reduced=True)['train_loss_epoch'].item() - excepted = torch.stack(model.train_losses).mean().item() - assert generated == excepted + expected = torch.stack(model.train_losses).mean().item() + assert generated == expected def test__logger_connector__epoch_result_store__train__tbptt(tmpdir): @@ -453,7 +454,7 @@ def test_metrics_holder(to_float, tmpdir): def is_float(value: Any) -> bool: return isinstance(value, float) - excepted_function = is_float if to_float else torch.is_tensor + expected_function = is_float if to_float else torch.is_tensor targets = torch.tensor([1], device=device) acc = Accuracy().to(device) metric_holder = MetricsHolder(to_float=to_float) @@ -464,9 +465,9 @@ def is_float(value: Any) -> bool: }) metric_holder.convert(device) metrics = metric_holder.metrics - assert excepted_function(metrics["x"]) - assert excepted_function(metrics["y"]) - assert excepted_function(metrics["z"]) + assert expected_function(metrics["x"]) + assert expected_function(metrics["y"]) + assert expected_function(metrics["z"]) def test_metric_holder_raises(tmpdir): @@ -686,3 +687,97 @@ def _assert_called(model, stage): trainer.test(model) _assert_called(model, 'test') + + +def test_result_collection_on_tensor_with_mean_reduction(): + result_collection = ResultCollection(True, torch.device("cpu")) + product = [(True, True), (False, True), (True, False), (False, False)] + values = torch.arange(1, 10).float() # need to convert to float() due to precision issues using torch 1.4 + batches = values * values + + for i, v in enumerate(values): + for prog_bar in [False, True]: + for logger in [False, True]: + for on_step, on_epoch in product: + name = "loss" + if on_step: + name += "_on_step" + if on_epoch: + name += "_on_epoch" + if prog_bar: + name += "_prog_bar" + if logger: + name += "_logger" + result_collection.log( + "training_step", + name, + v, + on_step=on_step, + on_epoch=on_epoch, + batch_size=batches[i], + prog_bar=prog_bar, + logger=logger, + ) + + total_value = sum(values * batches) + total_batches = sum(batches) + assert result_collection["training_step.loss_on_step_on_epoch"].value == total_value + assert result_collection["training_step.loss_on_step_on_epoch"].cumulated_batch_size == total_batches + + batch_metrics = result_collection.metrics(True) + max_ = max(values) + assert batch_metrics[MetricSource.PBAR] == { + 'loss_on_step_on_epoch_prog_bar_step': max_, + 'loss_on_step_on_epoch_prog_bar_logger_step': max_, + 'loss_on_step_prog_bar': max_, + 'loss_on_step_prog_bar_logger': max_, + } + assert batch_metrics[MetricSource.LOG] == { + 'loss_on_step_on_epoch_logger_step': max_, + 'loss_on_step_logger': max_, + 'loss_on_step_on_epoch_prog_bar_logger_step': max_, + 'loss_on_step_prog_bar_logger': max_, + } + assert batch_metrics[MetricSource.CALLBACK] == { + 'loss_on_step': max_, + 'loss_on_step_logger': max_, + 'loss_on_step_on_epoch': max_, + 'loss_on_step_on_epoch_logger': max_, + 'loss_on_step_on_epoch_logger_step': max_, + 'loss_on_step_on_epoch_prog_bar': max_, + 'loss_on_step_on_epoch_prog_bar_logger': max_, + 'loss_on_step_on_epoch_prog_bar_logger_step': max_, + 'loss_on_step_on_epoch_prog_bar_step': max_, + 'loss_on_step_on_epoch_step': max_, + 'loss_on_step_prog_bar': max_, + 'loss_on_step_prog_bar_logger': max_, + } + + epoch_metrics = result_collection.metrics(False) + mean = total_value / total_batches + assert epoch_metrics[MetricSource.PBAR] == { + 'loss_on_epoch_prog_bar': mean, + 'loss_on_epoch_prog_bar_logger': mean, + 'loss_on_step_on_epoch_prog_bar_epoch': mean, + 'loss_on_step_on_epoch_prog_bar_logger_epoch': mean, + } + assert epoch_metrics[MetricSource.LOG] == { + 'loss_on_epoch_logger': mean, + 'loss_on_epoch_prog_bar_logger': mean, + 'loss_on_step_on_epoch_logger_epoch': mean, + 'loss_on_step_on_epoch_prog_bar_logger_epoch': mean + } + assert epoch_metrics[MetricSource.CALLBACK] == { + 'loss_on_epoch': mean, + 'loss_on_epoch_logger': mean, + 'loss_on_epoch_prog_bar': mean, + 'loss_on_epoch_prog_bar_logger': mean, + 'loss_on_step_on_epoch': mean, + 'loss_on_step_on_epoch_epoch': mean, + 'loss_on_step_on_epoch_logger': mean, + 'loss_on_step_on_epoch_logger_epoch': mean, + 'loss_on_step_on_epoch_prog_bar': mean, + 'loss_on_step_on_epoch_prog_bar_epoch': mean, + 'loss_on_step_on_epoch_prog_bar_logger': mean, + 'loss_on_step_on_epoch_prog_bar_logger_epoch': mean + }