Skip to content

Commit

Permalink
collections for PL back-compatibility (#750)
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Jan 13, 2022
1 parent 4dbdc77 commit daf7581
Show file tree
Hide file tree
Showing 7 changed files with 8 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/docs-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ jobs:
working-directory: ./docs
run: |
make clean
make html --debug --jobs 2 SPHINXOPTS="-W --keep-going" -b linkcheck
make html --debug SPHINXOPTS="-W --keep-going" -b linkcheck
- name: Upload built docs
uses: actions/upload-artifact@v2
Expand Down
1 change: 0 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Metrics having third party dependencies removed from global import ([#463](https://github.com/PyTorchLightning/metrics/pull/463))
- Untokenized for `BLEUScore` input stay consistent with all the other text metrics ([#640](https://github.com/PyTorchLightning/metrics/pull/640))
- Arguments reordered for `TER`, `BLEUScore`, `SacreBLEUScore`, `CHRFScore` now expect input order as predictions first and target second ([#696](https://github.com/PyTorchLightning/metrics/pull/696))
- Renamed `torchmetrics.collections` to `torchmetrics.metrics_collections` to avoid clashing with system's `collections` package ([#695](https://github.com/PyTorchLightning/metrics/pull/695))
- Changed dtype of metric state from `torch.float` to `torch.long` in `ConfusionMatrix` to accommodate larger values ([#708](https://github.com/PyTorchLightning/metrics/issues/708))
- Unify `preds`, `target` input argument's naming across all text metrics ([#723](https://github.com/PyTorchLightning/metrics/issues/723), [#727](https://github.com/PyTorchLightning/metrics/issues/727))
* `bert`, `bleu`, `chrf`, `sacre_bleu`, `wip`, `wil`, `cer`, `ter`, `wer`, `mer`, `rouge`, `squad`
Expand Down
2 changes: 1 addition & 1 deletion tests/bases/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from tests.helpers.testers import DummyMetricDiff, DummyMetricSum
from torchmetrics import Metric
from torchmetrics.classification import Accuracy
from torchmetrics.metric_collections import MetricCollection
from torchmetrics.collections import MetricCollection

seed_all(42)

Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
Specificity,
StatScores,
)
from torchmetrics.collections import MetricCollection # noqa: E402
from torchmetrics.image import ( # noqa: E402
PSNR,
SSIM,
Expand All @@ -63,7 +64,6 @@
StructuralSimilarityIndexMeasure,
)
from torchmetrics.metric import Metric # noqa: E402
from torchmetrics.metric_collections import MetricCollection # noqa: E402
from torchmetrics.regression import ( # noqa: E402
CosineSimilarity,
ExplainedVariance,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import OrderedDict
from copy import deepcopy
from typing import Any, Dict, Hashable, Iterable, Optional, Sequence, Tuple, Union

Expand All @@ -22,6 +21,9 @@
from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_warn

# this is just a bypass for this module name collision with build-in one
from torchmetrics.utilities.imports import OrderedDict


class MetricCollection(nn.ModuleDict):
"""MetricCollection class can be used to chain metrics that have the same call pattern into one single class.
Expand Down
3 changes: 1 addition & 2 deletions torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
import inspect
import operator as op
from abc import ABC, abstractmethod
from collections.abc import Sequence
from contextlib import contextmanager
from copy import deepcopy
from typing import Any, Callable, Dict, Generator, List, Optional, Union
from typing import Any, Callable, Dict, Generator, List, Optional, Sequence, Union

import torch
from torch import Tensor
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""Import utilities."""
import operator
from collections import OrderedDict # noqa: F401
from importlib import import_module
from importlib.util import find_spec
from typing import Callable, Optional
Expand Down

0 comments on commit daf7581

Please sign in to comment.