From 0c98c8bb2fdb53ad21f4b0ab6e3f465fbfdac959 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 8 Jun 2021 13:23:15 +0200 Subject: [PATCH 01/24] New logger connector code --- .../logger_connector/logger_connector_new.py | 307 +++++++++++ .../connectors/logger_connector/result_new.py | 503 ++++++++++++++++++ 2 files changed, 810 insertions(+) create mode 100644 pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py create mode 100644 pytorch_lightning/trainer/connectors/logger_connector/result_new.py diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py new file mode 100644 index 0000000000000..0b7327183d77f --- /dev/null +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py @@ -0,0 +1,307 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from pprint import pprint +from typing import Any, Dict, Iterable, Mapping, Optional + +import torch + +import pytorch_lightning as pl +from pytorch_lightning.core import memory +from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger +from pytorch_lightning.trainer.connectors.logger_connector.result import _METRIC, MetricSource +from pytorch_lightning.trainer.states import RunningStage, TrainerFn +from pytorch_lightning.utilities import DeviceType +from pytorch_lightning.utilities.metrics import metrics_to_scalars +from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT + + +class LoggerConnector: + + def __init__(self, trainer: 'pl.Trainer', log_gpu_memory: Optional[str] = None) -> None: + self.trainer = trainer + self.log_gpu_memory = log_gpu_memory + self.eval_loop_results = [] + self._val_log_step: int = 0 + self._test_log_step: int = 0 + self._progress_bar_metrics: Dict[str, float] = {} + self._logged_metrics: Dict[str, _METRIC] = {} + self._callback_metrics: Dict[str, _METRIC] = {} + self._epoch_end_reached = False + self._current_fx: Optional[str] = None + self._batch_idx: Optional[int] = None + self._split_idx: Optional[int] = None + + def on_trainer_init( + self, logger: LightningLoggerBase, flush_logs_every_n_steps: int, log_every_n_steps: int, + move_metrics_to_cpu: bool + ) -> None: + self.configure_logger(logger) + self.trainer.flush_logs_every_n_steps = flush_logs_every_n_steps + self.trainer.log_every_n_steps = log_every_n_steps + self.trainer.move_metrics_to_cpu = move_metrics_to_cpu + + @property + def should_flush_logs(self) -> bool: + should_flush = (self.trainer.global_step + 1) % self.trainer.flush_logs_every_n_steps == 0 + return should_flush or self.trainer.should_stop + + @property + def should_update_logs(self) -> bool: + should_log_every_n_steps = (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0 + return should_log_every_n_steps or self.trainer.should_stop + + def configure_logger(self, logger: LightningLoggerBase) -> None: + if logger is True: + version = os.environ.get('PL_EXP_VERSION', self.trainer.slurm_job_id) + + # default logger + self.trainer.logger = TensorBoardLogger( + save_dir=self.trainer.default_root_dir, version=version, name='lightning_logs' + ) + elif logger is False: + self.trainer.logger = None + else: + if isinstance(logger, Iterable): + self.trainer.logger = LoggerCollection(logger) + else: + self.trainer.logger = logger + + def log_metrics(self, metrics: Dict[str, _METRIC], step: Optional[int] = None) -> None: + """Logs the metric dict passed in. + If `step` parameter is None and `step` key is presented is metrics, + uses metrics["step"] as a step + + Args: + metrics: Metric values + step: Step for which metrics should be logged. Default value is `self.global_step` during training or + the total validation / test log step count during validation and testing. + """ + # add gpu memory + if self.trainer._device_type == DeviceType.GPU and self.log_gpu_memory: + mem_map = memory.get_memory_profile(self.log_gpu_memory) + metrics.update(mem_map) + + # turn all tensors to scalars + scalar_metrics = metrics_to_scalars(metrics) + + if "step" in scalar_metrics and step is None: + step = scalar_metrics.pop("step") + + elif step is None: + # added metrics by Lightning for convenience + scalar_metrics['epoch'] = self.trainer.current_epoch + step = self.trainer.global_step + + # log actual metrics + if self.trainer.logger is not None: + if self.trainer.is_global_zero: + self.trainer.logger.agg_and_log_metrics(scalar_metrics, step=step) + self.trainer.logger.save() + + self._logged_metrics.update(scalar_metrics) + + """ + Evaluation metric updates + """ + + def prepare_eval_loop_results(self, metrics: Mapping[str, _METRIC]) -> None: + if self.trainer.sanity_checking: + return + + num_dataloaders = self.trainer.evaluation_loop.num_dataloaders + has_been_initialized = len(self.eval_loop_results) == num_dataloaders + for dl_idx in range(self.trainer.evaluation_loop.num_dataloaders): + # remove callback metrics that don't belong to this dataloader + callback_metrics = { + k: v + for k, v in metrics.items() if "dataloader_idx" not in k or f"dataloader_idx_{dl_idx}" in k + } + if has_been_initialized: + self.eval_loop_results[dl_idx].update(callback_metrics) + else: + self.eval_loop_results.append(callback_metrics) + + def get_evaluate_epoch_results(self) -> _EVALUATE_OUTPUT: + assert self._epoch_end_reached + metrics = self.metrics + + if not self.trainer.sanity_checking: + # log all the metrics as a single dict + log_metrics = metrics[MetricSource.LOG] + if log_metrics: + self.log_metrics(log_metrics) + + self.prepare_eval_loop_results(metrics[MetricSource.CALLBACK]) + + # log results of evaluation + if ( + self.trainer.state.fn != TrainerFn.FITTING and self.trainer.evaluating and self.trainer.is_global_zero + and self.trainer.verbose_evaluate + ): + print('-' * 80) + for result_idx, results in enumerate(self.eval_loop_results): + print(f'DATALOADER:{result_idx} {self.trainer.state.stage.upper()} RESULTS') + pprint({ + k: (v.item() if v.numel() == 1 else v.tolist()) if isinstance(v, torch.Tensor) else v + for k, v in results.items() + }) + print('-' * 80) + + results = self.eval_loop_results + # clear mem + self.eval_loop_results = [] + return results + + @property + def evaluation_log_step(self) -> Optional[int]: + if self.trainer.state.stage is RunningStage.VALIDATING: + return self._val_log_step + elif self.trainer.state.stage is RunningStage.TESTING: + return self._test_log_step + else: + return None + + def increment_evaluation_log_step(self) -> None: + if self.trainer.state.stage is RunningStage.VALIDATING: + self._val_log_step += 1 + elif self.trainer.state.stage is RunningStage.TESTING: + self._test_log_step += 1 + + def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int, num_dataloaders: int) -> None: + model = self.trainer.lightning_module + # set dataloader_idx only if multiple ones + model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None + + # track batch_size + self.trainer.result_collection.extract_batch_size(batch) + self._batch_idx = batch_idx + + def update_evaluation_step_metrics(self) -> None: + if self.trainer.sanity_checking: + return + + # logs user requested information to logger + assert not self._epoch_end_reached + metrics = self.metrics[MetricSource.LOG] + if metrics: + self.log_metrics(metrics, step=self.evaluation_log_step) + + # increment the step even if nothing was logged + self.increment_evaluation_log_step() + + """ + Train metric updates + """ + + def on_train_split_start(self, batch_idx: int, split_idx: int, split_batch: Any) -> None: + self.trainer.result_collection.extract_batch_size(split_batch) + self._batch_idx = batch_idx + self._split_idx = split_idx + + def update_train_step_metrics(self) -> None: + if self.trainer.train_loop.should_accumulate() and self.trainer.lightning_module.automatic_optimization: + return + + # when metrics should be logged + assert not self._epoch_end_reached + metrics = self.metrics[MetricSource.LOG] + if self.should_update_logs or self.trainer.fast_dev_run is True and metrics: + self.log_metrics(metrics) + + def update_train_epoch_metrics(self) -> None: + # add the metrics to the loggers + assert self._epoch_end_reached + metrics = self.metrics[MetricSource.LOG] + if metrics: + self.log_metrics(metrics) + + # reset result collection for next epoch + self.trainer.result_collection.reset(metrics=True) + + def teardown(self): + self.trainer.train_loop.train_results.cpu() + self.trainer.evaluation_loop.validation_results.cpu() + self.trainer.evaluation_loop.test_results.cpu() + + """ + Utilities and properties + """ + + def on_epoch_start(self) -> None: + self._epoch_end_reached = False + + def on_batch_start(self) -> None: + self._epoch_end_reached = False + + def epoch_end_reached(self): + self.trainer.logger_connector._epoch_end_reached = True + self.trainer.logger_connector._batch_idx = None + self.trainer.logger_connector._split_idx = None + + def on_epoch_end(self) -> None: + assert self._epoch_end_reached + metrics = self.metrics + self._progress_bar_metrics.update(metrics[MetricSource.PBAR]) + self._callback_metrics.update(metrics[MetricSource.CALLBACK]) + self._logged_metrics.update(metrics[MetricSource.LOG]) + self._current_fx = None + + def on_batch_end(self) -> None: + assert not self._epoch_end_reached + metrics = self.metrics + self._progress_bar_metrics.update(metrics[MetricSource.PBAR]) + self._callback_metrics.update(metrics[MetricSource.CALLBACK]) + self._logged_metrics.update(metrics[MetricSource.LOG]) + + def should_reset_tensors(self, fx: str) -> bool: + is_different_fx = self._current_fx != fx + if self._split_idx is None: + is_first_batch = self._batch_idx in (None, 0) + else: + is_first_batch = self._batch_idx + self._split_idx == 0 + return is_different_fx and is_first_batch + + def reset(self, metrics: Optional[bool] = None) -> None: + self.trainer.result_collection.reset(metrics=metrics) + self._batch_idx = None + self._split_idx = None + self._current_fx = None + + @property + def metrics(self) -> Dict[MetricSource, Dict[str, _METRIC]]: + """This function returns either batch or epoch metrics depending on ``_epoch_end_reached``.""" + on_step = not self._epoch_end_reached + return self.trainer.result_collection.metrics(on_step) + + @property + def callback_metrics(self) -> Dict[str, _METRIC]: + if self.trainer.result_collection: + metrics = self.metrics[MetricSource.CALLBACK] + self._callback_metrics.update(metrics) + return self._callback_metrics + + @property + def logged_metrics(self) -> Dict[str, _METRIC]: + if self.trainer.result_collection: + metrics = self.metrics[MetricSource.LOG] + self._logged_metrics.update(metrics) + return self._logged_metrics + + @property + def progress_bar_metrics(self) -> Dict[str, float]: + if self.trainer.result_collection: + metrics = self.metrics[MetricSource.PBAR] + self._progress_bar_metrics.update(metrics) + return self._progress_bar_metrics diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result_new.py b/pytorch_lightning/trainer/connectors/logger_connector/result_new.py new file mode 100644 index 0000000000000..40481dd9afb68 --- /dev/null +++ b/pytorch_lightning/trainer/connectors/logger_connector/result_new.py @@ -0,0 +1,503 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Generator +from dataclasses import dataclass, field +from functools import partial, wraps +from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Tuple, Union + +import torch +from torchmetrics import Metric + +from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator +from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections +from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin +from pytorch_lightning.utilities.enums import LightningEnum +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +# re-define the ones from pytorch_lightning.utilities.types without the `Number` type +_METRIC = Union[Metric, torch.Tensor] +_METRIC_COLLECTION = Union[_METRIC, Mapping[str, _METRIC]] + + +class MetricSource(LightningEnum): + CALLBACK = "callback" + PBAR = "pbar" + LOG = "log" + + +@dataclass +class _Sync: + fn: Callable + should: bool = False + op: Union[Any, str] = 'mean' + group: Optional[Any] = None + + @property + def __call__(self) -> Any: + return partial(self.fn, reduce_op=self.op, group=self.group) if self.should else self.no_op + + @staticmethod + def no_op(value: Any, *_, **__) -> Any: + return value + + +@dataclass +class _Metadata: + fx: str + name: str + prog_bar: bool = False + logger: bool = True + on_step: bool = False + on_epoch: bool = True + reduce_fx: Callable = torch.mean + enable_graph: bool = False + dataloader_idx: Optional[int] = None + metric_attribute: Optional[str] = None + sync: _Sync = field(default_factory=_Sync) + + @property + def forked(self) -> bool: + return self.on_step and self.on_epoch + + def forked_name(self, on_step: bool) -> str: + if self.forked: + return f'{self.name}_{"step" if on_step else "epoch"}' + return self.name + + @property + def is_mean_reduction(self) -> bool: + return self.reduce_fx == torch.mean + + @property + def is_max_reduction(self) -> bool: + return self.reduce_fx in (torch.max, max) + + @property + def is_min_reduction(self) -> bool: + return self.reduce_fx in (torch.min, min) + + @property + def is_custom_reduction(self) -> bool: + return not (self.is_mean_reduction or self.is_max_reduction or self.is_min_reduction) + + +class ResultMetric(Metric, DeviceDtypeModuleMixin): + """Wraps the value provided to `:meth:`~pytorch_lightning.core.lightning.LightningModule.log`""" + + def __init__(self, metadata: _Metadata, is_tensor: bool) -> None: + super().__init__() + self.is_tensor = is_tensor + self.meta = metadata + self.has_reset = False + if is_tensor: + self.add_state("value", torch.tensor(0, dtype=torch.float)) + if self.meta.is_mean_reduction: + self.add_state("cumulated_batch_size", torch.tensor(0, dtype=torch.float)) + + def update(self, value: _METRIC, batch_size: torch.Tensor) -> None: + if self.is_tensor: + value = value.float() + self._forward_cache = value + # performance: no need to accumulate on values only logged on_step + if self.meta.on_step and not self.meta.on_epoch: + self.value = self.meta.sync(value) + return + # perform accumulation with reduction + if self.meta.is_mean_reduction: + self.value += value.mean() * batch_size + self.cumulated_batch_size += batch_size + elif self.meta.is_max_reduction or self.meta.is_min_reduction: + self.value = self.meta.reduce_fx(self.value, value.mean()) + else: + self.value = value # noqa: attribute-defined-outside-init + self._forward_cache = value._forward_cache + + def compute(self) -> torch.Tensor: + if self.is_tensor: + value = self.meta.sync(self.value) + if self.meta.is_mean_reduction: + cumulated_batch_size = self.meta.sync(self.cumulated_batch_size) + return value / cumulated_batch_size + elif self.meta.is_max_reduction or self.meta.is_min_reduction: + return value + raise MisconfigurationException( + f"Only [min, max, mean] reductions are supported. Found {self.meta.reduce_fx}" + ) + return self.value.compute() + + def reset(self) -> None: + if self.is_tensor: + super().reset() + else: + self.value.reset() + self.has_reset = True + + def forward(self, value: _METRIC, batch_size: torch.Tensor) -> None: + if self.meta.enable_graph: + with torch.no_grad(): + self.update(value, batch_size) + else: + # performance: skip the `torch.no_grad` context manager by calling `update` directly + self.update(value, batch_size) + + def _wrap_compute(self, compute: Any) -> Any: + # Override to avoid syncing - we handle it ourselves. + @wraps(compute) + def wrapped_func(*args, **kwargs): + if not self._update_called: + rank_zero_warn( + f"The ``compute`` method of metric {self.__class__.__name__}" + " was called before the ``update`` method which may lead to errors," + " as metric states have not yet been updated.", UserWarning + ) + + # return cached value + if self._computed is not None: + return self._computed + self._computed = compute(*args, **kwargs) + return self._computed + + return wrapped_func + + def __setattr__(self, key: str, value: Any) -> None: + # performance: skip the `torch.nn.Module.__setattr__` checks + object.__setattr__(self, key, value) + + def __repr__(self) -> str: + state = f"value={self.value}" + if self.is_tensor and self.meta.is_mean_reduction: + state += f", cumulated_batch_size={self.cumulated_batch_size}" + return f"{self.__class__.__name__}({state})" + + +class ResultMetricCollection(dict): + """ + Dict wrapper for easy access to metadata. + + All of the leaf items should be instances of + :class:`~pytorch_lightning.trainer.connectors.logger_connector.result.ResultMetric` + with the same metadata. + """ + + def __init__(self, *args, metadata: Optional[_Metadata] = None) -> None: + super().__init__(*args) + self.meta = metadata + + +class ResultCollection(dict): + """ + Collection (dictionary) of :class:`~pytorch_lightning.trainer.connectors.logger_connector.result.ResultMetric` or + :class:`~pytorch_lightning.trainer.connectors.logger_connector.result.ResultMetricCollection` + + Example: + + # `device` needs to be provided before logging + result = ResultCollection(True, torch.device("cpu")) + + # you can log to a specific collection. + # arguments: fx, key, value, metadata + result.log('training_step', 'acc', torch.tensor(...), on_step=True, on_epoch=True) + result.log('validation_step', 'recall', torch.tensor(...), on_step=True, on_epoch=True) + """ + + DATALOADER_SUFFIX = "/dataloader_idx_{}" + + def __init__(self, training: bool, device: Optional[torch.device] = None) -> None: + super().__init__() + self.training = training + self._minimize = None + self._batch_size = torch.tensor(1, device=device) + self.device: Optional[Union[str, torch.device]] = device + self.fx_validator = FxValidator() + + @property + def batch_size(self) -> torch.Tensor: + # performance: cache the `batch_size` tensor instead of re-creating it + return self._batch_size + + @batch_size.setter + def batch_size(self, value: int) -> None: + self._batch_size = torch.tensor(value, device=self.device) + + @property + def minimize(self) -> Optional[torch.Tensor]: + """ + The :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` loss + will be saved as the ``minimize`` attribute. + """ + return self._minimize + + @minimize.setter + def minimize(self, loss: Optional[torch.Tensor]) -> None: + if loss is not None: + if not isinstance(loss, torch.Tensor): + raise ValueError(f"`Result.minimize` must be a `torch.Tensor`, found: {loss}") + if loss.grad_fn is None: + raise RuntimeError("`Result.minimize` must have a `grad_fn`") + self._minimize = loss + + @property + def extra(self) -> Dict[str, Any]: + """ + Extras are any keys other than the loss returned by + :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` + """ + return self.get('_extra', {}) + + @extra.setter + def extra(self, extra: Mapping[str, Any]) -> None: + + def check_fn(v): + if v.grad_fn is not None: + raise MisconfigurationException(f'You returned a tensor with `grad_fn`. The extra values are {extra}') + + apply_to_collection(extra, torch.Tensor, check_fn) + self['_extra'] = extra + + def log( + self, + fx: str, + name: str, + value: _METRIC_COLLECTION, + prog_bar: bool = False, + logger: bool = True, + on_step: bool = False, + on_epoch: bool = True, + reduce_fx: Callable = torch.mean, + enable_graph: bool = False, + sync_dist: bool = False, + sync_dist_fn: Callable = _Sync.no_op, + sync_dist_op: Union[Any, str] = 'mean', + sync_dist_group: Optional[Any] = None, + dataloader_idx: Optional[int] = None, + batch_size: Optional[int] = None, + metric_attribute: Optional[str] = None, + ) -> None: + """See :meth:`~pytorch_lightning.core.lightning.LightningModule.log`""" + # no metrics should be logged with graphs + if not enable_graph and isinstance(value, torch.Tensor): + value = value.detach() + + # move metrics to cpu on TPU. + if isinstance(value, torch.Tensor) and value.device.type == "xla": + value = value.cpu() + + # storage key + key = f"{fx}.{name}" + # add dataloader_suffix to both key and fx + if dataloader_idx is not None: + key += f'.{dataloader_idx}' + fx += f'.{dataloader_idx}' + + meta = _Metadata( + fx=fx, + name=name, + prog_bar=prog_bar, + logger=logger, + on_step=on_step, + on_epoch=on_epoch, + reduce_fx=reduce_fx, + enable_graph=enable_graph, + dataloader_idx=dataloader_idx, + metric_attribute=metric_attribute, + sync=_Sync( + should=sync_dist, + fn=sync_dist_fn, + op=sync_dist_op, + group=sync_dist_group, + ) + ) + if key not in self: + if meta.is_custom_reduction: + raise MisconfigurationException( + 'Only `self.log(..., reduce_fx={min,max,mean})` are currently supported.' + ' Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`' + ) + self.register_key(key, meta, value) + elif meta != self[key].meta: + raise MisconfigurationException( + f'You called `self.log({name}, ...)` twice in `{fx}` with different arguments. This is not allowed' + ) + + if batch_size is not None: + self.batch_size = batch_size + + self.update_metrics(key, value) + + def register_key(self, key: str, meta: _Metadata, value: _METRIC_COLLECTION) -> None: + """Create one ResultMetric object per value. Value can be provided as a nested collection""" + + def fn(v: _METRIC) -> ResultMetric: + metric = ResultMetric(meta, isinstance(v, torch.Tensor)) + return metric.to(self.device) + + value = apply_to_collection(value, (torch.Tensor, Metric), fn) + if isinstance(value, dict): + value = ResultMetricCollection(value, metadata=meta) + self[key] = value + + def update_metrics(self, key: str, value: _METRIC_COLLECTION) -> None: + + def fn(result_metric, v): + # performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl` + result_metric.forward(v.to(self.device), self.batch_size) + result_metric.has_reset = False + + apply_to_collections(self[key], value, ResultMetric, fn) + + @staticmethod + def _get_cache(result_metric: ResultMetric, on_step: bool) -> Optional[torch.Tensor]: + cache = None + if on_step and result_metric.meta.on_step: + cache = result_metric._forward_cache + elif not on_step and result_metric.meta.on_epoch: + if not result_metric._computed: + result_metric.compute() + cache = result_metric._computed + if cache is not None and not result_metric.meta.enable_graph: + return cache.detach() + return cache + + @staticmethod + def __to_item(t: torch.Tensor) -> float: + return t.item() + + def valid_items(self) -> Generator: + """This function is used to iterate over current valid metrics.""" + return ((k, v) for k, v in self.items() + if not k == "_extra" and not (isinstance(v, ResultMetric) and v.has_reset)) + + def _forked_name(self, result_metric: ResultMetric, on_step: bool) -> Tuple[str, str]: + name = result_metric.meta.name + forked_name = result_metric.meta.forked_name(on_step) + dl_idx = result_metric.meta.dataloader_idx + if dl_idx is not None: + dataloader_suffix = self.DATALOADER_SUFFIX.format(dl_idx) + name += dataloader_suffix + forked_name += dataloader_suffix + return name, forked_name + + def metrics(self, on_step: bool) -> Dict[MetricSource, Dict[str, _METRIC]]: + metrics = {k: {} for k in MetricSource} + + for key, result_metric in self.valid_items(): + + # extract forward_cache or computed from the ResultMetric. ignore when the output is None + value = apply_to_collection(result_metric, ResultMetric, self._get_cache, on_step, include_none=False) + + # check if the collection is empty + has_tensor = False + + def any_tensor(_): + nonlocal has_tensor + has_tensor = True + + apply_to_collection(value, torch.Tensor, any_tensor) + if not has_tensor: + continue + + name, forked_name = self._forked_name(result_metric, on_step) + + # populate logging metrics + if result_metric.meta.logger: + metrics[MetricSource.LOG][forked_name] = value + + # populate callback metrics. callback metrics don't take `_step` forked metrics + if self.training or result_metric.meta.on_epoch and not on_step: + metrics[MetricSource.CALLBACK][name] = value + metrics[MetricSource.CALLBACK][forked_name] = value + + # populate progress_bar metrics. convert tensors to numbers + if result_metric.meta.prog_bar: + value = apply_to_collection(value, torch.Tensor, self.__to_item, include_none=False) + metrics[MetricSource.PBAR][forked_name] = value + + return metrics + + def reset(self, metrics: Optional[bool] = None, fx: Optional[str] = None) -> None: + """ + Reset the result collection + + Args: + metrics: If True, only ``torchmetrics.Metric`` results are reset, + if False, only ``torch.Tensors`` are reset, + if ``None``, both are. + fx: Function to reset + """ + + def fn(item: ResultMetric) -> None: + requested_type = metrics is None or metrics ^ item.is_tensor + same_fx = fx is None or fx == item.meta.fx + if requested_type and same_fx: + item.reset() + + apply_to_collection(self, ResultMetric, fn) + + def extract_batch_size(self, batch: Any) -> None: + try: + self.batch_size = self._extract_batch_size(batch) + except RecursionError: + self.batch_size = 1 + + def _extract_batch_size(self, batch: Any) -> int: + """ + Recursively unpack a batch to find a torch.Tensor. + + Returns: + ``len(tensor)`` when found, or ``1`` when it hits an empty or non iterable. + """ + if isinstance(batch, torch.Tensor): + size = batch.size(0) + elif isinstance(batch, str): + return len(batch) + elif isinstance(batch, dict): + sample = next(iter(batch.values()), 1) + size = self._extract_batch_size(sample) + elif isinstance(batch, Iterable): + sample = next(iter(batch), 1) + size = self._extract_batch_size(sample) + else: + size = 1 + return size + + def to(self, *args, **kwargs) -> 'ResultCollection': + """Move all data to the given device.""" + + def to_(item: Union[torch.Tensor, Metric], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Metric]: + return item.to(*args, **kwargs) + + apply_to_collection(self, (torch.Tensor, Metric), to_, *args, **kwargs) + + if self.minimize is not None: + self.minimize = self.minimize.to(*args, **kwargs) + self._batch_size = self._batch_size.to(*args, **kwargs) + if 'device' in kwargs: + self.device = kwargs['device'] + return self + + def cpu(self) -> 'ResultCollection': + """Move all data to CPU.""" + return self.to(device="cpu") + + def __str__(self) -> str: + return f'{self.__class__.__name__}({self.training}, {self.device}, {repr(self)})' + + def __getstate__(self) -> dict: + d = self.__dict__.copy() + # can't deepcopy tensors with grad_fn + minimize = d.get('_minimize') + if minimize is not None: + d['_minimize'] = minimize.detach() + return d From ee78a9009e06b7f99cb246cf70d35151d9151a7e Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 8 Jun 2021 13:28:57 +0200 Subject: [PATCH 02/24] Update CHANGELOG --- CHANGELOG.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b1f4e0693faa6..53cc451e9b02b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -104,7 +104,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Refactored logging * Renamed and moved `core/step_result.py` to `trainer/connectors/logger_connector/result.py` ([#7736](https://github.com/PyTorchLightning/pytorch-lightning/pull/7736)) - + * Dramatically simplify the `LoggerConnector` ([#7882](https://github.com/PyTorchLightning/pytorch-lightning/pull/7882)) + * `trainer.{logged,progress_bar,callback}_metrics` are now updated on-demand ([#7882](https://github.com/PyTorchLightning/pytorch-lightning/pull/7882)) + * Completely overhaul the `Result` object in favor of `ResultMetric` ([#7882](https://github.com/PyTorchLightning/pytorch-lightning/pull/7882)) + * Improve epoch-level reduction time and overall memory usage ([#7882](https://github.com/PyTorchLightning/pytorch-lightning/pull/7882)) - Moved `ignore_scalar_return_in_dp` warning suppression to the DataParallelPlugin class ([#7421](https://github.com/PyTorchLightning/pytorch-lightning/pull/7421/)) From 917d6448fc89d5b4861115708081d39f50f11e69 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 8 Jun 2021 13:29:58 +0200 Subject: [PATCH 03/24] Update requirements --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 4ad6d15f158df..7d19ee8f6b558 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ tqdm>=4.41.0 PyYAML>=5.1,<=5.4.1 fsspec[http]>=2021.05.0, !=2021.06.0 tensorboard>=2.2.0, !=2.5.0 # 2.5.0 GPU CI error: 'Couldn't build proto file into descriptor pool!' -torchmetrics>=0.2.0 +torchmetrics>=0.3.2 pyDeprecate==0.3.0 packaging typing-extensions # TypedDict support for python<3.8 From 9d97d25fe4811f1dfe228a3555ca83d8812b8f51 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 8 Jun 2021 13:37:13 +0200 Subject: [PATCH 04/24] Fix import path --- .../trainer/connectors/logger_connector/logger_connector_new.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py index 0b7327183d77f..9e8017e2be11b 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py @@ -20,7 +20,7 @@ import pytorch_lightning as pl from pytorch_lightning.core import memory from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger -from pytorch_lightning.trainer.connectors.logger_connector.result import _METRIC, MetricSource +from pytorch_lightning.trainer.connectors.logger_connector.result_new import _METRIC, MetricSource from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.utilities import DeviceType from pytorch_lightning.utilities.metrics import metrics_to_scalars From cd72867ce48846261716a4ded71c04ff230eb836 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 8 Jun 2021 13:38:24 +0200 Subject: [PATCH 05/24] Add new suffix --- .../connectors/logger_connector/logger_connector_new.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py index 9e8017e2be11b..e3ba3812a1ccf 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py @@ -27,7 +27,8 @@ from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT -class LoggerConnector: +# TODO(@carmocca): Remove `New` suffix +class LoggerConnectorNew: def __init__(self, trainer: 'pl.Trainer', log_gpu_memory: Optional[str] = None) -> None: self.trainer = trainer From 51a31872caa98f1909f0cc261cfc81b9cf522428 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 8 Jun 2021 13:58:00 +0200 Subject: [PATCH 06/24] Tests --- tests/core/test_metric_result_integration.py | 120 ++++++++++++------ tests/core/test_results.py | 7 +- tests/models/test_tpu.py | 12 +- .../trainer/logging_/test_logger_connector.py | 98 +++++++++++++- 4 files changed, 183 insertions(+), 54 deletions(-) diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index fd08890604807..9aa1c4db21d73 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -11,14 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import torch import torch.distributed as dist import torch.multiprocessing as mp from torchmetrics import Metric import tests.helpers.utils as tutils -from pytorch_lightning.trainer.connectors.logger_connector.result import Result +from pytorch_lightning.trainer.connectors.logger_connector.result_new import MetricSource, ResultCollection from tests.helpers.runif import RunIf @@ -52,12 +51,14 @@ def _ddp_test_fn(rank, worldsize): metric_b = DummyMetric() metric_c = DummyMetric() - # dist_sync_on_step is False by default - result = Result() + metric_a = metric_a.to(f"cuda:{rank}") + metric_b = metric_b.to(f"cuda:{rank}") + metric_c = metric_c.to(f"cuda:{rank}") - for epoch in range(3): - cumulative_sum = 0 + result = ResultCollection(True, torch.device(f"cuda:{rank}")) + for _ in range(3): + cumulative_sum = 0 for i in range(5): metric_a(i) metric_b(i) @@ -65,32 +66,25 @@ def _ddp_test_fn(rank, worldsize): cumulative_sum += i - result.log('a', metric_a, on_step=True, on_epoch=True) - result.log('b', metric_b, on_step=False, on_epoch=True) - result.log('c', metric_c, on_step=True, on_epoch=False) + result.log('h', 'a', metric_a, on_step=True, on_epoch=True, metric_attribute="metric_a") + result.log('h', 'b', metric_b, on_step=False, on_epoch=True, metric_attribute="metric_b") + result.log('h', 'c', metric_c, on_step=True, on_epoch=False, metric_attribute="metric_c") - batch_log = result.get_batch_log_metrics() - batch_expected = {"a_step": i, "a": i, "c": i} - assert set(batch_log.keys()) == set(batch_expected.keys()) - for k in batch_expected.keys(): - assert batch_expected[k] == batch_log[k] + batch_log = result.metrics(True)[MetricSource.LOG] + assert batch_log == {"a_step": i, "c": i} - epoch_log = result.get_epoch_log_metrics() + epoch_log = result.metrics(False)[MetricSource.LOG] result.reset() # assert metric state reset to default values - assert metric_a.x == metric_a._defaults['x'] + assert metric_a.x == metric_a._defaults['x'], (metric_a.x, metric_a._defaults['x']) assert metric_b.x == metric_b._defaults['x'] assert metric_c.x == metric_c._defaults['x'] - epoch_expected = {"b": cumulative_sum * worldsize, "a_epoch": cumulative_sum * worldsize} + assert epoch_log == {"b": cumulative_sum * worldsize, "a_epoch": cumulative_sum * worldsize} - assert set(epoch_log.keys()) == set(epoch_expected.keys()) - for k in epoch_expected.keys(): - assert epoch_expected[k] == epoch_log[k] - -@RunIf(skip_windows=True) +@RunIf(skip_windows=True, min_gpus=2) def test_result_reduce_ddp(): """Make sure result logging works with DDP""" tutils.set_random_master_port() @@ -104,11 +98,10 @@ def test_result_metric_integration(): metric_b = DummyMetric() metric_c = DummyMetric() - result = Result() + result = ResultCollection(True, torch.device("cpu")) - for epoch in range(3): + for _ in range(3): cumulative_sum = 0 - for i in range(5): metric_a(i) metric_b(i) @@ -116,17 +109,14 @@ def test_result_metric_integration(): cumulative_sum += i - result.log('a', metric_a, on_step=True, on_epoch=True) - result.log('b', metric_b, on_step=False, on_epoch=True) - result.log('c', metric_c, on_step=True, on_epoch=False) + result.log('h', 'a', metric_a, on_step=True, on_epoch=True, metric_attribute="metric_a") + result.log('h', 'b', metric_b, on_step=False, on_epoch=True, metric_attribute="metric_b") + result.log('h', 'c', metric_c, on_step=True, on_epoch=False, metric_attribute="metric_c") - batch_log = result.get_batch_log_metrics() - batch_expected = {"a_step": i, "a": i, "c": i} - assert set(batch_log.keys()) == set(batch_expected.keys()) - for k in batch_expected.keys(): - assert batch_expected[k] == batch_log[k] + batch_log = result.metrics(True)[MetricSource.LOG] + assert batch_log == {"a_step": i, "c": i} - epoch_log = result.get_epoch_log_metrics() + epoch_log = result.metrics(False)[MetricSource.LOG] result.reset() # assert metric state reset to default values @@ -134,8 +124,60 @@ def test_result_metric_integration(): assert metric_b.x == metric_b._defaults['x'] assert metric_c.x == metric_c._defaults['x'] - epoch_expected = {"b": cumulative_sum, "a_epoch": cumulative_sum} - - assert set(epoch_log.keys()) == set(epoch_expected.keys()) - for k in epoch_expected.keys(): - assert epoch_expected[k] == epoch_log[k] + assert epoch_log == {"b": cumulative_sum, "a_epoch": cumulative_sum} + + assert str(result) == ( + "ResultCollection(True, cpu, {" + "'h.a': ResultMetric(value=DummyMetric()), " + "'h.b': ResultMetric(value=DummyMetric()), " + "'h.c': ResultMetric(value=DummyMetric())" + "})" + ) + + +def test_result_collection_simple_loop(): + result = ResultCollection(True, torch.device("cpu")) + current_fx_name = None + batch_idx = None + + def lightning_log(fx, *args, **kwargs): + nonlocal current_fx_name + if current_fx_name != fx and batch_idx in (None, 0): + result.reset(metrics=False, fx=fx) + result.log(fx, *args, **kwargs) + current_fx_name = fx + + lightning_log('a0', 'a', torch.tensor(0.), on_step=True, on_epoch=True) + lightning_log('a1', 'a', torch.tensor(0.), on_step=True, on_epoch=True) + for epoch in range(2): + lightning_log('b0', 'a', torch.tensor(1.) + epoch, on_step=True, on_epoch=True) + lightning_log('b1', 'a', torch.tensor(1.) + epoch, on_step=True, on_epoch=True) + for batch_idx in range(2): + lightning_log('c0', 'a', torch.tensor(2.) + epoch, on_step=True, on_epoch=True) + lightning_log('c1', 'a', torch.tensor(2.) + epoch, on_step=True, on_epoch=True) + lightning_log('c2', 'a', torch.tensor(2.) + epoch, on_step=True, on_epoch=True) + batch_idx = None + lightning_log('d0', 'a', torch.tensor(3.) + epoch, on_step=False, on_epoch=True) + lightning_log('d1', 'a', torch.tensor(3.) + epoch, on_step=False, on_epoch=True) + + assert result['a0.a'].value == torch.tensor(0.) + assert result['a0.a'].cumulated_batch_size == torch.tensor(1.) + assert result['a1.a'].value == torch.tensor(0.) + assert result['a1.a'].cumulated_batch_size == torch.tensor(1.) + + assert result['b0.a'].value == torch.tensor(1.) + epoch + assert result['b0.a'].cumulated_batch_size == torch.tensor(1.) + assert result['b1.a'].value == torch.tensor(1.) + epoch + assert result['b1.a'].cumulated_batch_size == torch.tensor(1.) + + assert result['c0.a'].value == torch.tensor(4.) + epoch * 2 + assert result['c0.a'].cumulated_batch_size == torch.tensor(2.) + assert result['c1.a'].value == torch.tensor(4.) + epoch * 2 + assert result['c1.a'].cumulated_batch_size == torch.tensor(2.) + assert result['c2.a'].value == torch.tensor(4.) + epoch * 2 + assert result['c2.a'].cumulated_batch_size == torch.tensor(2.) + + assert result['d0.a'].value == torch.tensor(3.) + epoch + assert result['d0.a'].cumulated_batch_size == torch.tensor(1.) + assert result['d1.a'].value == torch.tensor(3.) + epoch + assert result['d1.a'].cumulated_batch_size == torch.tensor(1.) diff --git a/tests/core/test_results.py b/tests/core/test_results.py index 9fce99ffbecc9..59f01184ba33e 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -20,7 +20,9 @@ import torch.multiprocessing as mp import tests.helpers.utils as tutils -from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning import Trainer +from pytorch_lightning.trainer.connectors.logger_connector.result_new import _Sync +from pytorch_lightning.utilities.distributed import sync_ddp_if_available from tests.helpers import BoringDataModule, BoringModel from tests.helpers.runif import RunIf @@ -37,7 +39,8 @@ def _setup_ddp(rank, worldsize): def _ddp_test_fn(rank, worldsize): _setup_ddp(rank, worldsize) tensor = torch.tensor([1.0]) - actual = LightningModule._LightningModule__sync(tensor, sync_dist=True, sync_dist_op=torch.distributed.ReduceOp.SUM) + sync = _Sync(sync_ddp_if_available, should=True, op=torch.distributed.ReduceOp.SUM) + actual = sync(tensor) assert actual.item() == dist.get_world_size(), "Result-Log does not work properly with DDP and Tensors" diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index 8e3bab7350f7f..2e7db175801b9 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -21,10 +21,11 @@ import tests.helpers.pipelines as tpipes import tests.helpers.utils as tutils -from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning import Trainer from pytorch_lightning.accelerators import TPUAccelerator from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.plugins import TPUSpawnPlugin +from pytorch_lightning.trainer.connectors.logger_connector.result import _Sync from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -424,12 +425,9 @@ def test_tpu_sync_dist(): """Test tpu spawn sync dist operation """ def test_sync_dist(_): - value = LightningModule._LightningModule__sync( - torch.tensor([1.0]), - sync_fn=TPUSpawnPlugin().reduce, - sync_dist=True, - sync_dist_op=torch.distributed.ReduceOp.SUM - ) + sync = _Sync(TPUSpawnPlugin().reduce, should=True, op=torch.distributed.ReduceOp.SUM) + value = torch.tensor([1.0]) + value = sync(value), assert value.item() == 8 xmp.spawn(test_sync_dist, nprocs=8, start_method='fork') diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 2e6234bb98c76..52629161a9292 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -29,6 +29,7 @@ from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder from pytorch_lightning.trainer.connectors.logger_connector.result import Result +from pytorch_lightning.trainer.connectors.logger_connector.result_new import MetricSource, ResultCollection from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel, RandomDataset from tests.helpers.runif import RunIf @@ -108,8 +109,8 @@ def training_step_end(self, *_): assert train_results.has_reduced is True generated = train_results(fx_name="training_step", dl_idx=0, opt_idx=0, reduced=True)['train_loss_epoch'].item() - excepted = torch.stack(model.train_losses).mean().item() - assert generated == excepted + expected = torch.stack(model.train_losses).mean().item() + assert generated == expected def test__logger_connector__epoch_result_store__train__tbptt(tmpdir): @@ -453,7 +454,7 @@ def test_metrics_holder(to_float, tmpdir): def is_float(value: Any) -> bool: return isinstance(value, float) - excepted_function = is_float if to_float else torch.is_tensor + expected_function = is_float if to_float else torch.is_tensor targets = torch.tensor([1], device=device) acc = Accuracy().to(device) metric_holder = MetricsHolder(to_float=to_float) @@ -464,9 +465,9 @@ def is_float(value: Any) -> bool: }) metric_holder.convert(device) metrics = metric_holder.metrics - assert excepted_function(metrics["x"]) - assert excepted_function(metrics["y"]) - assert excepted_function(metrics["z"]) + assert expected_function(metrics["x"]) + assert expected_function(metrics["y"]) + assert expected_function(metrics["z"]) def test_metric_holder_raises(tmpdir): @@ -686,3 +687,88 @@ def _assert_called(model, stage): trainer.test(model) _assert_called(model, 'test') + + +def test_result_collection_on_tensor_with_mean_reduction(): + result_collection = ResultCollection(True, torch.device("cpu")) + + for i in range(1, 10): + value = torch.tensor(i, dtype=torch.float) + for prob_bar in [False, True]: + for logger in [False, True]: + for j, (on_step, on_epoch) in enumerate([(True, True), (False, True), (True, False), (False, False)], + 1): + result_collection.log( + "training_step", + f"loss_{j}_{int(prob_bar)}_{int(logger)}", + value, + on_step=on_step, + on_epoch=on_epoch, + batch_size=i**2, + prog_bar=prob_bar, + logger=logger + ) + + expected_values = [1, 2, 3, 4, 5, 6, 7, 8, 9] + expected_batches = [1, 4, 9, 16, 25, 36, 49, 64, 81] + total_value = sum(torch.tensor(expected_values) * torch.tensor(expected_batches)) + total_batches = sum(expected_batches) + assert result_collection["training_step.loss_1_0_0"].value == total_value + assert result_collection["training_step.loss_1_0_0"].cumulated_batch_size == total_batches + + batch_metrics = result_collection.metrics(True) + assert batch_metrics[MetricSource.PBAR] == { + 'loss_1_1_0_step': 9, + 'loss_3_1_0': 9, + 'loss_1_1_1_step': 9, + 'loss_3_1_1': 9 + } + assert batch_metrics[MetricSource.LOG] == { + 'loss_1_0_1_step': 9, + 'loss_3_0_1': 9, + 'loss_1_1_1_step': 9, + 'loss_3_1_1': 9, + } + assert batch_metrics[MetricSource.CALLBACK] == { + 'loss_1_0_0': 9, + 'loss_1_0_0_step': 9, + 'loss_3_0_0': 9, + 'loss_1_0_1': 9, + 'loss_1_0_1_step': 9, + 'loss_3_0_1': 9, + 'loss_1_1_0': 9, + 'loss_1_1_0_step': 9, + 'loss_3_1_0': 9, + 'loss_1_1_1': 9, + 'loss_1_1_1_step': 9, + 'loss_3_1_1': 9, + } + + epoch_metrics = result_collection.metrics(False) + mean = total_value / total_batches + assert epoch_metrics[MetricSource.PBAR] == { + 'loss_1_1_0_epoch': mean, + 'loss_2_1_0': mean, + 'loss_1_1_1_epoch': mean, + 'loss_2_1_1': mean + } + assert epoch_metrics[MetricSource.LOG] == { + 'loss_1_0_1_epoch': mean, + 'loss_2_0_1': mean, + 'loss_1_1_1_epoch': mean, + 'loss_2_1_1': mean + } + assert epoch_metrics[MetricSource.CALLBACK] == { + 'loss_1_0_0': mean, + 'loss_1_0_0_epoch': mean, + 'loss_2_0_0': mean, + 'loss_1_0_1': mean, + 'loss_1_0_1_epoch': mean, + 'loss_2_0_1': mean, + 'loss_1_1_0': mean, + 'loss_1_1_0_epoch': mean, + 'loss_2_1_0': mean, + 'loss_1_1_1': mean, + 'loss_1_1_1_epoch': mean, + 'loss_2_1_1': mean, + } From 792d0739a0ff5327b26f5ecf8c339c270d293acb Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 8 Jun 2021 14:06:31 +0200 Subject: [PATCH 07/24] Minor changes --- tests/trainer/logging_/test_logger_connector.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 52629161a9292..6a0618c0ad1bc 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -691,21 +691,21 @@ def _assert_called(model, stage): def test_result_collection_on_tensor_with_mean_reduction(): result_collection = ResultCollection(True, torch.device("cpu")) + product = [(True, True), (False, True), (True, False), (False, False)] for i in range(1, 10): value = torch.tensor(i, dtype=torch.float) - for prob_bar in [False, True]: + for prog_bar in [False, True]: for logger in [False, True]: - for j, (on_step, on_epoch) in enumerate([(True, True), (False, True), (True, False), (False, False)], - 1): + for j, (on_step, on_epoch) in enumerate(product, 1): result_collection.log( "training_step", - f"loss_{j}_{int(prob_bar)}_{int(logger)}", + f"loss_{j}_{int(prog_bar)}_{int(logger)}", value, on_step=on_step, on_epoch=on_epoch, batch_size=i**2, - prog_bar=prob_bar, + prog_bar=prog_bar, logger=logger ) From cef98a711ffd0c3430e5d15338b81b7033b855f5 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 8 Jun 2021 14:15:30 +0200 Subject: [PATCH 08/24] Rename and reorder --- .../logger_connector/logger_connector_new.py | 80 +++++++++---------- 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py index e3ba3812a1ccf..5e63918c931ab 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py @@ -117,7 +117,44 @@ def log_metrics(self, metrics: Dict[str, _METRIC], step: Optional[int] = None) - Evaluation metric updates """ - def prepare_eval_loop_results(self, metrics: Mapping[str, _METRIC]) -> None: + @property + def _eval_log_step(self) -> Optional[int]: + if self.trainer.state.stage is RunningStage.VALIDATING: + return self._val_log_step + elif self.trainer.state.stage is RunningStage.TESTING: + return self._test_log_step + else: + return None + + def _increment_eval_log_step(self) -> None: + if self.trainer.state.stage is RunningStage.VALIDATING: + self._val_log_step += 1 + elif self.trainer.state.stage is RunningStage.TESTING: + self._test_log_step += 1 + + def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int, num_dataloaders: int) -> None: + model = self.trainer.lightning_module + # set dataloader_idx only if multiple ones + model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None + + # track batch_size + self.trainer.result_collection.extract_batch_size(batch) + self._batch_idx = batch_idx + + def update_eval_step_metrics(self) -> None: + if self.trainer.sanity_checking: + return + + # logs user requested information to logger + assert not self._epoch_end_reached + metrics = self.metrics[MetricSource.LOG] + if metrics: + self.log_metrics(metrics, step=self._eval_log_step) + + # increment the step even if nothing was logged + self._increment_eval_log_step() + + def _prepare_eval_loop_results(self, metrics: Mapping[str, _METRIC]) -> None: if self.trainer.sanity_checking: return @@ -134,7 +171,7 @@ def prepare_eval_loop_results(self, metrics: Mapping[str, _METRIC]) -> None: else: self.eval_loop_results.append(callback_metrics) - def get_evaluate_epoch_results(self) -> _EVALUATE_OUTPUT: + def update_eval_epoch_metrics(self) -> _EVALUATE_OUTPUT: assert self._epoch_end_reached metrics = self.metrics @@ -144,7 +181,7 @@ def get_evaluate_epoch_results(self) -> _EVALUATE_OUTPUT: if log_metrics: self.log_metrics(log_metrics) - self.prepare_eval_loop_results(metrics[MetricSource.CALLBACK]) + self._prepare_eval_loop_results(metrics[MetricSource.CALLBACK]) # log results of evaluation if ( @@ -165,43 +202,6 @@ def get_evaluate_epoch_results(self) -> _EVALUATE_OUTPUT: self.eval_loop_results = [] return results - @property - def evaluation_log_step(self) -> Optional[int]: - if self.trainer.state.stage is RunningStage.VALIDATING: - return self._val_log_step - elif self.trainer.state.stage is RunningStage.TESTING: - return self._test_log_step - else: - return None - - def increment_evaluation_log_step(self) -> None: - if self.trainer.state.stage is RunningStage.VALIDATING: - self._val_log_step += 1 - elif self.trainer.state.stage is RunningStage.TESTING: - self._test_log_step += 1 - - def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int, num_dataloaders: int) -> None: - model = self.trainer.lightning_module - # set dataloader_idx only if multiple ones - model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None - - # track batch_size - self.trainer.result_collection.extract_batch_size(batch) - self._batch_idx = batch_idx - - def update_evaluation_step_metrics(self) -> None: - if self.trainer.sanity_checking: - return - - # logs user requested information to logger - assert not self._epoch_end_reached - metrics = self.metrics[MetricSource.LOG] - if metrics: - self.log_metrics(metrics, step=self.evaluation_log_step) - - # increment the step even if nothing was logged - self.increment_evaluation_log_step() - """ Train metric updates """ From 860206d4361fa154c200a77ed6553504447fee25 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 8 Jun 2021 14:17:56 +0200 Subject: [PATCH 09/24] Fix import --- tests/models/test_tpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index 2e7db175801b9..b178126c4f81b 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -25,7 +25,7 @@ from pytorch_lightning.accelerators import TPUAccelerator from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.plugins import TPUSpawnPlugin -from pytorch_lightning.trainer.connectors.logger_connector.result import _Sync +from pytorch_lightning.trainer.connectors.logger_connector.result_new import _Sync from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException From 518d775c5533eb910029fcabcc9750fb6a281907 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 8 Jun 2021 14:18:48 +0200 Subject: [PATCH 10/24] Formatting --- .../connectors/logger_connector/logger_connector_new.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py index 5e63918c931ab..31df06f6348e1 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py @@ -45,8 +45,11 @@ def __init__(self, trainer: 'pl.Trainer', log_gpu_memory: Optional[str] = None) self._split_idx: Optional[int] = None def on_trainer_init( - self, logger: LightningLoggerBase, flush_logs_every_n_steps: int, log_every_n_steps: int, - move_metrics_to_cpu: bool + self, + logger: LightningLoggerBase, + flush_logs_every_n_steps: int, + log_every_n_steps: int, + move_metrics_to_cpu: bool, ) -> None: self.configure_logger(logger) self.trainer.flush_logs_every_n_steps = flush_logs_every_n_steps From c5baa2cf288d02d7bc227eeefd7b5b6dd56923f5 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 8 Jun 2021 14:36:56 +0200 Subject: [PATCH 11/24] Fix with seed_everything? --- tests/trainer/logging_/test_logger_connector.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 6a0618c0ad1bc..2505a48246bce 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -23,7 +23,7 @@ from torch.utils.data import DataLoader from torchmetrics import Accuracy, AveragePrecision -from pytorch_lightning import LightningModule +from pytorch_lightning import LightningModule, seed_everything from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.trainer import Trainer from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator @@ -690,6 +690,7 @@ def _assert_called(model, stage): def test_result_collection_on_tensor_with_mean_reduction(): + seed_everything(42) result_collection = ResultCollection(True, torch.device("cpu")) product = [(True, True), (False, True), (True, False), (False, False)] From 51b8f356735ea0a35f08e554f0bc6c0f18483f64 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 8 Jun 2021 15:14:51 +0200 Subject: [PATCH 12/24] Fix test? --- .../trainer/logging_/test_logger_connector.py | 121 ++++++++++-------- 1 file changed, 65 insertions(+), 56 deletions(-) diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 2505a48246bce..b9d8f498a2a7e 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -20,10 +20,11 @@ import pytest import torch +from numpy.core import allclose from torch.utils.data import DataLoader from torchmetrics import Accuracy, AveragePrecision -from pytorch_lightning import LightningModule, seed_everything +from pytorch_lightning import LightningModule from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.trainer import Trainer from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator @@ -690,86 +691,94 @@ def _assert_called(model, stage): def test_result_collection_on_tensor_with_mean_reduction(): - seed_everything(42) result_collection = ResultCollection(True, torch.device("cpu")) product = [(True, True), (False, True), (True, False), (False, False)] + values = torch.arange(1, 10) + batches = values * values - for i in range(1, 10): - value = torch.tensor(i, dtype=torch.float) + for i, v in enumerate(values): for prog_bar in [False, True]: for logger in [False, True]: - for j, (on_step, on_epoch) in enumerate(product, 1): + for on_step, on_epoch in product: + name = "loss" + if on_step: + name += "_on_step" + if on_epoch: + name += "_on_epoch" + if prog_bar: + name += "_prog_bar" + if logger: + name += "_logger" result_collection.log( "training_step", - f"loss_{j}_{int(prog_bar)}_{int(logger)}", - value, + name, + v, on_step=on_step, on_epoch=on_epoch, - batch_size=i**2, + batch_size=batches[i], prog_bar=prog_bar, - logger=logger + logger=logger, ) - expected_values = [1, 2, 3, 4, 5, 6, 7, 8, 9] - expected_batches = [1, 4, 9, 16, 25, 36, 49, 64, 81] - total_value = sum(torch.tensor(expected_values) * torch.tensor(expected_batches)) - total_batches = sum(expected_batches) - assert result_collection["training_step.loss_1_0_0"].value == total_value - assert result_collection["training_step.loss_1_0_0"].cumulated_batch_size == total_batches + total_value = sum(values * batches) + total_batches = sum(batches) + assert result_collection["training_step.loss_on_step_on_epoch"].value == total_value + assert result_collection["training_step.loss_on_step_on_epoch"].cumulated_batch_size == total_batches batch_metrics = result_collection.metrics(True) + max_ = max(values) assert batch_metrics[MetricSource.PBAR] == { - 'loss_1_1_0_step': 9, - 'loss_3_1_0': 9, - 'loss_1_1_1_step': 9, - 'loss_3_1_1': 9 + 'loss_on_step_on_epoch_prog_bar_step': max_, + 'loss_on_step_on_epoch_prog_bar_logger_step': max_, + 'loss_on_step_prog_bar': max_, + 'loss_on_step_prog_bar_logger': max_, } assert batch_metrics[MetricSource.LOG] == { - 'loss_1_0_1_step': 9, - 'loss_3_0_1': 9, - 'loss_1_1_1_step': 9, - 'loss_3_1_1': 9, + 'loss_on_step_on_epoch_logger_step': max_, + 'loss_on_step_logger': max_, + 'loss_on_step_on_epoch_prog_bar_logger_step': max_, + 'loss_on_step_prog_bar_logger': max_, } assert batch_metrics[MetricSource.CALLBACK] == { - 'loss_1_0_0': 9, - 'loss_1_0_0_step': 9, - 'loss_3_0_0': 9, - 'loss_1_0_1': 9, - 'loss_1_0_1_step': 9, - 'loss_3_0_1': 9, - 'loss_1_1_0': 9, - 'loss_1_1_0_step': 9, - 'loss_3_1_0': 9, - 'loss_1_1_1': 9, - 'loss_1_1_1_step': 9, - 'loss_3_1_1': 9, + 'loss_on_step': max_, + 'loss_on_step_logger': max_, + 'loss_on_step_on_epoch': max_, + 'loss_on_step_on_epoch_logger': max_, + 'loss_on_step_on_epoch_logger_step': max_, + 'loss_on_step_on_epoch_prog_bar': max_, + 'loss_on_step_on_epoch_prog_bar_logger': max_, + 'loss_on_step_on_epoch_prog_bar_logger_step': max_, + 'loss_on_step_on_epoch_prog_bar_step': max_, + 'loss_on_step_on_epoch_step': max_, + 'loss_on_step_prog_bar': max_, + 'loss_on_step_prog_bar_logger': max_, } epoch_metrics = result_collection.metrics(False) mean = total_value / total_batches - assert epoch_metrics[MetricSource.PBAR] == { - 'loss_1_1_0_epoch': mean, - 'loss_2_1_0': mean, - 'loss_1_1_1_epoch': mean, - 'loss_2_1_1': mean + pbar_metrics = epoch_metrics[MetricSource.PBAR] + assert set(pbar_metrics) == { + 'loss_on_epoch_prog_bar', 'loss_on_epoch_prog_bar_logger', 'loss_on_step_on_epoch_prog_bar_epoch', + 'loss_on_step_on_epoch_prog_bar_logger_epoch' } + assert all(allclose(m, mean) for m in pbar_metrics.values()) assert epoch_metrics[MetricSource.LOG] == { - 'loss_1_0_1_epoch': mean, - 'loss_2_0_1': mean, - 'loss_1_1_1_epoch': mean, - 'loss_2_1_1': mean + 'loss_on_epoch_logger': mean, + 'loss_on_epoch_prog_bar_logger': mean, + 'loss_on_step_on_epoch_logger_epoch': mean, + 'loss_on_step_on_epoch_prog_bar_logger_epoch': mean } assert epoch_metrics[MetricSource.CALLBACK] == { - 'loss_1_0_0': mean, - 'loss_1_0_0_epoch': mean, - 'loss_2_0_0': mean, - 'loss_1_0_1': mean, - 'loss_1_0_1_epoch': mean, - 'loss_2_0_1': mean, - 'loss_1_1_0': mean, - 'loss_1_1_0_epoch': mean, - 'loss_2_1_0': mean, - 'loss_1_1_1': mean, - 'loss_1_1_1_epoch': mean, - 'loss_2_1_1': mean, + 'loss_on_epoch': mean, + 'loss_on_epoch_logger': mean, + 'loss_on_epoch_prog_bar': mean, + 'loss_on_epoch_prog_bar_logger': mean, + 'loss_on_step_on_epoch': mean, + 'loss_on_step_on_epoch_epoch': mean, + 'loss_on_step_on_epoch_logger': mean, + 'loss_on_step_on_epoch_logger_epoch': mean, + 'loss_on_step_on_epoch_prog_bar': mean, + 'loss_on_step_on_epoch_prog_bar_epoch': mean, + 'loss_on_step_on_epoch_prog_bar_logger': mean, + 'loss_on_step_on_epoch_prog_bar_logger_epoch': mean } From 0475f42556d02fcf6ab60dc48d45e52f1a5f4680 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 8 Jun 2021 15:15:02 +0200 Subject: [PATCH 13/24] Fix test? --- tests/trainer/logging_/test_logger_connector.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index b9d8f498a2a7e..f8f9218a2f189 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -758,8 +758,10 @@ def test_result_collection_on_tensor_with_mean_reduction(): mean = total_value / total_batches pbar_metrics = epoch_metrics[MetricSource.PBAR] assert set(pbar_metrics) == { - 'loss_on_epoch_prog_bar', 'loss_on_epoch_prog_bar_logger', 'loss_on_step_on_epoch_prog_bar_epoch', - 'loss_on_step_on_epoch_prog_bar_logger_epoch' + 'loss_on_epoch_prog_bar', + 'loss_on_epoch_prog_bar_logger', + 'loss_on_step_on_epoch_prog_bar_epoch', + 'loss_on_step_on_epoch_prog_bar_logger_epoch', } assert all(allclose(m, mean) for m in pbar_metrics.values()) assert epoch_metrics[MetricSource.LOG] == { From db61789af550fb0d94ce64b1be0a8bb233d7c837 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 8 Jun 2021 15:16:25 +0200 Subject: [PATCH 14/24] Minor change --- tests/trainer/logging_/test_logger_connector.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index f8f9218a2f189..aea55bb82f847 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -18,9 +18,9 @@ from typing import Any, Callable from unittest import mock +import numpy as np import pytest import torch -from numpy.core import allclose from torch.utils.data import DataLoader from torchmetrics import Accuracy, AveragePrecision @@ -763,7 +763,8 @@ def test_result_collection_on_tensor_with_mean_reduction(): 'loss_on_step_on_epoch_prog_bar_epoch', 'loss_on_step_on_epoch_prog_bar_logger_epoch', } - assert all(allclose(m, mean) for m in pbar_metrics.values()) + # pbar metrics are converted to float, need to check with `allclose` + assert all(np.allclose(m, mean) for m in pbar_metrics.values()) assert epoch_metrics[MetricSource.LOG] == { 'loss_on_epoch_logger': mean, 'loss_on_epoch_prog_bar_logger': mean, From f0d81f92a0fd42e990d68412dac96b45e65e1be2 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 8 Jun 2021 15:23:39 +0200 Subject: [PATCH 15/24] Minor changes --- tests/core/test_metric_result_integration.py | 39 +++++++++----------- 1 file changed, 18 insertions(+), 21 deletions(-) diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 9aa1c4db21d73..85bc0c44f6d4c 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -160,24 +160,21 @@ def lightning_log(fx, *args, **kwargs): lightning_log('d0', 'a', torch.tensor(3.) + epoch, on_step=False, on_epoch=True) lightning_log('d1', 'a', torch.tensor(3.) + epoch, on_step=False, on_epoch=True) - assert result['a0.a'].value == torch.tensor(0.) - assert result['a0.a'].cumulated_batch_size == torch.tensor(1.) - assert result['a1.a'].value == torch.tensor(0.) - assert result['a1.a'].cumulated_batch_size == torch.tensor(1.) - - assert result['b0.a'].value == torch.tensor(1.) + epoch - assert result['b0.a'].cumulated_batch_size == torch.tensor(1.) - assert result['b1.a'].value == torch.tensor(1.) + epoch - assert result['b1.a'].cumulated_batch_size == torch.tensor(1.) - - assert result['c0.a'].value == torch.tensor(4.) + epoch * 2 - assert result['c0.a'].cumulated_batch_size == torch.tensor(2.) - assert result['c1.a'].value == torch.tensor(4.) + epoch * 2 - assert result['c1.a'].cumulated_batch_size == torch.tensor(2.) - assert result['c2.a'].value == torch.tensor(4.) + epoch * 2 - assert result['c2.a'].cumulated_batch_size == torch.tensor(2.) - - assert result['d0.a'].value == torch.tensor(3.) + epoch - assert result['d0.a'].cumulated_batch_size == torch.tensor(1.) - assert result['d1.a'].value == torch.tensor(3.) + epoch - assert result['d1.a'].cumulated_batch_size == torch.tensor(1.) + for k in ( + 'a0.a', + 'a1.a', + ): + assert result[k].value == torch.tensor(0.), k + assert result[k].cumulated_batch_size == torch.tensor(1.), k + + for k in ('b0.a', 'b1.a'): + assert result[k].value == torch.tensor(1.) + epoch, k + assert result[k].cumulated_batch_size == torch.tensor(1.), k + + for k in ('c0.a', 'c1.a', 'c2.a'): + assert result[k].value == torch.tensor(4.) + epoch * 2, k + assert result[k].cumulated_batch_size == torch.tensor(2.), k + + for k in ('d0.a', 'd1.a'): + assert result[k].value == torch.tensor(3.) + epoch, k + assert result[k].cumulated_batch_size == torch.tensor(1.), k From c87aaec89d198191eefeec8f9e4b8ae4ff5a7a63 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 8 Jun 2021 15:23:55 +0200 Subject: [PATCH 16/24] Minor changes --- tests/core/test_metric_result_integration.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 85bc0c44f6d4c..8a636a0b15dd1 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -160,10 +160,7 @@ def lightning_log(fx, *args, **kwargs): lightning_log('d0', 'a', torch.tensor(3.) + epoch, on_step=False, on_epoch=True) lightning_log('d1', 'a', torch.tensor(3.) + epoch, on_step=False, on_epoch=True) - for k in ( - 'a0.a', - 'a1.a', - ): + for k in ('a0.a', 'a1.a'): assert result[k].value == torch.tensor(0.), k assert result[k].cumulated_batch_size == torch.tensor(1.), k From 284b247112868e541af1d61e23e02eacec986c31 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 8 Jun 2021 16:33:22 +0200 Subject: [PATCH 17/24] Force float --- tests/trainer/logging_/test_logger_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index aea55bb82f847..1f6ecda82a72f 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -693,7 +693,7 @@ def _assert_called(model, stage): def test_result_collection_on_tensor_with_mean_reduction(): result_collection = ResultCollection(True, torch.device("cpu")) product = [(True, True), (False, True), (True, False), (False, False)] - values = torch.arange(1, 10) + values = torch.arange(1, 10).float() batches = values * values for i, v in enumerate(values): From dac213f0f2433f6ea81cfe53ce9883670f3cc45d Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 8 Jun 2021 16:41:16 +0200 Subject: [PATCH 18/24] Fix minimal bug --- tests/trainer/logging_/test_logger_connector.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 1f6ecda82a72f..2d5decc572a1e 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -18,7 +18,6 @@ from typing import Any, Callable from unittest import mock -import numpy as np import pytest import torch from torch.utils.data import DataLoader @@ -693,7 +692,7 @@ def _assert_called(model, stage): def test_result_collection_on_tensor_with_mean_reduction(): result_collection = ResultCollection(True, torch.device("cpu")) product = [(True, True), (False, True), (True, False), (False, False)] - values = torch.arange(1, 10).float() + values = torch.arange(1, 10).float() # need to convert to float() due to precision issues using torch 1.4 batches = values * values for i, v in enumerate(values): @@ -757,14 +756,12 @@ def test_result_collection_on_tensor_with_mean_reduction(): epoch_metrics = result_collection.metrics(False) mean = total_value / total_batches pbar_metrics = epoch_metrics[MetricSource.PBAR] - assert set(pbar_metrics) == { - 'loss_on_epoch_prog_bar', - 'loss_on_epoch_prog_bar_logger', - 'loss_on_step_on_epoch_prog_bar_epoch', - 'loss_on_step_on_epoch_prog_bar_logger_epoch', + assert pbar_metrics == { + 'loss_on_epoch_prog_bar': mean, + 'loss_on_epoch_prog_bar_logger': mean, + 'loss_on_step_on_epoch_prog_bar_epoch': mean, + 'loss_on_step_on_epoch_prog_bar_logger_epoch': mean, } - # pbar metrics are converted to float, need to check with `allclose` - assert all(np.allclose(m, mean) for m in pbar_metrics.values()) assert epoch_metrics[MetricSource.LOG] == { 'loss_on_epoch_logger': mean, 'loss_on_epoch_prog_bar_logger': mean, From f94969c8b3a628c33297c13e3e903c66b7aa47eb Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 8 Jun 2021 16:42:03 +0200 Subject: [PATCH 19/24] Fix minimal bug --- tests/trainer/logging_/test_logger_connector.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 2d5decc572a1e..dced8ebcddd61 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -755,8 +755,7 @@ def test_result_collection_on_tensor_with_mean_reduction(): epoch_metrics = result_collection.metrics(False) mean = total_value / total_batches - pbar_metrics = epoch_metrics[MetricSource.PBAR] - assert pbar_metrics == { + assert epoch_metrics[MetricSource.PBAR] == { 'loss_on_epoch_prog_bar': mean, 'loss_on_epoch_prog_bar_logger': mean, 'loss_on_step_on_epoch_prog_bar_epoch': mean, From 9ee57ca3b8f84475adb65c817947b9c021c3b8c5 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 8 Jun 2021 17:50:24 +0200 Subject: [PATCH 20/24] Update with latest changes --- .../logger_connector/logger_connector_new.py | 109 +++++++++--------- .../connectors/logger_connector/result_new.py | 8 +- pytorch_lightning/utilities/metrics.py | 26 ++--- .../trainer/logging_/test_logger_connector.py | 2 +- 4 files changed, 67 insertions(+), 78 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py index 31df06f6348e1..c1d85e4a84b27 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py @@ -20,15 +20,14 @@ import pytorch_lightning as pl from pytorch_lightning.core import memory from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger -from pytorch_lightning.trainer.connectors.logger_connector.result_new import _METRIC, MetricSource +from pytorch_lightning.trainer.connectors.logger_connector.result import _METRIC, MetricSource from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.utilities import DeviceType from pytorch_lightning.utilities.metrics import metrics_to_scalars from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT -# TODO(@carmocca): Remove `New` suffix -class LoggerConnectorNew: +class LoggerConnector: def __init__(self, trainer: 'pl.Trainer', log_gpu_memory: Optional[str] = None) -> None: self.trainer = trainer @@ -120,44 +119,7 @@ def log_metrics(self, metrics: Dict[str, _METRIC], step: Optional[int] = None) - Evaluation metric updates """ - @property - def _eval_log_step(self) -> Optional[int]: - if self.trainer.state.stage is RunningStage.VALIDATING: - return self._val_log_step - elif self.trainer.state.stage is RunningStage.TESTING: - return self._test_log_step - else: - return None - - def _increment_eval_log_step(self) -> None: - if self.trainer.state.stage is RunningStage.VALIDATING: - self._val_log_step += 1 - elif self.trainer.state.stage is RunningStage.TESTING: - self._test_log_step += 1 - - def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int, num_dataloaders: int) -> None: - model = self.trainer.lightning_module - # set dataloader_idx only if multiple ones - model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None - - # track batch_size - self.trainer.result_collection.extract_batch_size(batch) - self._batch_idx = batch_idx - - def update_eval_step_metrics(self) -> None: - if self.trainer.sanity_checking: - return - - # logs user requested information to logger - assert not self._epoch_end_reached - metrics = self.metrics[MetricSource.LOG] - if metrics: - self.log_metrics(metrics, step=self._eval_log_step) - - # increment the step even if nothing was logged - self._increment_eval_log_step() - - def _prepare_eval_loop_results(self, metrics: Mapping[str, _METRIC]) -> None: + def prepare_eval_loop_results(self, metrics: Mapping[str, _METRIC]) -> None: if self.trainer.sanity_checking: return @@ -174,7 +136,7 @@ def _prepare_eval_loop_results(self, metrics: Mapping[str, _METRIC]) -> None: else: self.eval_loop_results.append(callback_metrics) - def update_eval_epoch_metrics(self) -> _EVALUATE_OUTPUT: + def get_evaluate_epoch_results(self) -> _EVALUATE_OUTPUT: assert self._epoch_end_reached metrics = self.metrics @@ -184,7 +146,7 @@ def update_eval_epoch_metrics(self) -> _EVALUATE_OUTPUT: if log_metrics: self.log_metrics(log_metrics) - self._prepare_eval_loop_results(metrics[MetricSource.CALLBACK]) + self.prepare_eval_loop_results(metrics[MetricSource.CALLBACK]) # log results of evaluation if ( @@ -205,12 +167,49 @@ def update_eval_epoch_metrics(self) -> _EVALUATE_OUTPUT: self.eval_loop_results = [] return results + @property + def evaluation_log_step(self) -> Optional[int]: + if self.trainer.state.stage is RunningStage.VALIDATING: + return self._val_log_step + elif self.trainer.state.stage is RunningStage.TESTING: + return self._test_log_step + else: + return None + + def increment_evaluation_log_step(self) -> None: + if self.trainer.state.stage is RunningStage.VALIDATING: + self._val_log_step += 1 + elif self.trainer.state.stage is RunningStage.TESTING: + self._test_log_step += 1 + + def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int, num_dataloaders: int) -> None: + model = self.trainer.lightning_module + # set dataloader_idx only if multiple ones + model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None + + # track batch_size + self.trainer.results.extract_batch_size(batch) + self._batch_idx = batch_idx + + def update_evaluation_step_metrics(self) -> None: + if self.trainer.sanity_checking: + return + + # logs user requested information to logger + assert not self._epoch_end_reached + metrics = self.metrics[MetricSource.LOG] + if metrics: + self.log_metrics(metrics, step=self.evaluation_log_step) + + # increment the step even if nothing was logged + self.increment_evaluation_log_step() + """ Train metric updates """ def on_train_split_start(self, batch_idx: int, split_idx: int, split_batch: Any) -> None: - self.trainer.result_collection.extract_batch_size(split_batch) + self.trainer.results.extract_batch_size(split_batch) self._batch_idx = batch_idx self._split_idx = split_idx @@ -232,12 +231,7 @@ def update_train_epoch_metrics(self) -> None: self.log_metrics(metrics) # reset result collection for next epoch - self.trainer.result_collection.reset(metrics=True) - - def teardown(self): - self.trainer.train_loop.train_results.cpu() - self.trainer.evaluation_loop.validation_results.cpu() - self.trainer.evaluation_loop.test_results.cpu() + self.trainer.results.reset(metrics=True) """ Utilities and properties @@ -278,7 +272,7 @@ def should_reset_tensors(self, fx: str) -> bool: return is_different_fx and is_first_batch def reset(self, metrics: Optional[bool] = None) -> None: - self.trainer.result_collection.reset(metrics=metrics) + self.trainer.results.reset(metrics=metrics) self._batch_idx = None self._split_idx = None self._current_fx = None @@ -287,25 +281,30 @@ def reset(self, metrics: Optional[bool] = None) -> None: def metrics(self) -> Dict[MetricSource, Dict[str, _METRIC]]: """This function returns either batch or epoch metrics depending on ``_epoch_end_reached``.""" on_step = not self._epoch_end_reached - return self.trainer.result_collection.metrics(on_step) + return self.trainer.results.metrics(on_step) @property def callback_metrics(self) -> Dict[str, _METRIC]: - if self.trainer.result_collection: + if self.trainer.results: metrics = self.metrics[MetricSource.CALLBACK] self._callback_metrics.update(metrics) return self._callback_metrics @property def logged_metrics(self) -> Dict[str, _METRIC]: - if self.trainer.result_collection: + if self.trainer.results: metrics = self.metrics[MetricSource.LOG] self._logged_metrics.update(metrics) return self._logged_metrics @property def progress_bar_metrics(self) -> Dict[str, float]: - if self.trainer.result_collection: + if self.trainer.results: metrics = self.metrics[MetricSource.PBAR] self._progress_bar_metrics.update(metrics) return self._progress_bar_metrics + + def teardown(self): + self.trainer.train_loop.results.cpu() + self.trainer.evaluation_loop._val_results.cpu() + self.trainer.evaluation_loop._test_results.cpu() diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result_new.py b/pytorch_lightning/trainer/connectors/logger_connector/result_new.py index 40481dd9afb68..03a7c78e11175 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result_new.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result_new.py @@ -25,6 +25,7 @@ from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin from pytorch_lightning.utilities.enums import LightningEnum from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.metrics import metrics_to_scalars # re-define the ones from pytorch_lightning.utilities.types without the `Number` type _METRIC = Union[Metric, torch.Tensor] @@ -370,10 +371,6 @@ def _get_cache(result_metric: ResultMetric, on_step: bool) -> Optional[torch.Ten return cache.detach() return cache - @staticmethod - def __to_item(t: torch.Tensor) -> float: - return t.item() - def valid_items(self) -> Generator: """This function is used to iterate over current valid metrics.""" return ((k, v) for k, v in self.items() @@ -421,8 +418,7 @@ def any_tensor(_): # populate progress_bar metrics. convert tensors to numbers if result_metric.meta.prog_bar: - value = apply_to_collection(value, torch.Tensor, self.__to_item, include_none=False) - metrics[MetricSource.PBAR][forked_name] = value + metrics[MetricSource.PBAR][forked_name] = metrics_to_scalars(value) return metrics diff --git a/pytorch_lightning/utilities/metrics.py b/pytorch_lightning/utilities/metrics.py index bd57470dc270e..8433e9e370640 100644 --- a/pytorch_lightning/utilities/metrics.py +++ b/pytorch_lightning/utilities/metrics.py @@ -12,29 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. """Helper functions to operate on metric values. """ +import numbers import torch +from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.exceptions import MisconfigurationException def metrics_to_scalars(metrics: dict) -> dict: """ Recursively walk through a dictionary of metrics and convert single-item tensors to scalar values. """ - # TODO: this is duplicated in MetricsHolder. should be unified - new_metrics = {} - for k, v in metrics.items(): - if isinstance(v, torch.Tensor): - if v.numel() != 1: - raise MisconfigurationException( - f"The metric `{k}` does not contain a single element" - f" thus it cannot be converted to float. Found `{v}`" - ) - v = v.item() + def to_item(value: torch.Tensor) -> numbers.Number: + if value.numel() != 1: + raise MisconfigurationException( + f"The metric `{value}` does not contain a single element" + f" thus it cannot be converted to float." + ) + return value.item() - if isinstance(v, dict): - v = metrics_to_scalars(v) - - new_metrics[k] = v - - return new_metrics + return apply_to_collection(metrics, torch.Tensor, to_item) diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index dced8ebcddd61..b7ee13fed6d8e 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -488,7 +488,7 @@ def test_step(self, *args, **kwargs): trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) - match = "The metric `test` does not contain a single element" + match = "The metric `.*` does not contain a single element" with pytest.raises(MisconfigurationException, match=match): trainer.validate(model) with pytest.raises(MisconfigurationException, match=match): From a2eed6d6c17c78fc58b8cb87c9d544485417a7a5 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 8 Jun 2021 18:03:16 +0200 Subject: [PATCH 21/24] Fix import --- .../trainer/connectors/logger_connector/logger_connector_new.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py index c1d85e4a84b27..d834806cc1952 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py @@ -20,7 +20,7 @@ import pytorch_lightning as pl from pytorch_lightning.core import memory from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger -from pytorch_lightning.trainer.connectors.logger_connector.result import _METRIC, MetricSource +from pytorch_lightning.trainer.connectors.logger_connector.result_new import _METRIC, MetricSource from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.utilities import DeviceType from pytorch_lightning.utilities.metrics import metrics_to_scalars From 1638bba8b54d237426560a55abb075996548e8a4 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 8 Jun 2021 18:07:05 +0200 Subject: [PATCH 22/24] bad merge --- .../logger_connector/logger_connector_new.py | 83 ++++++++++--------- 1 file changed, 42 insertions(+), 41 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py index d834806cc1952..068703c210f84 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py @@ -27,7 +27,8 @@ from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT -class LoggerConnector: +# TODO(@carmocca): Remove `New` suffix +class LoggerConnectorNew: def __init__(self, trainer: 'pl.Trainer', log_gpu_memory: Optional[str] = None) -> None: self.trainer = trainer @@ -119,7 +120,44 @@ def log_metrics(self, metrics: Dict[str, _METRIC], step: Optional[int] = None) - Evaluation metric updates """ - def prepare_eval_loop_results(self, metrics: Mapping[str, _METRIC]) -> None: + @property + def _eval_log_step(self) -> Optional[int]: + if self.trainer.state.stage is RunningStage.VALIDATING: + return self._val_log_step + elif self.trainer.state.stage is RunningStage.TESTING: + return self._test_log_step + else: + return None + + def _increment_eval_log_step(self) -> None: + if self.trainer.state.stage is RunningStage.VALIDATING: + self._val_log_step += 1 + elif self.trainer.state.stage is RunningStage.TESTING: + self._test_log_step += 1 + + def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int, num_dataloaders: int) -> None: + model = self.trainer.lightning_module + # set dataloader_idx only if multiple ones + model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None + + # track batch_size + self.trainer.result_collection.extract_batch_size(batch) + self._batch_idx = batch_idx + + def update_eval_step_metrics(self) -> None: + if self.trainer.sanity_checking: + return + + # logs user requested information to logger + assert not self._epoch_end_reached + metrics = self.metrics[MetricSource.LOG] + if metrics: + self.log_metrics(metrics, step=self._eval_log_step) + + # increment the step even if nothing was logged + self._increment_eval_log_step() + + def _prepare_eval_loop_results(self, metrics: Mapping[str, _METRIC]) -> None: if self.trainer.sanity_checking: return @@ -136,7 +174,7 @@ def prepare_eval_loop_results(self, metrics: Mapping[str, _METRIC]) -> None: else: self.eval_loop_results.append(callback_metrics) - def get_evaluate_epoch_results(self) -> _EVALUATE_OUTPUT: + def update_eval_epoch_metrics(self) -> _EVALUATE_OUTPUT: assert self._epoch_end_reached metrics = self.metrics @@ -146,7 +184,7 @@ def get_evaluate_epoch_results(self) -> _EVALUATE_OUTPUT: if log_metrics: self.log_metrics(log_metrics) - self.prepare_eval_loop_results(metrics[MetricSource.CALLBACK]) + self._prepare_eval_loop_results(metrics[MetricSource.CALLBACK]) # log results of evaluation if ( @@ -167,43 +205,6 @@ def get_evaluate_epoch_results(self) -> _EVALUATE_OUTPUT: self.eval_loop_results = [] return results - @property - def evaluation_log_step(self) -> Optional[int]: - if self.trainer.state.stage is RunningStage.VALIDATING: - return self._val_log_step - elif self.trainer.state.stage is RunningStage.TESTING: - return self._test_log_step - else: - return None - - def increment_evaluation_log_step(self) -> None: - if self.trainer.state.stage is RunningStage.VALIDATING: - self._val_log_step += 1 - elif self.trainer.state.stage is RunningStage.TESTING: - self._test_log_step += 1 - - def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int, num_dataloaders: int) -> None: - model = self.trainer.lightning_module - # set dataloader_idx only if multiple ones - model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None - - # track batch_size - self.trainer.results.extract_batch_size(batch) - self._batch_idx = batch_idx - - def update_evaluation_step_metrics(self) -> None: - if self.trainer.sanity_checking: - return - - # logs user requested information to logger - assert not self._epoch_end_reached - metrics = self.metrics[MetricSource.LOG] - if metrics: - self.log_metrics(metrics, step=self.evaluation_log_step) - - # increment the step even if nothing was logged - self.increment_evaluation_log_step() - """ Train metric updates """ From 48249f1e38ee7a2f284e89a410e0e49642021372 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 8 Jun 2021 19:32:52 +0000 Subject: [PATCH 23/24] update typing --- .../trainer/connectors/logger_connector/logger_connector.py | 2 +- .../trainer/connectors/logger_connector/logger_connector_new.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index a05e273ece8aa..93bcc5527643e 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -141,7 +141,7 @@ def should_update_logs(self) -> bool: should_log_every_n_steps = (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0 return should_log_every_n_steps or self.trainer.should_stop - def configure_logger(self, logger: LightningLoggerBase) -> None: + def configure_logger(self, logger: Union[bool, Iterable, LightningLoggerBase]) -> None: if logger is True: version = os.environ.get('PL_EXP_VERSION', self.trainer.slurm_job_id) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py index 068703c210f84..c0fea406ae019 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py @@ -66,7 +66,7 @@ def should_update_logs(self) -> bool: should_log_every_n_steps = (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0 return should_log_every_n_steps or self.trainer.should_stop - def configure_logger(self, logger: LightningLoggerBase) -> None: + def configure_logger(self, logger: Union[bool, Iterable, LightningLoggerBase]) -> None: if logger is True: version = os.environ.get('PL_EXP_VERSION', self.trainer.slurm_job_id) From e2c8c6472c32c28307044b0e67e304dc32c77fe7 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 8 Jun 2021 19:34:57 +0000 Subject: [PATCH 24/24] missing typing --- .../trainer/connectors/logger_connector/logger_connector_new.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py index c0fea406ae019..069a7f5183c70 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py @@ -13,7 +13,7 @@ # limitations under the License. import os from pprint import pprint -from typing import Any, Dict, Iterable, Mapping, Optional +from typing import Any, Dict, Iterable, Mapping, Optional, Union import torch