From daf7581e7922d0f0a87b4e3cca1c5cf285d6c3e6 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 13 Jan 2022 23:32:19 +0100 Subject: [PATCH] collections for PL back-compatibility (#750) --- .github/workflows/docs-check.yml | 2 +- CHANGELOG.md | 1 - tests/bases/test_collections.py | 2 +- torchmetrics/__init__.py | 2 +- torchmetrics/{metric_collections.py => collections.py} | 4 +++- torchmetrics/metric.py | 3 +-- torchmetrics/utilities/imports.py | 1 + 7 files changed, 8 insertions(+), 7 deletions(-) rename torchmetrics/{metric_collections.py => collections.py} (98%) diff --git a/.github/workflows/docs-check.yml b/.github/workflows/docs-check.yml index 8a08019a3e2..ab81cbac4b1 100644 --- a/.github/workflows/docs-check.yml +++ b/.github/workflows/docs-check.yml @@ -76,7 +76,7 @@ jobs: working-directory: ./docs run: | make clean - make html --debug --jobs 2 SPHINXOPTS="-W --keep-going" -b linkcheck + make html --debug SPHINXOPTS="-W --keep-going" -b linkcheck - name: Upload built docs uses: actions/upload-artifact@v2 diff --git a/CHANGELOG.md b/CHANGELOG.md index 753de091ba8..780798b458a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,7 +31,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Metrics having third party dependencies removed from global import ([#463](https://github.com/PyTorchLightning/metrics/pull/463)) - Untokenized for `BLEUScore` input stay consistent with all the other text metrics ([#640](https://github.com/PyTorchLightning/metrics/pull/640)) - Arguments reordered for `TER`, `BLEUScore`, `SacreBLEUScore`, `CHRFScore` now expect input order as predictions first and target second ([#696](https://github.com/PyTorchLightning/metrics/pull/696)) -- Renamed `torchmetrics.collections` to `torchmetrics.metrics_collections` to avoid clashing with system's `collections` package ([#695](https://github.com/PyTorchLightning/metrics/pull/695)) - Changed dtype of metric state from `torch.float` to `torch.long` in `ConfusionMatrix` to accommodate larger values ([#708](https://github.com/PyTorchLightning/metrics/issues/708)) - Unify `preds`, `target` input argument's naming across all text metrics ([#723](https://github.com/PyTorchLightning/metrics/issues/723), [#727](https://github.com/PyTorchLightning/metrics/issues/727)) * `bert`, `bleu`, `chrf`, `sacre_bleu`, `wip`, `wil`, `cer`, `ter`, `wer`, `mer`, `rouge`, `squad` diff --git a/tests/bases/test_collections.py b/tests/bases/test_collections.py index 772299d0184..3f61effbb78 100644 --- a/tests/bases/test_collections.py +++ b/tests/bases/test_collections.py @@ -20,7 +20,7 @@ from tests.helpers.testers import DummyMetricDiff, DummyMetricSum from torchmetrics import Metric from torchmetrics.classification import Accuracy -from torchmetrics.metric_collections import MetricCollection +from torchmetrics.collections import MetricCollection seed_all(42) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index e4bd488841a..c3c1b43f5da 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -55,6 +55,7 @@ Specificity, StatScores, ) +from torchmetrics.collections import MetricCollection # noqa: E402 from torchmetrics.image import ( # noqa: E402 PSNR, SSIM, @@ -63,7 +64,6 @@ StructuralSimilarityIndexMeasure, ) from torchmetrics.metric import Metric # noqa: E402 -from torchmetrics.metric_collections import MetricCollection # noqa: E402 from torchmetrics.regression import ( # noqa: E402 CosineSimilarity, ExplainedVariance, diff --git a/torchmetrics/metric_collections.py b/torchmetrics/collections.py similarity index 98% rename from torchmetrics/metric_collections.py rename to torchmetrics/collections.py index 547a767e1ca..ad903fe906a 100644 --- a/torchmetrics/metric_collections.py +++ b/torchmetrics/collections.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import OrderedDict from copy import deepcopy from typing import Any, Dict, Hashable, Iterable, Optional, Sequence, Tuple, Union @@ -22,6 +21,9 @@ from torchmetrics.metric import Metric from torchmetrics.utilities import rank_zero_warn +# this is just a bypass for this module name collision with build-in one +from torchmetrics.utilities.imports import OrderedDict + class MetricCollection(nn.ModuleDict): """MetricCollection class can be used to chain metrics that have the same call pattern into one single class. diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 3b1211297c8..487ab27c156 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -15,10 +15,9 @@ import inspect import operator as op from abc import ABC, abstractmethod -from collections.abc import Sequence from contextlib import contextmanager from copy import deepcopy -from typing import Any, Callable, Dict, Generator, List, Optional, Union +from typing import Any, Callable, Dict, Generator, List, Optional, Sequence, Union import torch from torch import Tensor diff --git a/torchmetrics/utilities/imports.py b/torchmetrics/utilities/imports.py index 3cd27c9af78..fdd74f4394b 100644 --- a/torchmetrics/utilities/imports.py +++ b/torchmetrics/utilities/imports.py @@ -13,6 +13,7 @@ # limitations under the License. """Import utilities.""" import operator +from collections import OrderedDict # noqa: F401 from importlib import import_module from importlib.util import find_spec from typing import Callable, Optional