From b225889b34b83272117b758cbc28772a5c2356d9 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Sat, 5 Feb 2022 09:53:33 +0100 Subject: [PATCH] Remove lightning legacy code and references (#788) * remove lightning * more removal * fix test * flake8 * Apply suggestions from code review Co-authored-by: Jirka Borovec Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> Co-authored-by: Aki Nitta --- .github/CONTRIBUTING.md | 2 +- .github/PULL_REQUEST_TEMPLATE.md | 2 +- CHANGELOG.md | 3 ++ docs/paper_JOSS/paper.md | 2 +- docs/source/governance.rst | 5 +- docs/source/pages/implement.rst | 2 +- docs/source/pages/lightning.rst | 4 ++ integrations/__init__.py | 5 -- integrations/test_lightning.py | 80 ++++++++++++++--------------- requirements/devel.txt | 3 -- requirements/docs.txt | 2 +- requirements/integrate.txt | 2 +- tests/bases/test_metric.py | 9 ++-- tests/helpers/__init__.py | 5 +- tests/helpers/testers.py | 18 +++---- tests/regression/test_mean_error.py | 4 -- tests/retrieval/helpers.py | 4 +- tests/text/helpers.py | 18 +++---- torchmetrics/__about__.py | 2 +- torchmetrics/metric.py | 7 +-- torchmetrics/utilities/enums.py | 2 +- torchmetrics/utilities/imports.py | 1 - 22 files changed, 83 insertions(+), 99 deletions(-) diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index cab299a3201..4bc9adb39c5 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -40,7 +40,7 @@ help you or finish it with you :\]_ 1. Add/update the relevant tests! -- [This PR](https://github.com/PyTorchLightning/pytorch-lightning/pull/5241) is a good example for adding a new metric +- [This PR](https://github.com/PyTorchLightning/metrics/pull/98) is a good example for adding a new metric ### Test cases: diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 8be4ac32441..71c290cf277 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -5,7 +5,7 @@ Fixes #\ ## Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos and docs improvements) -- [ ] Did you read the [contributor guideline](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/.github/CONTRIBUTING.md), Pull Request section? +- [ ] Did you read the [contributor guideline](https://github.com/PyTorchLightning/metrics/blob/master/.github/CONTRIBUTING.md), Pull Request section? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? diff --git a/CHANGELOG.md b/CHANGELOG.md index 93d2388bc2e..8ffd8e127a8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Removed +- Removed support for versions of Lightning lower than v1.5 ([#788](https://github.com/PyTorchLightning/metrics/pull/788)) + + - Removed deprecated functions, and warnings in Text ([#773](https://github.com/PyTorchLightning/metrics/pull/773)) * `functional.wer` * `WER` diff --git a/docs/paper_JOSS/paper.md b/docs/paper_JOSS/paper.md index e3405c1e40f..11e65bf36d9 100644 --- a/docs/paper_JOSS/paper.md +++ b/docs/paper_JOSS/paper.md @@ -99,7 +99,7 @@ In addition to stateful metrics (called modular metrics in TorchMetrics), we als TorchMetrics exhibits high test coverage on the various configurations, including all three major OS platforms (Linux, macOS, and Windows), and various Python, CUDA, and PyTorch versions. We test both minimum and latest package requirements for all combinations of OS and Python versions and include additional tests for each PyTorch version from 1.3 up to future development versions. On every pull request and merge to master, we run a full test suite. All standard tests run on CPU. In addition, we run all tests on a multi-GPU setting which reflects realistic Deep Learning workloads. For usability, we have auto-generated HTML documentation (hosted at [readthedocs](https://torchmetrics.readthedocs.io/en/stable/)) from the source code which updates in real-time with new merged pull requests. -TorchMetrics is released under the Apache 2.0 license. The source code is available at https://github.com/PytorchLightning/metrics. +TorchMetrics is released under the Apache 2.0 license. The source code is available at https://github.com/PyTorchLightning/metrics. # Acknowledgement diff --git a/docs/source/governance.rst b/docs/source/governance.rst index de338d5f01a..0769dae0882 100644 --- a/docs/source/governance.rst +++ b/docs/source/governance.rst @@ -41,8 +41,9 @@ Project Management and Decision Making ************************************** The decision what goes into a release is governed by the :ref:`staff contributors and leaders ` of -Lightning development. Whenever possible, discussion happens publicly on GitHub and includes the whole community. -When a consensus is reached, staff and core contributors assign milestones and labels to the issue and/or pull request and start tracking the development. It is possible that priorities change over time. +TorchMetrics development. Whenever possible, discussion happens publicly on GitHub and includes the whole community. +When a consensus is reached, staff and core contributors assign milestones and labels to the issue and/or pull request +and start tracking the development. It is possible that priorities change over time. Commits to the project are exclusively to be added by pull requests on GitHub and anyone in the community is welcome to review them. However, reviews submitted by diff --git a/docs/source/pages/implement.rst b/docs/source/pages/implement.rst index ead4e3a1f7e..417b73edac6 100644 --- a/docs/source/pages/implement.rst +++ b/docs/source/pages/implement.rst @@ -44,7 +44,7 @@ Internal implementation details ------------------------------- This section briefly describes how metrics work internally. We encourage looking at the source code for more info. -Internally, Lightning wraps the user defined ``update()`` and ``compute()`` method. We do this to automatically +Internally, TorchMetrics wraps the user defined ``update()`` and ``compute()`` method. We do this to automatically synchronize and reduce metric states across multiple devices. More precisely, calling ``update()`` does the following internally: diff --git a/docs/source/pages/lightning.rst b/docs/source/pages/lightning.rst index f9349fc9969..55633f4a415 100644 --- a/docs/source/pages/lightning.rst +++ b/docs/source/pages/lightning.rst @@ -11,6 +11,10 @@ TorchMetrics in PyTorch Lightning TorchMetrics was originaly created as part of `PyTorch Lightning `_, a powerful deep learning research framework designed for scaling models without boilerplate. +..note:: + TorchMetrics always offers compatibility with the last 2 major PyTorch Lightning versions, but we recommend to always keep both frameworks + up-to-date for the best experience. + While TorchMetrics was built to be used with native PyTorch, using TorchMetrics with Lightning offers additional benefits: * Module metrics are automatically placed on the correct device when properly defined inside a LightningModule. This means that your data will always be placed on the same device as your metrics. diff --git a/integrations/__init__.py b/integrations/__init__.py index c882b811c0f..382976fecb2 100644 --- a/integrations/__init__.py +++ b/integrations/__init__.py @@ -1,10 +1,5 @@ -import operator import os -from torchmetrics.utilities.imports import _compare_version - _INTEGRATION_ROOT = os.path.realpath(os.path.dirname(__file__)) _PACKAGE_ROOT = os.path.dirname(_INTEGRATION_ROOT) _PATH_DATASETS = os.path.join(_PACKAGE_ROOT, "datasets") - -_LIGHTNING_GREATER_EQUAL_1_3 = _compare_version("pytorch_lightning", operator.ge, "1.3.0") diff --git a/integrations/test_lightning.py b/integrations/test_lightning.py index 379153a5423..226faee7e72 100644 --- a/integrations/test_lightning.py +++ b/integrations/test_lightning.py @@ -13,14 +13,12 @@ # limitations under the License. from unittest import mock -import pytest import torch from pytorch_lightning import LightningModule, Trainer from torch import tensor from torch.utils.data import DataLoader from integrations.lightning.boring_model import BoringModel, RandomDataset -from tests.helpers import _LIGHTNING_GREATER_EQUAL_1_3 from torchmetrics import Accuracy, AveragePrecision, MetricCollection, SumMetric @@ -63,7 +61,6 @@ def training_epoch_end(self, outs): trainer.fit(model) -@pytest.mark.skipif(not _LIGHTNING_GREATER_EQUAL_1_3, reason="test requires lightning v1.3 or higher") def test_metrics_reset(tmpdir): """Tests that metrics are reset correctly after the end of the train/val/test epoch. @@ -222,6 +219,8 @@ def training_epoch_end(self, outs): def test_metric_collection_lightning_log(tmpdir): + """Test that MetricCollection works with Lightning modules.""" + class TestModel(BoringModel): def __init__(self): super().__init__() @@ -258,40 +257,41 @@ def training_epoch_end(self, outputs): assert torch.allclose(tensor(logged["DiffMetric_epoch"]), model.diff) -# todo: need to be fixed -# def test_scriptable(tmpdir): -# class TestModel(BoringModel): -# def __init__(self): -# super().__init__() -# # the metric is not used in the module's `forward` -# # so the module should be exportable to TorchScript -# self.metric = SumMetric() -# self.sum = 0.0 -# -# def training_step(self, batch, batch_idx): -# x = batch -# self.metric(x.sum()) -# self.sum += x.sum() -# self.log("sum", self.metric, on_epoch=True, on_step=False) -# return self.step(x) -# -# model = TestModel() -# trainer = Trainer( -# default_root_dir=tmpdir, -# limit_train_batches=2, -# limit_val_batches=2, -# max_epochs=1, -# log_every_n_steps=1, -# weights_summary=None, -# logger=False, -# checkpoint_callback=False, -# ) -# trainer.fit(model) -# rand_input = torch.randn(10, 32) -# -# script_model = model.to_torchscript() -# -# # test that we can still do inference -# output = model(rand_input) -# script_output = script_model(rand_input) -# assert torch.allclose(output, script_output) +def test_scriptable(tmpdir): + """Test that lightning modules can still be scripted even if metrics cannot.""" + + class TestModel(BoringModel): + def __init__(self): + super().__init__() + # the metric is not used in the module's `forward` + # so the module should be exportable to TorchScript + self.metric = SumMetric() + self.sum = 0.0 + + def training_step(self, batch, batch_idx): + x = batch + self.metric(x.sum()) + self.sum += x.sum() + self.log("sum", self.metric, on_epoch=True, on_step=False) + return self.step(x) + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, + logger=False, + checkpoint_callback=False, + ) + trainer.fit(model) + rand_input = torch.randn(10, 32) + + script_model = model.to_torchscript() + + # test that we can still do inference + output = model(rand_input) + script_output = script_model(rand_input) + assert torch.allclose(output, script_output) diff --git a/requirements/devel.txt b/requirements/devel.txt index fa8c352ded7..2982b7de3ce 100644 --- a/requirements/devel.txt +++ b/requirements/devel.txt @@ -14,6 +14,3 @@ -r image_test.txt -r text_test.txt -r audio_test.txt - -# add the integration dependencies -#-r integrate.txt diff --git a/requirements/docs.txt b/requirements/docs.txt index 6b507a0402b..eb521689b83 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -12,4 +12,4 @@ sphinx-togglebutton>=0.2 sphinx-copybutton>=0.3 # integrations -pytorch-lightning>=1.1 +-r integrate.txt diff --git a/requirements/integrate.txt b/requirements/integrate.txt index 3acf3f8e78a..3cc0367710e 100644 --- a/requirements/integrate.txt +++ b/requirements/integrate.txt @@ -1 +1 @@ -pytorch-lightning>=1.3 +pytorch-lightning>=1.5 diff --git a/tests/bases/test_metric.py b/tests/bases/test_metric.py index 5d65f14f202..3b9f3635470 100644 --- a/tests/bases/test_metric.py +++ b/tests/bases/test_metric.py @@ -20,9 +20,9 @@ import torch from torch import Tensor, nn, tensor -from tests.helpers import _LIGHTNING_GREATER_EQUAL_1_3, seed_all +from tests.helpers import seed_all from tests.helpers.testers import DummyListMetric, DummyMetric, DummyMetricMultiOutput, DummyMetricSum -from torchmetrics.utilities.imports import _LIGHTNING_AVAILABLE, _TORCH_LOWER_1_6 +from torchmetrics.utilities.imports import _TORCH_LOWER_1_6 seed_all(42) @@ -100,10 +100,7 @@ def test_reset_compute(): a.update(tensor(5)) assert a.compute() == 5 a.reset() - if not _LIGHTNING_AVAILABLE or _LIGHTNING_GREATER_EQUAL_1_3: - assert a.compute() == 0 - else: - assert a.compute() == 5 + assert a.compute() == 0 def test_update(): diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py index d7e86d30644..0e0b4862415 100644 --- a/tests/helpers/__init__.py +++ b/tests/helpers/__init__.py @@ -1,17 +1,14 @@ -import operator import random import numpy import torch -from torchmetrics.utilities.imports import _TORCH_LOWER_1_4, _TORCH_LOWER_1_5, _TORCH_LOWER_1_6, _compare_version +from torchmetrics.utilities.imports import _TORCH_LOWER_1_4, _TORCH_LOWER_1_5, _TORCH_LOWER_1_6 _MARK_TORCH_MIN_1_4 = dict(condition=_TORCH_LOWER_1_4, reason="required PT >= 1.4") _MARK_TORCH_MIN_1_5 = dict(condition=_TORCH_LOWER_1_5, reason="required PT >= 1.5") _MARK_TORCH_MIN_1_6 = dict(condition=_TORCH_LOWER_1_6, reason="required PT >= 1.6") -_LIGHTNING_GREATER_EQUAL_1_3 = _compare_version("pytorch_lightning", operator.ge, "1.3.0") - def seed_all(seed): random.seed(seed) diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index 9f0f7a28166..5aecc4c4eb7 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -123,14 +123,14 @@ def _class_test( check_scriptable: bool = True, **kwargs_update: Any, ): - """Utility function doing the actual comparison between lightning class metric and reference metric. + """Utility function doing the actual comparison between class metric and reference metric. Args: rank: rank of current process worldsize: number of processes preds: torch tensor with predictions target: torch tensor with targets - metric_class: lightning metric class that should be tested + metric_class: metric class that should be tested sk_metric: callable function that is used for comparison dist_sync_on_step: bool, if true will synchronize metric state across processes at each ``forward()`` @@ -150,7 +150,7 @@ def _class_test( if not metric_args: metric_args = {} - # Instantiate lightning metric + # Instantiate metric metric = metric_class( compute_on_step=check_dist_sync_on_step or check_batch, dist_sync_on_step=dist_sync_on_step, **metric_args ) @@ -255,12 +255,12 @@ def _functional_test( fragment_kwargs: bool = False, **kwargs_update, ): - """Utility function doing the actual comparison between lightning functional metric and reference metric. + """Utility function doing the actual comparison between functional metric and reference metric. Args: preds: torch tensor with predictions target: torch tensor with targets - metric_functional: lightning metric functional that should be tested + metric_functional: metric functional that should be tested sk_metric: callable function that is used for comparison metric_args: dict with additional arguments used for class initialization device: determine which device to run on, either 'cuda' or 'cpu' @@ -283,7 +283,7 @@ def _functional_test( for i in range(num_batches): extra_kwargs = {k: v[i] if isinstance(v, Tensor) else v for k, v in kwargs_update.items()} - lightning_result = metric(preds[i], target[i], **extra_kwargs) + tm_result = metric(preds[i], target[i], **extra_kwargs) extra_kwargs = { k: v.cpu() if isinstance(v, Tensor) else v for k, v in (extra_kwargs if fragment_kwargs else kwargs_update).items() @@ -291,7 +291,7 @@ def _functional_test( sk_result = sk_metric(preds[i].cpu(), target[i].cpu(), **extra_kwargs) # assert its the same - _assert_allclose(lightning_result, sk_result, atol=atol) + _assert_allclose(tm_result, sk_result, atol=atol) def _assert_half_support( @@ -366,7 +366,7 @@ def run_functional_metric_test( Args: preds: torch tensor with predictions target: torch tensor with targets - metric_functional: lightning metric class that should be tested + metric_functional: metric class that should be tested sk_metric: callable function that is used for comparison metric_args: dict with additional arguments used for class initialization fragment_kwargs: whether tensors in kwargs should be divided as `preds` and `target` among processes @@ -408,7 +408,7 @@ def run_class_metric_test( ddp: bool, if running in ddp mode or not preds: torch tensor with predictions target: torch tensor with targets - metric_class: lightning metric class that should be tested + metric_class: metric class that should be tested sk_metric: callable function that is used for comparison dist_sync_on_step: bool, if true will synchronize metric state across processes at each ``forward()`` diff --git a/tests/regression/test_mean_error.py b/tests/regression/test_mean_error.py index 777c19165cd..4bf58c5a23d 100644 --- a/tests/regression/test_mean_error.py +++ b/tests/regression/test_mean_error.py @@ -64,8 +64,6 @@ def _single_target_sk_metric(preds, target, sk_fn, metric_args): sk_preds = preds.view(-1).numpy() sk_target = target.view(-1).numpy() - # `sk_target` and `sk_preds` switched to fix failing tests. - # For more info, check https://github.com/PyTorchLightning/metrics/pull/248#issuecomment-841232277 res = sk_fn(sk_target, sk_preds) return math.sqrt(res) if (metric_args and not metric_args["squared"]) else res @@ -75,8 +73,6 @@ def _multi_target_sk_metric(preds, target, sk_fn, metric_args): sk_preds = preds.view(-1, num_targets).numpy() sk_target = target.view(-1, num_targets).numpy() - # `sk_target` and `sk_preds` switched to fix failing tests. - # For more info, check https://github.com/PyTorchLightning/metrics/pull/248#issuecomment-841232277 res = sk_fn(sk_target, sk_preds) return math.sqrt(res) if (metric_args and not metric_args["squared"]) else res diff --git a/tests/retrieval/helpers.py b/tests/retrieval/helpers.py index 1cbf3078096..37bd83c5194 100644 --- a/tests/retrieval/helpers.py +++ b/tests/retrieval/helpers.py @@ -369,7 +369,7 @@ def _errors_test_class_metric( indexes: torch tensor with indexes preds: torch tensor with predictions target: torch tensor with targets - metric_class: lightning metric class that should be tested + metric_class: metric class that should be tested message: message that exception should return metric_args: arguments for class initialization exception_type: callable function that is used for comparison @@ -396,7 +396,7 @@ def _errors_test_functional_metric( Args: preds: torch tensor with predictions target: torch tensor with targets - metric_functional: lightning functional metric that should be tested + metric_functional: functional metric that should be tested message: message that exception should return exception_type: callable function that is used for comparison kwargs_update: Additional keyword arguments that will be passed with indexes, preds and diff --git a/tests/text/helpers.py b/tests/text/helpers.py index 2a3620cdf92..beec48d27c5 100644 --- a/tests/text/helpers.py +++ b/tests/text/helpers.py @@ -52,14 +52,14 @@ def _class_test( key: str = None, **kwargs_update: Any, ): - """Utility function doing the actual comparison between lightning class metric and reference metric. + """Utility function doing the actual comparison between class metric and reference metric. Args: rank: rank of current process worldsize: number of processes preds: Sequence of predicted tokens or predicted sentences targets: Sequence of target tokens or target sentences - metric_class: lightning metric class that should be tested + metric_class: metric class that should be tested sk_metric: callable function that is used for comparison dist_sync_on_step: bool, if true will synchronize metric state across processes at each ``forward()`` @@ -78,7 +78,7 @@ def _class_test( if not metric_args: metric_args = {} - # Instanciate lightning metric + # Instanciate metric metric = metric_class( compute_on_step=check_dist_sync_on_step or check_batch, dist_sync_on_step=dist_sync_on_step, **metric_args ) @@ -156,12 +156,12 @@ def _functional_test( key: str = None, **kwargs_update, ): - """Utility function doing the actual comparison between lightning functional metric and reference metric. + """Utility function doing the actual comparison between functional metric and reference metric. Args: preds: torch tensor with predictions targets: torch tensor with targets - metric_functional: lightning metric functional that should be tested + metric_functional: metric functional that should be tested sk_metric: callable function that is used for comparison metric_args: dict with additional arguments used for class initialization device: determine which device to run on, either 'cuda' or 'cpu' @@ -181,7 +181,7 @@ def _functional_test( for i in range(NUM_BATCHES): extra_kwargs = {k: v[i] if isinstance(v, Tensor) else v for k, v in kwargs_update.items()} - lightning_result = metric(preds[i], targets[i], **extra_kwargs) + tm_result = metric(preds[i], targets[i], **extra_kwargs) extra_kwargs = { k: v.cpu() if isinstance(v, Tensor) else v @@ -190,7 +190,7 @@ def _functional_test( sk_result = sk_metric(preds[i], targets[i], **extra_kwargs) # assert its the same - _assert_allclose(lightning_result, sk_result, atol=atol, key=key) + _assert_allclose(tm_result, sk_result, atol=atol, key=key) def _assert_half_support( @@ -247,7 +247,7 @@ def run_functional_metric_test( Args: preds: torch tensor with predictions targets: torch tensor with targets - metric_functional: lightning metric class that should be tested + metric_functional: metric class that should be tested sk_metric: callable function that is used for comparison metric_args: dict with additional arguments used for class initialization fragment_kwargs: whether tensors in kwargs should be divided as `preds` and `targets` among processes @@ -293,7 +293,7 @@ def run_class_metric_test( ddp: bool, if running in ddp mode or not preds: torch tensor with predictions targets: torch tensor with targets - metric_class: lightning metric class that should be tested + metric_class: metric class that should be tested sk_metric: callable function that is used for comparison dist_sync_on_step: bool, if true will synchronize metric state across processes at each ``forward()`` diff --git a/torchmetrics/__about__.py b/torchmetrics/__about__.py index f80637b081b..58c412ad4f3 100644 --- a/torchmetrics/__about__.py +++ b/torchmetrics/__about__.py @@ -11,7 +11,7 @@ [PyTorch Lightning](https://pytorch-lightning.readthedocs.io/en/stable/). It was originally a part of Pytorch Lightning, but got split off so users could take advantage of the large collection of metrics implemented without having to install Pytorch Lightning (even though we would love for you to try it out). -We currently have around 25+ metrics implemented and we continuously is adding more metrics, both within +We currently have around 60+ metrics implemented and we continuously are adding more metrics, both within already covered domains (classification, regression ect.) but also new domains (object detection ect.). We make sure that all our metrics are rigorously tested such that you can trust them. """ diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 487ab27c156..f5620ff5294 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -13,7 +13,6 @@ # limitations under the License. import functools import inspect -import operator as op from abc import ABC, abstractmethod from contextlib import contextmanager from copy import deepcopy @@ -35,7 +34,6 @@ ) from torchmetrics.utilities.distributed import gather_all_tensors from torchmetrics.utilities.exceptions import TorchMetricsUserError -from torchmetrics.utilities.imports import _LIGHTNING_AVAILABLE, _compare_version def jit_distributed_available() -> bool: @@ -93,7 +91,6 @@ def __init__( # torch/nn/modules/module.py#L227) torch._C._log_api_usage_once(f"torchmetrics.metric.{self.__class__.__name__}") - self._LIGHTNING_GREATER_EQUAL_1_3 = _compare_version("pytorch_lightning", op.ge, "1.3.0") self._device = torch.device("cpu") self.dist_sync_on_step = dist_sync_on_step @@ -397,9 +394,7 @@ def reset(self) -> None: """This method automatically resets the metric state variables to their default value.""" self._update_called = False self._forward_cache = None - # lower lightning versions requires this implicitly to log metric objects correctly in self.log - if not _LIGHTNING_AVAILABLE or self._LIGHTNING_GREATER_EQUAL_1_3: - self._computed = None + self._computed = None for attr, default in self._defaults.items(): current_val = getattr(self, attr) diff --git a/torchmetrics/utilities/enums.py b/torchmetrics/utilities/enums.py index 7476c051d92..c9672626727 100644 --- a/torchmetrics/utilities/enums.py +++ b/torchmetrics/utilities/enums.py @@ -41,7 +41,7 @@ def __eq__(self, other: Union[str, "EnumStr", None]) -> bool: # type: ignore def __hash__(self) -> int: # re-enable hashtable so it can be used as a dict key or in a set - # example: set(LightningEnum) + # example: set(EnumStr) return hash(self.name) diff --git a/torchmetrics/utilities/imports.py b/torchmetrics/utilities/imports.py index b8e37ab2b9a..3f7c2f7a8fc 100644 --- a/torchmetrics/utilities/imports.py +++ b/torchmetrics/utilities/imports.py @@ -98,7 +98,6 @@ def _compare_version(package: str, op: Callable, version: str) -> Optional[bool] _TORCH_GREATER_EQUAL_1_7: Optional[bool] = _compare_version("torch", operator.ge, "1.7.0") _TORCH_GREATER_EQUAL_1_8: Optional[bool] = _compare_version("torch", operator.ge, "1.8.0") -_LIGHTNING_AVAILABLE: bool = _package_available("pytorch_lightning") _JIWER_AVAILABLE: bool = _package_available("jiwer") _NLTK_AVAILABLE: bool = _package_available("nltk") _ROUGE_SCORE_AVAILABLE: bool = _package_available("rouge_score")