diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index 01760ac3394..b1d495ff064 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -44,7 +44,7 @@ help you or finish it with you :]_ Want to keep Torchmetrics healthy? Love seeing those green tests? So do we! How to we keep it that way? We write tests! We value tests contribution even more than new features. One of the core values of torchmetrics -is that our users can trust our metric implementation. We can only garantee this if our metrics are well tested. +is that our users can trust our metric implementation. We can only guarantee this if our metrics are well tested. --- diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 6d6a3d82cad..205c7650e42 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -2,7 +2,7 @@ name: Bug report about: Create a report to help us improve title: '' -labels: bug, help wanted +labels: bug / fix, help wanted assignees: '' --- diff --git a/.github/mergify.yml b/.github/mergify.yml new file mode 100644 index 00000000000..eb9a8764e99 --- /dev/null +++ b/.github/mergify.yml @@ -0,0 +1,59 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +pull_request_rules: + + - name: warn on conflicts + conditions: + - conflict + - -draft # filter-out GH draft PRs + - -label="has conflicts" + actions: + # comment: + # message: This pull request is now in conflict... :( + label: + add: [ "has conflicts" ] + + - name: resolved conflicts + conditions: + - -conflict + - label="has conflicts" + - -draft # filter-out GH draft PRs + - -merged # not merged yet + - -closed + actions: + label: + remove: [ "has conflicts" ] + + - name: update PR + conditions: + - -conflict + - -draft # filter-out GH draft PRs + - base=master # apply only on master + - -title~=(?i)wip # skip all PR that title contains “WIP” (ignoring case) + - "#approved-reviews-by>=1" # number of review approvals + actions: + update: {} + + - name: add core reviewer + conditions: + - -conflict # skip if conflict + - -draft # filter-out GH draft PRs + - label="0:] Ready-To-Go" + - "#approved-reviews-by<2" # number of review approvals + - "#review-requested<2" # number of requested reviews + actions: + request_reviews: + teams: + - "@PyTorchLightning/core-metrics" diff --git a/CHANGELOG.md b/CHANGELOG.md index 784174932a5..6692332fbcf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,12 +9,23 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- Added prefix arg to metric collection ([#70](https://github.com/PyTorchLightning/metrics/pull/70)) +- Added `prefix` argument to `MetricCollection` ([#70](https://github.com/PyTorchLightning/metrics/pull/70)) - Added `CohenKappa` metric ([#69](https://github.com/PyTorchLightning/metrics/pull/69)) +- Added `RetrievalMAP` metric for Information Retrieval ([#5032](https://github.com/PyTorchLightning/pytorch-lightning/pull/5032)) + + +- Added `average='micro'` as an option in AUROC for multilabel problems ([#110](https://github.com/PyTorchLightning/metrics/pull/110)) + + +- Added `MatthewsCorrcoef` metric ([#98](https://github.com/PyTorchLightning/metrics/pull/98)) + + +- Added multilabel support to `ROC` metric ([#114](https://github.com/PyTorchLightning/metrics/pull/114)) + ### Changed - Changed `ExplainedVariance` from storing all preds/targets to tracking 5 statistics ([#68](https://github.com/PyTorchLightning/metrics/pull/68)) @@ -28,6 +39,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fixed when `_stable_1d_sort` to work when n >= N ([#6177](https://github.com/PyTorchLightning/pytorch-lightning/pull/6177)) + ## [0.2.0] - 2021-03-12 diff --git a/README.md b/README.md index 37097597652..7af3a2f53d7 100644 --- a/README.md +++ b/README.md @@ -123,48 +123,64 @@ Module metric usage remains the same when using multiple GPUs or multiple nodes. ``` python -os.environ['MASTER_ADDR'] = 'localhost' -os.environ['MASTER_PORT'] = '12355' +import os +import torch +from torch import nn +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel as DDP +import torchmetrics -# create default process group -dist.init_process_group("gloo", rank=rank, world_size=world_size) +def metric_ddp(rank, world_size) + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12355' -# initialize model -metric = torchmetrics.Accuracy() + # create default process group + dist.init_process_group("gloo", rank=rank, world_size=world_size) + + # initialize model + metric = torchmetrics.Accuracy() + + # define a model and append your metric to it + # this allows metric states to be placed on correct accelerators when + # .to(device) is called on the model + model = nn.Linear(10, 10) + model.metric = metric + model = model.to(rank) + + # initialize DDP + model = DDP(model, device_ids=[rank]) + + n_epochs = 5 + # this shows iteration over multiple training epochs + for n in range(n_epochs): + + # this will be replaced by a DataLoader with a DistributedSampler + n_batches = 10 + for i in range(n_batches): + # simulate a classification problem + preds = torch.randn(10, 5).softmax(dim=-1) + target = torch.randint(5, (10,)) + + # metric on current batch + acc = metric(preds, target) + if rank == 0: # print only for rank 0 + print(f"Accuracy on batch {i}: {acc}") + + # metric on all batches and all accelerators using custom accumulation + # accuracy is same across both accelerators + acc = metric.compute() + print(f"Accuracy on all data: {acc}, accelerator rank: {rank}") + + # Reseting internal state such that metric ready for new data + metric.reset() + + # cleanup + dist.destroy_process_group() + +world_size = 2 # number of gpus to parallize over +mp.spawn(metric_dpp, args=(world_size,), nprocs=world_size, join=True) -# define a model and append your metric to it -# this allows metric states to be placed on correct accelerators when -# .to(device) is called on the model -model = nn.Linear(10, 10) -model.metric = metric -model = model.to(rank) - -# initialize DDP -model = DDP(model, device_ids=[rank]) - -n_epochs = 5 -# this shows iteration over multiple training epochs -for n in range(n_epochs): - - # this will be replaced by a DataLoader with a DistributedSampler - n_batches = 10 - for i in range(n_batches): - # simulate a classification problem - preds = torch.randn(10, 5).softmax(dim=-1) - target = torch.randint(5, (10,)) - - # metric on current batch - acc = metric(preds, target) - if rank == 0: # print only for rank 0 - print(f"Accuracy on batch {i}: {acc}") - - # metric on all batches and all accelerators using custom accumulation - # accuracy is same across both accelerators - acc = metric.compute() - print(f"Accuracy on all data: {acc}, accelerator rank: {rank}") - - # Reseting internal state such that metric ready for new data - metric.reset() ``` diff --git a/docs/source/_templates/theme_variables.jinja b/docs/source/_templates/theme_variables.jinja index 87538f5d3a6..c41e23e9f9f 100644 --- a/docs/source/_templates/theme_variables.jinja +++ b/docs/source/_templates/theme_variables.jinja @@ -1,18 +1,14 @@ {%- set external_urls = { - 'github': 'https://github.com/PytorchLightning/pytorch-torchmetrics', - 'github_issues': 'https://github.com/PytorchLightning/pytorch-torchmetrics/issues', - 'contributing': 'https://github.com/PytorchLightning/pytorch-lightning/blob/master/CONTRIBUTING.md', - 'governance': 'https://github.com/PytorchLightning/pytorch-lightning/blob/master/governance.md', - 'docs': 'https://pytorch-torchmetrics.rtfd.io/en/latest', + 'github': 'https://github.com/PytorchLightning/metrics', + 'github_issues': 'https://github.com/PytorchLightning/metrics/issues', + 'contributing': 'https://github.com/PyTorchLightning/metrics/blob/master/.github/CONTRIBUTING.md', + 'docs': 'https://torchmetrics.readthedocs.io/en/latest', 'twitter': 'https://twitter.com/PyTorchLightnin', 'discuss': 'https://pytorch-lightning.slack.com', - 'tutorials': 'https://pytorch-lightning.readthedocs.io/en/latest/#tutorials', 'previous_pytorch_versions': 'https://torchmetrics.rtfd.io/en/latest/', 'home': 'https://torchmetrics.rtfd.io/en/latest/', 'get_started': 'https://torchmetrics.readthedocs.io/en/latest/quickstart.html', - 'features': 'https://lightning-bolts.rtfd.io/en/latest/', 'blog': 'https://www.pytorchlightning.ai/blog', - 'resources': 'https://pytorch-lightning.readthedocs.io/en/latest/#community-examples', - 'support': 'https://github.com/PytorchLightning/pytorch-torchmetrics/issues', + 'support': 'https://github.com/PytorchLightning/metrics/issues', } -%} diff --git a/docs/source/conf.py b/docs/source/conf.py index fba0e170b09..6ef3913aa44 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -13,25 +13,29 @@ # documentation root, use os.path.abspath to make it absolute, like shown here. # import m2r -import builtins import glob import inspect import os import shutil import sys +from importlib.util import module_from_spec, spec_from_file_location import pt_lightning_sphinx_theme -PATH_HERE = os.path.abspath(os.path.dirname(__file__)) -PATH_ROOT = os.path.join(PATH_HERE, "..", "..") -sys.path.insert(0, os.path.abspath(PATH_ROOT)) - -builtins.__LIGHTNING_BOLT_SETUP__ = True +_PATH_HERE = os.path.abspath(os.path.dirname(__file__)) +_PATH_ROOT = os.path.realpath(os.path.join(_PATH_HERE, "..", "..")) +sys.path.insert(0, os.path.abspath(_PATH_ROOT)) FOLDER_GENERATED = 'generated' SPHINX_MOCK_REQUIREMENTS = int(os.environ.get("SPHINX_MOCK_REQUIREMENTS", True)) -import torchmetrics # noqa: E402 +try: + from torchmetrics import info +except ImportError: + # alternative https://stackoverflow.com/a/67692/4521646 + spec = spec_from_file_location("torchmetrics/info.py", os.path.join(_PATH_ROOT, "torchmetrics", "info.py")) + info = module_from_spec(spec) + spec.loader.exec_module(info) html_favicon = '_static/images/icon.svg' @@ -39,13 +43,13 @@ # this name shall match the project name in Github as it is used for linking to code project = "PyTorch-Metrics" -copyright = torchmetrics.__copyright__ -author = torchmetrics.__author__ +copyright = info.__copyright__ +author = info.__author__ # The short X.Y version -version = torchmetrics.__version__ +version = info.__version__ # The full version, including alpha/beta/rc tags -release = torchmetrics.__version__ +release = info.__version__ # Options for the linkcode extension # ---------------------------------- @@ -70,14 +74,14 @@ def _transform_changelog(path_in: str, path_out: str) -> None: fp.writelines(chlog_lines) -os.makedirs(os.path.join(PATH_HERE, FOLDER_GENERATED), exist_ok=True) +os.makedirs(os.path.join(_PATH_HERE, FOLDER_GENERATED), exist_ok=True) # copy all documents from GH templates like contribution guide -for md in glob.glob(os.path.join(PATH_ROOT, '.github', '*.md')): - shutil.copy(md, os.path.join(PATH_HERE, FOLDER_GENERATED, os.path.basename(md))) +for md in glob.glob(os.path.join(_PATH_ROOT, '.github', '*.md')): + shutil.copy(md, os.path.join(_PATH_HERE, FOLDER_GENERATED, os.path.basename(md))) # copy also the changelog _transform_changelog( - os.path.join(PATH_ROOT, 'CHANGELOG.md'), - os.path.join(PATH_HERE, FOLDER_GENERATED, 'CHANGELOG.md'), + os.path.join(_PATH_ROOT, 'CHANGELOG.md'), + os.path.join(_PATH_HERE, FOLDER_GENERATED, 'CHANGELOG.md'), ) # -- General configuration --------------------------------------------------- @@ -166,8 +170,8 @@ def _transform_changelog(path_in: str, path_out: str) -> None: # documentation. html_theme_options = { - "pytorch_project": torchmetrics.__homepage__, - "canonical_url": torchmetrics.__homepage__, + "pytorch_project": info.__homepage__, + "canonical_url": info.__homepage__, "collapse_navigation": False, "display_version": True, "logo_only": False, @@ -233,7 +237,7 @@ def _transform_changelog(path_in: str, path_out: str) -> None: project + " Documentation", author, project, - torchmetrics.__docs__, + info.__docs__, "Miscellaneous", ), ] @@ -280,11 +284,11 @@ def _transform_changelog(path_in: str, path_out: str) -> None: # packages for which sphinx-apidoc should generate the docs (.rst files) PACKAGES = [ - torchmetrics.__name__, + info.__name__, ] # def run_apidoc(_): -# apidoc_output_folder = os.path.join(PATH_HERE, "api") +# apidoc_output_folder = os.path.join(_PATH_HERE, "api") # sys.path.insert(0, apidoc_output_folder) # # # delete api-doc files before generating them @@ -294,7 +298,7 @@ def _transform_changelog(path_in: str, path_out: str) -> None: # for pkg in PACKAGES: # argv = ['-e', # '-o', apidoc_output_folder, -# os.path.join(PATH_ROOT, pkg), +# os.path.join(_PATH_ROOT, pkg), # '**/test_*', # '--force', # '--private', @@ -311,10 +315,10 @@ def setup(app): # copy all notebooks to local folder -path_nbs = os.path.join(PATH_HERE, "notebooks") +path_nbs = os.path.join(_PATH_HERE, "notebooks") if not os.path.isdir(path_nbs): os.mkdir(path_nbs) -for path_ipynb in glob.glob(os.path.join(PATH_ROOT, "notebooks", "*.ipynb")): +for path_ipynb in glob.glob(os.path.join(_PATH_ROOT, "notebooks", "*.ipynb")): path_ipynb2 = os.path.join(path_nbs, os.path.basename(path_ipynb)) shutil.copy(path_ipynb, path_ipynb2) @@ -340,7 +344,7 @@ def package_list_from_file(file): MOCK_PACKAGES = [] if SPHINX_MOCK_REQUIREMENTS: # mock also base packages when we are on RTD since we don't install them there - MOCK_PACKAGES += package_list_from_file(os.path.join(PATH_ROOT, "requirements.txt")) + MOCK_PACKAGES += package_list_from_file(os.path.join(_PATH_ROOT, "requirements.txt")) MOCK_PACKAGES = [PACKAGE_MAPPING.get(pkg, pkg) for pkg in MOCK_PACKAGES] autodoc_mock_imports = MOCK_PACKAGES @@ -373,7 +377,7 @@ def find_source(): return None try: filename = "%s#L%d-L%d" % find_source() - except Exception: + except Exception: # todo: specify the exception filename = info["module"].replace(".", "/") + ".py" # import subprocess # tag = subprocess.Popen(['git', 'rev-parse', 'HEAD'], stdout=subprocess.PIPE, diff --git a/docs/source/index.rst b/docs/source/index.rst index 245296e1bb8..cdaffee857c 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,5 +1,4 @@ -.. PyTorchtorchmetrics documentation master file, created by - sphinx-quickstart on Wed Mar 25 21:34:07 2020. +.. TorchMetrics documentation master file. You can adapt this file completely to your liking, but it should at least contain the root `toctree` directive. diff --git a/docs/source/pages/brief_intro.rst b/docs/source/pages/brief_intro.rst index dfcccfe5e74..6f201a6ce41 100644 --- a/docs/source/pages/brief_intro.rst +++ b/docs/source/pages/brief_intro.rst @@ -1,9 +1,9 @@ -TorchMetrics is a collection of Machine learning metrics for distributed, scalable PyTorch models and an easy-to-use API to create custom metrics. It offers: +TorchMetrics is a collection of Machine learning metrics for distributed, scalable PyTorch models and an easy-to-use API to create custom metrics. It offers the following benefits: * Optimized for distributed-training * A standardized interface to increase reproducibility * Reduces Boilerplate -* Distrubuted-training compatible +* Distributed-training compatible * Rigorously tested * Automatic accumulation over batches * Automatic synchronization between multiple devices diff --git a/docs/source/pages/implement.rst b/docs/source/pages/implement.rst index aaff3880328..8354bd0943f 100644 --- a/docs/source/pages/implement.rst +++ b/docs/source/pages/implement.rst @@ -131,7 +131,7 @@ and tests gets formatted in the following way: 4. Remember to add binding to the different relevant ``__init__`` files. -5. Testing is key to keeping ``torchmetrics`` trustworty. This is why we have a very rigid testing protocol. This means +5. Testing is key to keeping ``torchmetrics`` trustworthy. This is why we have a very rigid testing protocol. This means that we in most cases require the metric to be tested against some other common framework (``sklearn``, ``scipy`` etc). 1. Create a testing file in ``tests/"domain"/test_"new_metric".py``. Only one file is needed as it is intended to test diff --git a/docs/source/pages/lightning.rst b/docs/source/pages/lightning.rst index 7d01ee765f6..a3dabb83ba2 100644 --- a/docs/source/pages/lightning.rst +++ b/docs/source/pages/lightning.rst @@ -15,7 +15,7 @@ While TorchMetrics was built to be used with native PyTorch, using TorchMetrics * Module metrics are automatically placed on the correct device when properly defined inside a LightningModule. This means that your data will always be placed on the same device as your metrics. * Native support for logging metrics in Lightning using `self.log `_ inside your LightningModule. -* The ``.reset()`` method of the metric will automatically be called and the end of an epoch. +* The ``.reset()`` method of the metric will automatically be called at the end of an epoch. The example below shows how to use a metric in your `LightningModule `_: diff --git a/docs/source/pages/overview.rst b/docs/source/pages/overview.rst index f45aa1a0180..82d8a9fd7a0 100644 --- a/docs/source/pages/overview.rst +++ b/docs/source/pages/overview.rst @@ -109,6 +109,29 @@ the native `MetricCollection`_ module can also be used to wrap multiple metrics. val3 = self.metric3['accuracy'](preds, target) val4 = self.metric4(preds, target) +Metrics in Dataparallel (DP) mode +================================= + +When using metrics in `Dataparallel (DP) `_ +mode, one should be aware DP will both create and clean-up replicas of Metric objects during a single forward pass. +This has the consequence, that the metric state of the replicas will as default be destroyed before we can sync +them. It is therefore recommended, when using metrics in DP mode, to initialize them with ``dist_sync_on_step=True`` +such that metric states are synchonized between the main process and the replicas before they are destroyed. + +Metrics in Distributed Data Parallel (DDP) mode +=============================================== + +When using metrics in `Distributed Data Parallel (DPP) `_ +mode, one should be aware that DDP will add additional samples to your dataset if the size of your dataset is +not equally divisible by ``batch_size * num_processors``. The added samples will always be replicates of datapoints +already in your dataset. This is done to secure an equal load for all processes. However, this has the consequence +that the calculated metric value will be sligtly bias towards those replicated samples, leading to a wrong result. + +During training and/or validation this may not be important, however it is highly recommended when evaluating +the test dataset to only run on a single gpu or use a `join `_ +context in conjunction with DDP to prevent this behaviour. + + ****************** Metric Arithmetics ****************** @@ -148,7 +171,7 @@ This pattern is implemented for the following operators (with ``a`` being metric * Inequality (``a != b``) * Bitwise OR (``a | b``) * Power (``a ** b``) -* Substraction (``a - b``) +* Subtraction (``a - b``) * True Division (``a / b``) * Bitwise XOR (``a ^ b``) * Absolute Value (``abs(a)``) diff --git a/docs/source/references/functional.rst b/docs/source/references/functional.rst index 427551faad3..6b6651bfc49 100644 --- a/docs/source/references/functional.rst +++ b/docs/source/references/functional.rst @@ -83,6 +83,11 @@ iou [func] .. autofunction:: torchmetrics.functional.iou :noindex: +matthews_corrcoef [func] +~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: torchmetrics.functional.matthews_corrcoef + :noindex: roc [func] ~~~~~~~~~~~~~~~~~~~~~ @@ -132,6 +137,13 @@ stat_scores [func] :noindex: +retrieval_average_precision [func] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: torchmetrics.functional.retrieval_average_precision + :noindex: + + to_categorical [func] ~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index ce4930c4931..cd7fe83ea57 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -168,6 +168,12 @@ IoU .. autoclass:: torchmetrics.IoU :noindex: +MatthewsCorrcoef +~~~~~~~~~~~~~~~~ + +.. autoclass:: torchmetrics.MatthewsCorrcoef + :noindex: + Hamming Distance ~~~~~~~~~~~~~~~~ @@ -206,6 +212,13 @@ StatScores :noindex: +RetrievalMAP +~~~~~~~~~~~~ + +.. autoclass:: torchmetrics.RetrievalMAP + :noindex: + + ****************** Regression Metrics ****************** diff --git a/requirements.txt b/requirements.txt index 12abcf40fc9..8846f8ef3cb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ +numpy torch>=1.3.1 \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index dacf40163cb..d05ba715b44 100644 --- a/setup.cfg +++ b/setup.cfg @@ -95,3 +95,7 @@ ignore_errors = True # todo: add proper typing to this module... [mypy-torchmetrics.regression.*] ignore_errors = True + +# todo: add proper typing to this module... +[mypy-torchmetrics.retrieval.*] +ignore_errors = True diff --git a/setup.py b/setup.py index 23dcfd623fc..278b506f517 100755 --- a/setup.py +++ b/setup.py @@ -1,22 +1,24 @@ #!/usr/bin/env python - import os +import sys # Always prefer setuptools over distutils from setuptools import find_packages, setup +_PATH_ROOT = os.path.realpath(os.path.dirname(__file__)) try: - import builtins + from torchmetrics import info, setup_tools except ImportError: - import __builtin__ as builtins - -# https://packaging.python.org/guides/single-sourcing-package-version/ -# http://blog.ionelmc.ro/2014/05/25/python-packaging/ -PATH_ROOT = os.path.dirname(__file__) -builtins.__LIGHTNING_SETUP__ = True + # alternative https://stackoverflow.com/a/67692/4521646 + sys.path.append("torchmetrics") + import info + import setup_tools -import torchmetrics # noqa: E402 -from torchmetrics.setup_tools import _load_readme_description, _load_requirements # noqa: E402 +long_description = setup_tools._load_readme_description( + _PATH_ROOT, + homepage=info.__homepage__, + version=f'v{info.__version__}', +) # https://packaging.python.org/discussions/install-requires-vs-requirements / # keep the meta-data here for simplicity in reading this file... it's not obvious @@ -25,26 +27,26 @@ # engineer specific practices setup( name='torchmetrics', - version=torchmetrics.__version__, - description=torchmetrics.__docs__, - author=torchmetrics.__author__, - author_email=torchmetrics.__author_email__, - url=torchmetrics.__homepage__, - download_url='https://github.com/PyTorchLightning/metrics/archive/master.zip', - license=torchmetrics.__license__, + version=info.__version__, + description=info.__docs__, + author=info.__author__, + author_email=info.__author_email__, + url=info.__homepage__, + download_url=os.path.join(info.__homepage__, 'archive', 'master.zip'), + license=info.__license__, packages=find_packages(exclude=['tests', 'docs']), - long_description=_load_readme_description(PATH_ROOT, version=f'v{torchmetrics.__version__}'), + long_description=long_description, long_description_content_type='text/markdown', include_package_data=True, zip_safe=False, keywords=['deep learning', 'machine learning', 'pytorch', 'metrics', 'AI'], python_requires='>=3.6', setup_requires=[], - install_requires=_load_requirements(PATH_ROOT), + install_requires=setup_tools._load_requirements(_PATH_ROOT), project_urls={ - "Bug Tracker": "https://github.com/PyTorchLightning/torchmetrics/issues", + "Bug Tracker": os.path.join(info.__homepage__, 'issues'), "Documentation": "https://torchmetrics.rtfd.io/en/latest/", - "Source Code": "https://github.com/PyTorchLightning/torchmetrics", + "Source Code": info.__homepage__, }, classifiers=[ 'Environment :: Console', diff --git a/tests/bases/test_composition.py b/tests/bases/test_composition.py index 680a8152bd4..d7b82ca270a 100644 --- a/tests/bases/test_composition.py +++ b/tests/bases/test_composition.py @@ -15,6 +15,7 @@ import pytest import torch +from torch import tensor from tests.helpers import _MARK_TORCH_MIN_1_4, _MARK_TORCH_MIN_1_5, _MARK_TORCH_MIN_1_6 from torchmetrics.metric import CompositionalMetric, Metric @@ -31,7 +32,7 @@ def update(self, *args, **kwargs) -> None: self._num_updates += 1 def compute(self): - return torch.tensor(self._val_to_return) + return tensor(self._val_to_return) def reset(self): self._num_updates = 0 @@ -41,10 +42,10 @@ def reset(self): @pytest.mark.parametrize( ["second_operand", "expected_result"], [ - (DummyMetric(2), torch.tensor(4)), - (2, torch.tensor(4)), - (2.0, torch.tensor(4.0)), - pytest.param(torch.tensor(2), torch.tensor(4), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_4)), + (DummyMetric(2), tensor(4)), + (2, tensor(4)), + (2.0, tensor(4.0)), + pytest.param(tensor(2), tensor(4), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_4)), ], ) def test_metrics_add(second_operand, expected_result): @@ -62,7 +63,7 @@ def test_metrics_add(second_operand, expected_result): @pytest.mark.parametrize( ["second_operand", "expected_result"], - [(DummyMetric(3), torch.tensor(2)), (3, torch.tensor(2)), (3, torch.tensor(2)), (torch.tensor(3), torch.tensor(2))], + [(DummyMetric(3), tensor(2)), (3, tensor(2)), (3, tensor(2)), (tensor(3), tensor(2))], ) @pytest.mark.skipif(**_MARK_TORCH_MIN_1_5) def test_metrics_and(second_operand, expected_result): @@ -81,10 +82,10 @@ def test_metrics_and(second_operand, expected_result): @pytest.mark.parametrize( ["second_operand", "expected_result"], [ - (DummyMetric(2), torch.tensor(True)), - (2, torch.tensor(True)), - (2.0, torch.tensor(True)), - (torch.tensor(2), torch.tensor(True)), + (DummyMetric(2), tensor(True)), + (2, tensor(True)), + (2.0, tensor(True)), + (tensor(2), tensor(True)), ], ) def test_metrics_eq(second_operand, expected_result): @@ -101,10 +102,10 @@ def test_metrics_eq(second_operand, expected_result): @pytest.mark.parametrize( ["second_operand", "expected_result"], [ - (DummyMetric(2), torch.tensor(2)), - (2, torch.tensor(2)), - (2.0, torch.tensor(2.0)), - (torch.tensor(2), torch.tensor(2)), + (DummyMetric(2), tensor(2)), + (2, tensor(2)), + (2.0, tensor(2.0)), + (tensor(2), tensor(2)), ], ) @pytest.mark.skipif(**_MARK_TORCH_MIN_1_5) @@ -121,10 +122,10 @@ def test_metrics_floordiv(second_operand, expected_result): @pytest.mark.parametrize( ["second_operand", "expected_result"], [ - (DummyMetric(2), torch.tensor(True)), - (2, torch.tensor(True)), - (2.0, torch.tensor(True)), - (torch.tensor(2), torch.tensor(True)), + (DummyMetric(2), tensor(True)), + (2, tensor(True)), + (2.0, tensor(True)), + (tensor(2), tensor(True)), ], ) def test_metrics_ge(second_operand, expected_result): @@ -141,10 +142,10 @@ def test_metrics_ge(second_operand, expected_result): @pytest.mark.parametrize( ["second_operand", "expected_result"], [ - (DummyMetric(2), torch.tensor(True)), - (2, torch.tensor(True)), - (2.0, torch.tensor(True)), - (torch.tensor(2), torch.tensor(True)), + (DummyMetric(2), tensor(True)), + (2, tensor(True)), + (2.0, tensor(True)), + (tensor(2), tensor(True)), ], ) def test_metrics_gt(second_operand, expected_result): @@ -161,10 +162,10 @@ def test_metrics_gt(second_operand, expected_result): @pytest.mark.parametrize( ["second_operand", "expected_result"], [ - (DummyMetric(2), torch.tensor(False)), - (2, torch.tensor(False)), - (2.0, torch.tensor(False)), - (torch.tensor(2), torch.tensor(False)), + (DummyMetric(2), tensor(False)), + (2, tensor(False)), + (2.0, tensor(False)), + (tensor(2), tensor(False)), ], ) def test_metrics_le(second_operand, expected_result): @@ -181,10 +182,10 @@ def test_metrics_le(second_operand, expected_result): @pytest.mark.parametrize( ["second_operand", "expected_result"], [ - (DummyMetric(2), torch.tensor(False)), - (2, torch.tensor(False)), - (2.0, torch.tensor(False)), - (torch.tensor(2), torch.tensor(False)), + (DummyMetric(2), tensor(False)), + (2, tensor(False)), + (2.0, tensor(False)), + (tensor(2), tensor(False)), ], ) def test_metrics_lt(second_operand, expected_result): @@ -200,7 +201,7 @@ def test_metrics_lt(second_operand, expected_result): @pytest.mark.parametrize( ["second_operand", "expected_result"], - [(DummyMetric([2, 2, 2]), torch.tensor(12)), (torch.tensor([2, 2, 2]), torch.tensor(12))], + [(DummyMetric([2, 2, 2]), tensor(12)), (tensor([2, 2, 2]), tensor(12))], ) def test_metrics_matmul(second_operand, expected_result): first_metric = DummyMetric([2, 2, 2]) @@ -215,10 +216,10 @@ def test_metrics_matmul(second_operand, expected_result): @pytest.mark.parametrize( ["second_operand", "expected_result"], [ - (DummyMetric(2), torch.tensor(1)), - (2, torch.tensor(1)), - (2.0, torch.tensor(1)), - (torch.tensor(2), torch.tensor(1)), + (DummyMetric(2), tensor(1)), + (2, tensor(1)), + (2.0, tensor(1)), + (tensor(2), tensor(1)), ], ) def test_metrics_mod(second_operand, expected_result): @@ -234,10 +235,10 @@ def test_metrics_mod(second_operand, expected_result): @pytest.mark.parametrize( ["second_operand", "expected_result"], [ - (DummyMetric(2), torch.tensor(4)), - (2, torch.tensor(4)), - (2.0, torch.tensor(4.0)), - pytest.param(torch.tensor(2), torch.tensor(4), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_4)), + (DummyMetric(2), tensor(4)), + (2, tensor(4)), + (2.0, tensor(4.0)), + pytest.param(tensor(2), tensor(4), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_4)), ], ) def test_metrics_mul(second_operand, expected_result): @@ -256,10 +257,10 @@ def test_metrics_mul(second_operand, expected_result): @pytest.mark.parametrize( ["second_operand", "expected_result"], [ - (DummyMetric(2), torch.tensor(False)), - (2, torch.tensor(False)), - (2.0, torch.tensor(False)), - (torch.tensor(2), torch.tensor(False)), + (DummyMetric(2), tensor(False)), + (2, tensor(False)), + (2.0, tensor(False)), + (tensor(2), tensor(False)), ], ) def test_metrics_ne(second_operand, expected_result): @@ -275,7 +276,7 @@ def test_metrics_ne(second_operand, expected_result): @pytest.mark.parametrize( ["second_operand", "expected_result"], - [(DummyMetric([1, 0, 3]), torch.tensor([-1, -2, 3])), (torch.tensor([1, 0, 3]), torch.tensor([-1, -2, 3]))], + [(DummyMetric([1, 0, 3]), tensor([-1, -2, 3])), (tensor([1, 0, 3]), tensor([-1, -2, 3]))], ) @pytest.mark.skipif(**_MARK_TORCH_MIN_1_5) def test_metrics_or(second_operand, expected_result): @@ -294,10 +295,10 @@ def test_metrics_or(second_operand, expected_result): @pytest.mark.parametrize( ["second_operand", "expected_result"], [ - pytest.param(DummyMetric(2), torch.tensor(4)), - pytest.param(2, torch.tensor(4)), - pytest.param(2.0, torch.tensor(4.0), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_6)), - pytest.param(torch.tensor(2), torch.tensor(4)), + pytest.param(DummyMetric(2), tensor(4)), + pytest.param(2, tensor(4)), + pytest.param(2.0, tensor(4.0), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_6)), + pytest.param(tensor(2), tensor(4)), ], ) def test_metrics_pow(second_operand, expected_result): @@ -312,7 +313,7 @@ def test_metrics_pow(second_operand, expected_result): @pytest.mark.parametrize( ["first_operand", "expected_result"], - [(5, torch.tensor(2)), (5.0, torch.tensor(2.0)), (torch.tensor(5), torch.tensor(2))], + [(5, tensor(2)), (5.0, tensor(2.0)), (tensor(5), tensor(2))], ) @pytest.mark.skipif(**_MARK_TORCH_MIN_1_5) def test_metrics_rfloordiv(first_operand, expected_result): @@ -324,10 +325,8 @@ def test_metrics_rfloordiv(first_operand, expected_result): assert torch.allclose(expected_result, final_rfloordiv.compute()) -@pytest.mark.parametrize( - ["first_operand", "expected_result"], - [pytest.param(torch.tensor([2, 2, 2]), torch.tensor(12), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_4))] -) +@pytest.mark.parametrize(["first_operand", "expected_result"], + [pytest.param(tensor([2, 2, 2]), tensor(12), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_4))]) def test_metrics_rmatmul(first_operand, expected_result): second_operand = DummyMetric([2, 2, 2]) @@ -338,10 +337,8 @@ def test_metrics_rmatmul(first_operand, expected_result): assert torch.allclose(expected_result, final_rmatmul.compute()) -@pytest.mark.parametrize( - ["first_operand", "expected_result"], - [pytest.param(torch.tensor(2), torch.tensor(2), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_4))] -) +@pytest.mark.parametrize(["first_operand", "expected_result"], + [pytest.param(tensor(2), tensor(2), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_4))]) def test_metrics_rmod(first_operand, expected_result): second_operand = DummyMetric(5) @@ -355,9 +352,9 @@ def test_metrics_rmod(first_operand, expected_result): @pytest.mark.parametrize( "first_operand,expected_result", [ - pytest.param(DummyMetric(2), torch.tensor(4)), - pytest.param(2, torch.tensor(4)), - pytest.param(2.0, torch.tensor(4.0), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_6)), + pytest.param(DummyMetric(2), tensor(4)), + pytest.param(2, tensor(4)), + pytest.param(2.0, tensor(4.0), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_6)), ], ) def test_metrics_rpow(first_operand, expected_result): @@ -373,10 +370,10 @@ def test_metrics_rpow(first_operand, expected_result): @pytest.mark.parametrize( ["first_operand", "expected_result"], [ - (DummyMetric(3), torch.tensor(1)), - (3, torch.tensor(1)), - (3.0, torch.tensor(1.0)), - pytest.param(torch.tensor(3), torch.tensor(1), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_4)), + (DummyMetric(3), tensor(1)), + (3, tensor(1)), + (3.0, tensor(1.0)), + pytest.param(tensor(3), tensor(1), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_4)), ], ) def test_metrics_rsub(first_operand, expected_result): @@ -392,10 +389,10 @@ def test_metrics_rsub(first_operand, expected_result): @pytest.mark.parametrize( ["first_operand", "expected_result"], [ - (DummyMetric(6), torch.tensor(2.0)), - (6, torch.tensor(2.0)), - (6.0, torch.tensor(2.0)), - (torch.tensor(6), torch.tensor(2.0)), + (DummyMetric(6), tensor(2.0)), + (6, tensor(2.0)), + (6.0, tensor(2.0)), + (tensor(6), tensor(2.0)), ], ) @pytest.mark.skipif(**_MARK_TORCH_MIN_1_5) @@ -412,10 +409,10 @@ def test_metrics_rtruediv(first_operand, expected_result): @pytest.mark.parametrize( ["second_operand", "expected_result"], [ - (DummyMetric(2), torch.tensor(1)), - (2, torch.tensor(1)), - (2.0, torch.tensor(1.0)), - (torch.tensor(2), torch.tensor(1)), + (DummyMetric(2), tensor(1)), + (2, tensor(1)), + (2.0, tensor(1.0)), + (tensor(2), tensor(1)), ], ) def test_metrics_sub(second_operand, expected_result): @@ -431,10 +428,10 @@ def test_metrics_sub(second_operand, expected_result): @pytest.mark.parametrize( ["second_operand", "expected_result"], [ - (DummyMetric(3), torch.tensor(2.0)), - (3, torch.tensor(2.0)), - (3.0, torch.tensor(2.0)), - (torch.tensor(3), torch.tensor(2.0)), + (DummyMetric(3), tensor(2.0)), + (3, tensor(2.0)), + (3.0, tensor(2.0)), + (tensor(3), tensor(2.0)), ], ) @pytest.mark.skipif(**_MARK_TORCH_MIN_1_5) @@ -450,7 +447,7 @@ def test_metrics_truediv(second_operand, expected_result): @pytest.mark.parametrize( ["second_operand", "expected_result"], - [(DummyMetric([1, 0, 3]), torch.tensor([-2, -2, 0])), (torch.tensor([1, 0, 3]), torch.tensor([-2, -2, 0]))], + [(DummyMetric([1, 0, 3]), tensor([-2, -2, 0])), (tensor([1, 0, 3]), tensor([-2, -2, 0]))], ) @pytest.mark.skipif(**_MARK_TORCH_MIN_1_5) def test_metrics_xor(second_operand, expected_result): @@ -473,7 +470,7 @@ def test_metrics_abs(): assert isinstance(final_abs, CompositionalMetric) - assert torch.allclose(torch.tensor(1), final_abs.compute()) + assert torch.allclose(tensor(1), final_abs.compute()) def test_metrics_invert(): @@ -481,7 +478,7 @@ def test_metrics_invert(): final_inverse = ~first_metric assert isinstance(final_inverse, CompositionalMetric) - assert torch.allclose(torch.tensor(-2), final_inverse.compute()) + assert torch.allclose(tensor(-2), final_inverse.compute()) def test_metrics_neg(): @@ -489,7 +486,7 @@ def test_metrics_neg(): final_neg = neg(first_metric) assert isinstance(final_neg, CompositionalMetric) - assert torch.allclose(torch.tensor(-1), final_neg.compute()) + assert torch.allclose(tensor(-1), final_neg.compute()) def test_metrics_pos(): @@ -497,7 +494,7 @@ def test_metrics_pos(): final_pos = pos(first_metric) assert isinstance(final_pos, CompositionalMetric) - assert torch.allclose(torch.tensor(1), final_pos.compute()) + assert torch.allclose(tensor(1), final_pos.compute()) def test_compositional_metrics_update(): diff --git a/tests/bases/test_ddp.py b/tests/bases/test_ddp.py index 36824631d80..2868d89701a 100644 --- a/tests/bases/test_ddp.py +++ b/tests/bases/test_ddp.py @@ -15,6 +15,7 @@ import pytest import torch +from torch import tensor from tests.helpers.testers import DummyMetric, setup_ddp from torchmetrics import Metric @@ -26,7 +27,7 @@ def _test_ddp_sum(rank, worldsize): setup_ddp(rank, worldsize) dummy = DummyMetric() dummy._reductions = {"foo": torch.sum} - dummy.foo = torch.tensor(1) + dummy.foo = tensor(1) dummy._sync_dist() assert dummy.foo == worldsize @@ -36,21 +37,21 @@ def _test_ddp_cat(rank, worldsize): setup_ddp(rank, worldsize) dummy = DummyMetric() dummy._reductions = {"foo": torch.cat} - dummy.foo = [torch.tensor([1])] + dummy.foo = [tensor([1])] dummy._sync_dist() - assert torch.all(torch.eq(dummy.foo, torch.tensor([1, 1]))) + assert torch.all(torch.eq(dummy.foo, tensor([1, 1]))) def _test_ddp_sum_cat(rank, worldsize): setup_ddp(rank, worldsize) dummy = DummyMetric() dummy._reductions = {"foo": torch.cat, "bar": torch.sum} - dummy.foo = [torch.tensor([1])] - dummy.bar = torch.tensor(1) + dummy.foo = [tensor([1])] + dummy.bar = tensor(1) dummy._sync_dist() - assert torch.all(torch.eq(dummy.foo, torch.tensor([1, 1]))) + assert torch.all(torch.eq(dummy.foo, tensor([1, 1]))) assert dummy.bar == worldsize diff --git a/tests/bases/test_metric.py b/tests/bases/test_metric.py index dc63740f727..5a9580cbce1 100644 --- a/tests/bases/test_metric.py +++ b/tests/bases/test_metric.py @@ -13,15 +13,15 @@ # limitations under the License. import pickle from collections import OrderedDict -from distutils.version import LooseVersion import cloudpickle import numpy as np import pytest import torch -from torch import nn +from torch import nn, tensor from tests.helpers.testers import DummyListMetric, DummyMetric, DummyMetricSum +from torchmetrics.utilities.imports import _TORCH_LOWER_1_6 torch.manual_seed(42) @@ -33,43 +33,43 @@ def test_inherit(): def test_add_state(): a = DummyMetric() - a.add_state("a", torch.tensor(0), "sum") - assert a._reductions["a"](torch.tensor([1, 1])) == 2 + a.add_state("a", tensor(0), "sum") + assert a._reductions["a"](tensor([1, 1])) == 2 - a.add_state("b", torch.tensor(0), "mean") - assert np.allclose(a._reductions["b"](torch.tensor([1.0, 2.0])).numpy(), 1.5) + a.add_state("b", tensor(0), "mean") + assert np.allclose(a._reductions["b"](tensor([1.0, 2.0])).numpy(), 1.5) - a.add_state("c", torch.tensor(0), "cat") - assert a._reductions["c"]([torch.tensor([1]), torch.tensor([1])]).shape == (2, ) + a.add_state("c", tensor(0), "cat") + assert a._reductions["c"]([tensor([1]), tensor([1])]).shape == (2, ) with pytest.raises(ValueError): - a.add_state("d1", torch.tensor(0), 'xyz') + a.add_state("d1", tensor(0), 'xyz') with pytest.raises(ValueError): - a.add_state("d2", torch.tensor(0), 42) + a.add_state("d2", tensor(0), 42) with pytest.raises(ValueError): - a.add_state("d3", [torch.tensor(0)], 'sum') + a.add_state("d3", [tensor(0)], 'sum') with pytest.raises(ValueError): a.add_state("d4", 42, 'sum') - def custom_fx(x): + def custom_fx(_): return -1 - a.add_state("e", torch.tensor(0), custom_fx) - assert a._reductions["e"](torch.tensor([1, 1])) == -1 + a.add_state("e", tensor(0), custom_fx) + assert a._reductions["e"](tensor([1, 1])) == -1 def test_add_state_persistent(): a = DummyMetric() - a.add_state("a", torch.tensor(0), "sum", persistent=True) + a.add_state("a", tensor(0), "sum", persistent=True) assert "a" in a.state_dict() - a.add_state("b", torch.tensor(0), "sum", persistent=False) + a.add_state("b", tensor(0), "sum", persistent=False) - if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): + if _TORCH_LOWER_1_6: assert "b" not in a.state_dict() @@ -83,13 +83,13 @@ class B(DummyListMetric): a = A() assert a.x == 0 - a.x = torch.tensor(5) + a.x = tensor(5) a.reset() assert a.x == 0 b = B() assert isinstance(b.x, list) and len(b.x) == 0 - b.x = torch.tensor(5) + b.x = tensor(5) b.reset() assert isinstance(b.x, list) and len(b.x) == 0 @@ -155,10 +155,10 @@ class B(DummyListMetric): b2 = B() assert hash(b1) == hash(b2) assert isinstance(b1.x, list) and len(b1.x) == 0 - b1.x.append(torch.tensor(5)) + b1.x.append(tensor(5)) assert isinstance(hash(b1), int) # <- check that nothing crashes assert isinstance(b1.x, list) and len(b1.x) == 1 - b2.x.append(torch.tensor(5)) + b2.x.append(tensor(5)) # Sanity: assert isinstance(b2.x, list) and len(b2.x) == 1 # Now that they have tensor contents, they should have different hashes: @@ -222,15 +222,15 @@ class TestModule(nn.Module): def __init__(self): super().__init__() self.metric = DummyMetric() - self.metric.add_state('a', torch.tensor(0), persistent=True) + self.metric.add_state('a', tensor(0), persistent=True) self.metric.add_state('b', [], persistent=True) - self.metric.register_buffer('c', torch.tensor(0)) + self.metric.register_buffer('c', tensor(0)) module = TestModule() expected_state_dict = { - 'metric.a': torch.tensor(0), + 'metric.a': tensor(0), 'metric.b': [], - 'metric.c': torch.tensor(0), + 'metric.c': tensor(0), } assert module.state_dict() == expected_state_dict diff --git a/tests/classification/test_accuracy.py b/tests/classification/test_accuracy.py index 1074d6ce37d..2456b22b441 100644 --- a/tests/classification/test_accuracy.py +++ b/tests/classification/test_accuracy.py @@ -17,6 +17,7 @@ import pytest import torch from sklearn.metrics import accuracy_score as sk_accuracy +from torch import tensor from tests.classification.inputs import _input_binary, _input_binary_prob from tests.classification.inputs import _input_multiclass as _input_mcls @@ -109,12 +110,12 @@ def test_accuracy_fn(self, preds, target, subset_accuracy): # The preds in these examples always put highest probability on class 3, second highest on class 2, # third highest on class 1, and lowest on class 0 -_topk_preds_mcls = torch.tensor([_l1to4t3, _l1to4t3]).float() -_topk_target_mcls = torch.tensor([[1, 2, 3], [2, 1, 0]]) +_topk_preds_mcls = tensor([_l1to4t3, _l1to4t3]).float() +_topk_target_mcls = tensor([[1, 2, 3], [2, 1, 0]]) # This is like for MC case, but one sample in each batch is sabotaged with 0 class prediction :) -_topk_preds_mdmc = torch.tensor([_l1to4t3_mcls, _l1to4t3_mcls]).float() -_topk_target_mdmc = torch.tensor([[[1, 1, 0], [2, 2, 2], [3, 3, 3]], [[2, 2, 0], [1, 1, 1], [0, 0, 0]]]) +_topk_preds_mdmc = tensor([_l1to4t3_mcls, _l1to4t3_mcls]).float() +_topk_target_mdmc = tensor([[[1, 1, 0], [2, 2, 2], [3, 3, 3]], [[2, 2, 0], [1, 1, 1], [0, 0, 0]]]) # Replace with a proper sk_metric test once sklearn 0.24 hits :) diff --git a/tests/classification/test_auc.py b/tests/classification/test_auc.py index 13cf3f38d58..59ebd8d1480 100644 --- a/tests/classification/test_auc.py +++ b/tests/classification/test_auc.py @@ -17,6 +17,7 @@ import pytest import torch from sklearn.metrics import auc as _sk_auc +from torch import tensor from tests.helpers.testers import NUM_BATCHES, MetricTester from torchmetrics.classification.auc import AUC @@ -43,7 +44,7 @@ def sk_auc(x, y): y = y[idx] if i % 2 == 0 else x[idx[::-1]] x = x.reshape(NUM_BATCHES, 8) y = y.reshape(NUM_BATCHES, 8) - _examples.append(Input(x=torch.tensor(x), y=torch.tensor(y))) + _examples.append(Input(x=tensor(x), y=tensor(y))) @pytest.mark.parametrize("x, y", _examples) @@ -74,4 +75,4 @@ def test_auc_functional(self, x, y): ]) def test_auc(x, y, expected): # Test Area Under Curve (AUC) computation - assert auc(torch.tensor(x), torch.tensor(y)) == expected + assert auc(tensor(x), tensor(y), reorder=True) == expected diff --git a/tests/classification/test_auroc.py b/tests/classification/test_auroc.py index 6b6bababd55..bff414cdab0 100644 --- a/tests/classification/test_auroc.py +++ b/tests/classification/test_auroc.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from distutils.version import LooseVersion from functools import partial import pytest @@ -26,11 +25,13 @@ from tests.helpers.testers import NUM_CLASSES, MetricTester from torchmetrics.classification.auroc import AUROC from torchmetrics.functional import auroc +from torchmetrics.utilities.imports import _TORCH_LOWER_1_6 torch.manual_seed(42) def _sk_auroc_binary_prob(preds, target, num_classes, average='macro', max_fpr=None, multi_class='ovr'): + # todo: `multi_class` is unused sk_preds = preds.view(-1).numpy() sk_target = target.view(-1).numpy() return sk_roc_auc_score(y_true=sk_target, y_score=sk_preds, average=average, max_fpr=max_fpr) @@ -92,7 +93,7 @@ def _sk_auroc_multilabel_multidim_prob(preds, target, num_classes, average='macr (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_auroc_multilabel_prob, NUM_CLASSES), (_input_mlmd_prob.preds, _input_mlmd_prob.target, _sk_auroc_multilabel_multidim_prob, NUM_CLASSES)] ) -@pytest.mark.parametrize("average", ['macro', 'weighted']) +@pytest.mark.parametrize("average", ['macro', 'weighted', 'micro']) @pytest.mark.parametrize("max_fpr", [None, 0.8, 0.5]) class TestAUROC(MetricTester): @@ -104,9 +105,13 @@ def test_auroc(self, preds, target, sk_metric, num_classes, average, max_fpr, dd pytest.skip('max_fpr parameter not support for multi class or multi label') # max_fpr only supported for torch v1.6 or higher - if max_fpr is not None and LooseVersion(torch.__version__) < LooseVersion('1.6.0'): + if max_fpr is not None and _TORCH_LOWER_1_6: pytest.skip('requires torch v1.6 or higher to test max_fpr argument') + # average='micro' only supported for multilabel + if average == 'micro' and preds.ndim > 2 and preds.ndim == target.ndim + 1: + pytest.skip('micro argument only support for multilabel input') + self.run_class_metric_test( ddp=ddp, preds=preds, @@ -127,9 +132,13 @@ def test_auroc_functional(self, preds, target, sk_metric, num_classes, average, pytest.skip('max_fpr parameter not support for multi class or multi label') # max_fpr only supported for torch v1.6 or higher - if max_fpr is not None and LooseVersion(torch.__version__) < LooseVersion('1.6.0'): + if max_fpr is not None and _TORCH_LOWER_1_6: pytest.skip('requires torch v1.6 or higher to test max_fpr argument') + # average='micro' only supported for multilabel + if average == 'micro' and preds.ndim > 2 and preds.ndim == target.ndim + 1: + pytest.skip('micro argument only support for multilabel input') + self.run_functional_metric_test( preds, target, diff --git a/tests/classification/test_average_precision.py b/tests/classification/test_average_precision.py index b3a78f9741d..89d6c2c0086 100644 --- a/tests/classification/test_average_precision.py +++ b/tests/classification/test_average_precision.py @@ -17,6 +17,7 @@ import pytest import torch from sklearn.metrics import average_precision_score as sk_average_precision_score +from torch import tensor from tests.classification.inputs import _input_binary_prob from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob @@ -101,9 +102,9 @@ def test_average_precision_functional(self, preds, target, sk_metric, num_classe # And a constant score # The precision is then the fraction of positive whatever the recall # is, as there is only one threshold: - pytest.param(torch.tensor([1, 1, 1, 1]), torch.tensor([0, 0, 0, 1]), .25), + pytest.param(tensor([1, 1, 1, 1]), tensor([0, 0, 0, 1]), .25), # With threshold 0.8 : 1 TP and 2 TN and one FN - pytest.param(torch.tensor([.6, .7, .8, 9]), torch.tensor([1, 0, 0, 1]), .75), + pytest.param(tensor([.6, .7, .8, 9]), tensor([1, 0, 0, 1]), .75), ] ) def test_average_precision(scores, target, expected_score): diff --git a/tests/classification/test_f_beta.py b/tests/classification/test_f_beta.py index 9925f951f54..f1fc17147e7 100644 --- a/tests/classification/test_f_beta.py +++ b/tests/classification/test_f_beta.py @@ -17,6 +17,7 @@ import pytest import torch from sklearn.metrics import fbeta_score +from torch import tensor from tests.classification.inputs import _input_binary, _input_binary_prob from tests.classification.inputs import _input_multiclass as _input_mcls @@ -34,6 +35,7 @@ def _sk_fbeta_binary_prob(preds, target, average='micro', beta=1.0): + # todo: `average` is unused sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) sk_target = target.view(-1).numpy() @@ -41,6 +43,7 @@ def _sk_fbeta_binary_prob(preds, target, average='micro', beta=1.0): def _sk_fbeta_binary(preds, target, average='micro', beta=1.0): + # todo: `average` is unused sk_preds = preds.view(-1).numpy() sk_target = target.view(-1).numpy() @@ -152,8 +155,8 @@ def test_fbeta_functional(self, preds, target, sk_metric, num_classes, multilabe pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], 2, [0.5, 0.5]), ]) def test_fbeta_score(pred, target, beta, exp_score): - score = fbeta(torch.tensor(pred), torch.tensor(target), num_classes=1, beta=beta, average='none') - assert torch.allclose(score, torch.tensor(exp_score)) + score = fbeta(tensor(pred), tensor(target), num_classes=1, beta=beta, average='none') + assert torch.allclose(score, tensor(exp_score)) @pytest.mark.parametrize(['pred', 'target', 'exp_score'], [ @@ -162,5 +165,5 @@ def test_fbeta_score(pred, target, beta, exp_score): pytest.param([1., 0., 1., 0.], [1., 0., 1., 0.], [1.0, 1.0]), ]) def test_f1_score(pred, target, exp_score): - score = f1(torch.tensor(pred), torch.tensor(target), num_classes=1, average='none') - assert torch.allclose(score, torch.tensor(exp_score)) + score = f1(tensor(pred), tensor(target), num_classes=1, average='none') + assert torch.allclose(score, tensor(exp_score)) diff --git a/tests/classification/test_inputs.py b/tests/classification/test_inputs.py index 6ac132ebc49..f7694a13888 100644 --- a/tests/classification/test_inputs.py +++ b/tests/classification/test_inputs.py @@ -13,7 +13,7 @@ # limitations under the License. import pytest import torch -from torch import rand, randint +from torch import Tensor, rand, randint, tensor from tests.classification.inputs import Input from tests.classification.inputs import _input_binary as _bin @@ -52,7 +52,7 @@ _mdmc_prob_2cls = Input(_mdmc_prob_2cls_preds, randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM))) # Some utils -T = torch.Tensor +T = Tensor def _idn(x): @@ -209,7 +209,7 @@ def test_threshold(): preds_probs_out, _, _ = _input_format_classification(preds_probs, target, threshold=0.5) - assert torch.equal(torch.tensor([0, 1, 1], dtype=torch.int), preds_probs_out.squeeze().int()) + assert torch.equal(tensor([0, 1, 1], dtype=torch.int), preds_probs_out.squeeze().int()) ######################################################################## diff --git a/tests/classification/test_iou.py b/tests/classification/test_iou.py index 2baed984ffb..2e21ab1b1a9 100644 --- a/tests/classification/test_iou.py +++ b/tests/classification/test_iou.py @@ -17,6 +17,7 @@ import pytest import torch from sklearn.metrics import jaccard_score as sk_jaccard_score +from torch import Tensor, tensor from tests.classification.inputs import _input_binary, _input_binary_prob from tests.classification.inputs import _input_multiclass as _input_mcls @@ -134,20 +135,20 @@ def test_confusion_matrix_functional(self, reduction, preds, target, sk_metric, @pytest.mark.parametrize(['half_ones', 'reduction', 'ignore_index', 'expected'], [ - pytest.param(False, 'none', None, torch.Tensor([1, 1, 1])), - pytest.param(False, 'elementwise_mean', None, torch.Tensor([1])), - pytest.param(False, 'none', 0, torch.Tensor([1, 1])), - pytest.param(True, 'none', None, torch.Tensor([0.5, 0.5, 0.5])), - pytest.param(True, 'elementwise_mean', None, torch.Tensor([0.5])), - pytest.param(True, 'none', 0, torch.Tensor([0.5, 0.5])), + pytest.param(False, 'none', None, Tensor([1, 1, 1])), + pytest.param(False, 'elementwise_mean', None, Tensor([1])), + pytest.param(False, 'none', 0, Tensor([1, 1])), + pytest.param(True, 'none', None, Tensor([0.5, 0.5, 0.5])), + pytest.param(True, 'elementwise_mean', None, Tensor([0.5])), + pytest.param(True, 'none', 0, Tensor([0.5, 0.5])), ]) def test_iou(half_ones, reduction, ignore_index, expected): - pred = (torch.arange(120) % 3).view(-1, 1) + preds = (torch.arange(120) % 3).view(-1, 1) target = (torch.arange(120) % 3).view(-1, 1) if half_ones: - pred[:60] = 1 + preds[:60] = 1 iou_val = iou( - pred=pred, + preds=preds, target=target, ignore_index=ignore_index, reduction=reduction, @@ -190,14 +191,14 @@ def test_iou(half_ones, reduction, ignore_index, expected): ) def test_iou_absent_score(pred, target, ignore_index, absent_score, num_classes, expected): iou_val = iou( - pred=torch.tensor(pred), - target=torch.tensor(target), + preds=tensor(pred), + target=tensor(target), ignore_index=ignore_index, absent_score=absent_score, num_classes=num_classes, reduction='none', ) - assert torch.allclose(iou_val, torch.tensor(expected).to(iou_val)) + assert torch.allclose(iou_val, tensor(expected).to(iou_val)) # example data taken from @@ -220,10 +221,10 @@ def test_iou_absent_score(pred, target, ignore_index, absent_score, num_classes, ) def test_iou_ignore_index(pred, target, ignore_index, num_classes, reduction, expected): iou_val = iou( - pred=torch.tensor(pred), - target=torch.tensor(target), + preds=tensor(pred), + target=tensor(target), ignore_index=ignore_index, num_classes=num_classes, reduction=reduction, ) - assert torch.allclose(iou_val, torch.tensor(expected).to(iou_val)) + assert torch.allclose(iou_val, tensor(expected).to(iou_val)) diff --git a/tests/classification/test_matthews_corrcoef.py b/tests/classification/test_matthews_corrcoef.py new file mode 100644 index 00000000000..8fcdc2f82dd --- /dev/null +++ b/tests/classification/test_matthews_corrcoef.py @@ -0,0 +1,127 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest +import torch +from sklearn.metrics import matthews_corrcoef as sk_matthews_corrcoef + +from tests.classification.inputs import _input_binary, _input_binary_prob +from tests.classification.inputs import _input_multiclass as _input_mcls +from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob +from tests.classification.inputs import _input_multidim_multiclass as _input_mdmc +from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob +from tests.classification.inputs import _input_multilabel as _input_mlb +from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob +from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester +from torchmetrics.classification.matthews_corrcoef import MatthewsCorrcoef +from torchmetrics.functional.classification.matthews_corrcoef import matthews_corrcoef + +torch.manual_seed(42) + + +def _sk_matthews_corrcoef_binary_prob(preds, target): + sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) + sk_target = target.view(-1).numpy() + + return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds) + + +def _sk_matthews_corrcoef_binary(preds, target): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + + return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds) + + +def _sk_matthews_corrcoef_multilabel_prob(preds, target): + sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) + sk_target = target.view(-1).numpy() + + return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds) + + +def _sk_matthews_corrcoef_multilabel(preds, target): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + + return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds) + + +def _sk_matthews_corrcoef_multiclass_prob(preds, target): + sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy() + sk_target = target.view(-1).numpy() + + return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds) + + +def _sk_matthews_corrcoef_multiclass(preds, target): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + + return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds) + + +def _sk_matthews_corrcoef_multidim_multiclass_prob(preds, target): + sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy() + sk_target = target.view(-1).numpy() + + return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds) + + +def _sk_matthews_corrcoef_multidim_multiclass(preds, target): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + + return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds) + + +@pytest.mark.parametrize( + "preds, target, sk_metric, num_classes", + [(_input_binary_prob.preds, _input_binary_prob.target, _sk_matthews_corrcoef_binary_prob, 2), + (_input_binary.preds, _input_binary.target, _sk_matthews_corrcoef_binary, 2), + (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_matthews_corrcoef_multilabel_prob, 2), + (_input_mlb.preds, _input_mlb.target, _sk_matthews_corrcoef_multilabel, 2), + (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_matthews_corrcoef_multiclass_prob, NUM_CLASSES), + (_input_mcls.preds, _input_mcls.target, _sk_matthews_corrcoef_multiclass, NUM_CLASSES), + (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_matthews_corrcoef_multidim_multiclass_prob, NUM_CLASSES), + (_input_mdmc.preds, _input_mdmc.target, _sk_matthews_corrcoef_multidim_multiclass, NUM_CLASSES)] +) +class TestMatthewsCorrCoef(MetricTester): + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_matthews_corrcoef(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MatthewsCorrcoef, + sk_metric=sk_metric, + dist_sync_on_step=dist_sync_on_step, + metric_args={ + "num_classes": num_classes, + "threshold": THRESHOLD, + } + ) + + def test_matthews_corrcoef_functional(self, preds, target, sk_metric, num_classes): + self.run_functional_metric_test( + preds, + target, + metric_functional=matthews_corrcoef, + sk_metric=sk_metric, + metric_args={ + "num_classes": num_classes, + "threshold": THRESHOLD, + } + ) diff --git a/tests/classification/test_precision_recall.py b/tests/classification/test_precision_recall.py index 6f792aabe53..9d5fef1b53d 100644 --- a/tests/classification/test_precision_recall.py +++ b/tests/classification/test_precision_recall.py @@ -18,6 +18,7 @@ import pytest import torch from sklearn.metrics import precision_score, recall_score +from torch import Tensor, tensor from tests.classification.inputs import _input_binary, _input_binary_prob from tests.classification.inputs import _input_multiclass as _input_mcls @@ -35,6 +36,7 @@ def _sk_prec_recall(preds, target, sk_fn, num_classes, average, is_multiclass, ignore_index, mdmc_average=None): + # todo: `mdmc_average` is unused if average == "none": average = None if num_classes == 1: @@ -128,8 +130,8 @@ def test_wrong_params(metric, fn_metric, average, mdmc_average, num_classes, ign def test_zero_division(metric_class, metric_fn): """ Test that zero_division works correctly (currently should just set to 0). """ - preds = torch.tensor([1, 2, 1, 1]) - target = torch.tensor([2, 1, 2, 1]) + preds = tensor([1, 2, 1, 1]) + target = tensor([2, 1, 2, 1]) cl_metric = metric_class(average="none", num_classes=3) cl_metric(preds, target) @@ -152,8 +154,8 @@ def test_no_support(metric_class, metric_fn): in this case (zero_division is for now not configurable and equals 0). """ - preds = torch.tensor([1, 1, 0, 0]) - target = torch.tensor([0, 0, 0, 0]) + preds = tensor([1, 1, 0, 0]) + target = tensor([0, 0, 0, 0]) cl_metric = metric_class(average="weighted", num_classes=2, ignore_index=0) cl_metric(preds, target) @@ -198,8 +200,8 @@ def test_precision_recall_class( self, ddp: bool, dist_sync_on_step: bool, - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, sk_wrapper: Callable, metric_class: Metric, metric_fn: Callable, @@ -210,6 +212,7 @@ def test_precision_recall_class( mdmc_average: Optional[str], ignore_index: Optional[int], ): + # todo: `metric_fn` is unused if num_classes == 1 and average != "micro": pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") @@ -248,8 +251,8 @@ def test_precision_recall_class( def test_precision_recall_fn( self, - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, sk_wrapper: Callable, metric_class: Metric, metric_fn: Callable, @@ -260,6 +263,7 @@ def test_precision_recall_fn( mdmc_average: Optional[str], ignore_index: Optional[int], ): + # todo: `metric_class` is unused if num_classes == 1 and average != "micro": pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") @@ -316,36 +320,35 @@ def test_precision_recall_joint(average): assert torch.equal(recall_result, prec_recall_result[1]) -_mc_k_target = torch.tensor([0, 1, 2]) -_mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) -_ml_k_target = torch.tensor([[0, 1, 0], [1, 1, 0], [0, 0, 0]]) -_ml_k_preds = torch.tensor([[0.9, 0.2, 0.75], [0.1, 0.7, 0.8], [0.6, 0.1, 0.7]]) +_mc_k_target = tensor([0, 1, 2]) +_mc_k_preds = tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) +_ml_k_target = tensor([[0, 1, 0], [1, 1, 0], [0, 0, 0]]) +_ml_k_preds = tensor([[0.9, 0.2, 0.75], [0.1, 0.7, 0.8], [0.6, 0.1, 0.7]]) @pytest.mark.parametrize("metric_class, metric_fn", [(Recall, recall), (Precision, precision)]) @pytest.mark.parametrize( "k, preds, target, average, expected_prec, expected_recall", [ - (1, _mc_k_preds, _mc_k_target, "micro", torch.tensor(2 / 3), torch.tensor(2 / 3)), - (2, _mc_k_preds, _mc_k_target, "micro", torch.tensor(1 / 2), torch.tensor(1.0)), - (1, _ml_k_preds, _ml_k_target, "micro", torch.tensor(0.0), torch.tensor(0.0)), - (2, _ml_k_preds, _ml_k_target, "micro", torch.tensor(1 / 6), torch.tensor(1 / 3)), + (1, _mc_k_preds, _mc_k_target, "micro", tensor(2 / 3), tensor(2 / 3)), + (2, _mc_k_preds, _mc_k_target, "micro", tensor(1 / 2), tensor(1.0)), + (1, _ml_k_preds, _ml_k_target, "micro", tensor(0.0), tensor(0.0)), + (2, _ml_k_preds, _ml_k_target, "micro", tensor(1 / 6), tensor(1 / 3)), ], ) def test_top_k( metric_class, metric_fn, k: int, - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, average: str, - expected_prec: torch.Tensor, - expected_recall: torch.Tensor, + expected_prec: Tensor, + expected_recall: Tensor, ): """A simple test to check that top_k works as expected. - Just a sanity check, the tests in StatScores should already guarantee - the corectness of results. + Just a sanity check, the tests in StatScores should already guarantee the correctness of results. """ class_metric = metric_class(top_k=k, average=average, num_classes=3) diff --git a/tests/classification/test_precision_recall_curve.py b/tests/classification/test_precision_recall_curve.py index 658885571d5..3209f39a74a 100644 --- a/tests/classification/test_precision_recall_curve.py +++ b/tests/classification/test_precision_recall_curve.py @@ -17,6 +17,7 @@ import pytest import torch from sklearn.metrics import precision_recall_curve as sk_precision_recall_curve +from torch import tensor from tests.classification.inputs import _input_binary_prob from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob @@ -101,10 +102,10 @@ def test_precision_recall_curve_functional(self, preds, target, sk_metric, num_c [pytest.param([1, 2, 3, 4], [1, 0, 0, 1], [0.5, 1 / 3, 0.5, 1., 1.], [1, 0.5, 0.5, 0.5, 0.], [1, 2, 3, 4])] ) def test_pr_curve(pred, target, expected_p, expected_r, expected_t): - p, r, t = precision_recall_curve(torch.tensor(pred), torch.tensor(target)) + p, r, t = precision_recall_curve(tensor(pred), tensor(target)) assert p.size() == r.size() assert p.size(0) == t.size(0) + 1 - assert torch.allclose(p, torch.tensor(expected_p).to(p)) - assert torch.allclose(r, torch.tensor(expected_r).to(r)) - assert torch.allclose(t, torch.tensor(expected_t).to(t)) + assert torch.allclose(p, tensor(expected_p).to(p)) + assert torch.allclose(r, tensor(expected_r).to(r)) + assert torch.allclose(t, tensor(expected_t).to(t)) diff --git a/tests/classification/test_roc.py b/tests/classification/test_roc.py index ebded9d11bc..99895e6c855 100644 --- a/tests/classification/test_roc.py +++ b/tests/classification/test_roc.py @@ -17,6 +17,7 @@ import pytest import torch from sklearn.metrics import roc_curve as sk_roc_curve +from torch import tensor from tests.classification.inputs import _input_binary_prob from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob @@ -138,9 +139,9 @@ def test_roc_functional(self, preds, target, sk_metric, num_classes): pytest.param([0.5, 0.5], [0, 1], [0, 1], [0, 1]), ]) def test_roc_curve(pred, target, expected_tpr, expected_fpr): - fpr, tpr, thresh = roc(torch.tensor(pred), torch.tensor(target)) + fpr, tpr, thresh = roc(tensor(pred), tensor(target)) assert fpr.shape == tpr.shape assert fpr.size(0) == thresh.size(0) - assert torch.allclose(fpr, torch.tensor(expected_fpr).to(fpr)) - assert torch.allclose(tpr, torch.tensor(expected_tpr).to(tpr)) + assert torch.allclose(fpr, tensor(expected_fpr).to(fpr)) + assert torch.allclose(tpr, tensor(expected_tpr).to(tpr)) diff --git a/tests/classification/test_stat_scores.py b/tests/classification/test_stat_scores.py index 22a9e89c1bd..47daa099d17 100644 --- a/tests/classification/test_stat_scores.py +++ b/tests/classification/test_stat_scores.py @@ -18,6 +18,7 @@ import pytest import torch from sklearn.metrics import multilabel_confusion_matrix +from torch import Tensor, tensor from tests.classification.inputs import _input_binary, _input_binary_prob, _input_multiclass from tests.classification.inputs import _input_multiclass_prob as _input_mccls_prob @@ -34,6 +35,7 @@ def _sk_stat_scores(preds, target, reduce, num_classes, is_multiclass, ignore_index, top_k, mdmc_reduce=None): + # todo: `mdmc_reduce` is unused preds, target, _ = _input_format_classification( preds, target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass, top_k=top_k ) @@ -159,8 +161,8 @@ def test_stat_scores_class( ddp: bool, dist_sync_on_step: bool, sk_fn: Callable, - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, reduce: str, mdmc_reduce: Optional[str], num_classes: Optional[int], @@ -202,8 +204,8 @@ def test_stat_scores_class( def test_stat_scores_fn( self, sk_fn: Callable, - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, reduce: str, mdmc_reduce: Optional[str], num_classes: Optional[int], @@ -239,26 +241,26 @@ def test_stat_scores_fn( ) -_mc_k_target = torch.tensor([0, 1, 2]) -_mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) -_ml_k_target = torch.tensor([[0, 1, 0], [1, 1, 0], [0, 0, 0]]) -_ml_k_preds = torch.tensor([[0.9, 0.2, 0.75], [0.1, 0.7, 0.8], [0.6, 0.1, 0.7]]) +_mc_k_target = tensor([0, 1, 2]) +_mc_k_preds = tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) +_ml_k_target = tensor([[0, 1, 0], [1, 1, 0], [0, 0, 0]]) +_ml_k_preds = tensor([[0.9, 0.2, 0.75], [0.1, 0.7, 0.8], [0.6, 0.1, 0.7]]) @pytest.mark.parametrize( "k, preds, target, reduce, expected", [ - (1, _mc_k_preds, _mc_k_target, "micro", torch.tensor([2, 1, 5, 1, 3])), - (2, _mc_k_preds, _mc_k_target, "micro", torch.tensor([3, 3, 3, 0, 3])), - (1, _ml_k_preds, _ml_k_target, "micro", torch.tensor([0, 3, 3, 3, 3])), - (2, _ml_k_preds, _ml_k_target, "micro", torch.tensor([1, 5, 1, 2, 3])), - (1, _mc_k_preds, _mc_k_target, "macro", torch.tensor([[0, 1, 1], [0, 1, 0], [2, 1, 2], [1, 0, 0], [1, 1, 1]])), - (2, _mc_k_preds, _mc_k_target, "macro", torch.tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1], [0, 0, 0], [1, 1, 1]])), - (1, _ml_k_preds, _ml_k_target, "macro", torch.tensor([[0, 0, 0], [1, 0, 2], [1, 1, 1], [1, 2, 0], [1, 2, 0]])), - (2, _ml_k_preds, _ml_k_target, "macro", torch.tensor([[0, 1, 0], [2, 0, 3], [0, 1, 0], [1, 1, 0], [1, 2, 0]])), + (1, _mc_k_preds, _mc_k_target, "micro", tensor([2, 1, 5, 1, 3])), + (2, _mc_k_preds, _mc_k_target, "micro", tensor([3, 3, 3, 0, 3])), + (1, _ml_k_preds, _ml_k_target, "micro", tensor([0, 3, 3, 3, 3])), + (2, _ml_k_preds, _ml_k_target, "micro", tensor([1, 5, 1, 2, 3])), + (1, _mc_k_preds, _mc_k_target, "macro", tensor([[0, 1, 1], [0, 1, 0], [2, 1, 2], [1, 0, 0], [1, 1, 1]])), + (2, _mc_k_preds, _mc_k_target, "macro", tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1], [0, 0, 0], [1, 1, 1]])), + (1, _ml_k_preds, _ml_k_target, "macro", tensor([[0, 0, 0], [1, 0, 2], [1, 1, 1], [1, 2, 0], [1, 2, 0]])), + (2, _ml_k_preds, _ml_k_target, "macro", tensor([[0, 1, 0], [2, 0, 3], [0, 1, 0], [1, 1, 0], [1, 2, 0]])), ], ) -def test_top_k(k: int, preds: torch.Tensor, target: torch.Tensor, reduce: str, expected: torch.Tensor): +def test_top_k(k: int, preds: Tensor, target: Tensor, reduce: str, expected: Tensor): """ A simple test to check that top_k works as expected """ class_metric = StatScores(top_k=k, reduce=reduce, num_classes=3) diff --git a/tests/functional/test_classification.py b/tests/functional/test_classification.py index da319a0056f..a6ad4e3a7bd 100644 --- a/tests/functional/test_classification.py +++ b/tests/functional/test_classification.py @@ -14,6 +14,7 @@ import pytest import torch from pytorch_lightning import seed_everything +from torch import Tensor, tensor from torchmetrics.functional import dice_score from torchmetrics.functional.classification.precision_recall_curve import _binary_clf_curve @@ -21,7 +22,7 @@ def test_onehot(): - test_tensor = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) + test_tensor = tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) expected = torch.stack([ torch.cat([torch.eye(5, dtype=int), torch.zeros((5, 5), dtype=int)]), torch.cat([torch.zeros((5, 5), dtype=int), torch.eye(5, dtype=int)]) @@ -48,7 +49,7 @@ def test_to_categorical(): torch.cat([torch.zeros((5, 5), dtype=int), torch.eye(5, dtype=int)]) ]).to(torch.float) - expected = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) + expected = tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) assert expected.shape == (2, 5) assert test_tensor.shape == (2, 10, 5) @@ -58,13 +59,13 @@ def test_to_categorical(): assert torch.allclose(result, expected.to(result.dtype)) -@pytest.mark.parametrize(['pred', 'target', 'num_classes', 'expected_num_classes'], [ +@pytest.mark.parametrize(['preds', 'target', 'num_classes', 'expected_num_classes'], [ pytest.param(torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), 10, 10), pytest.param(torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), None, 10), pytest.param(torch.rand(32, 28, 28), torch.randint(10, (32, 28, 28)), None, 10), ]) -def test_get_num_classes(pred, target, num_classes, expected_num_classes): - assert get_num_classes(pred, target, num_classes) == expected_num_classes +def test_get_num_classes(preds, target, num_classes, expected_num_classes): + assert get_num_classes(preds, target, num_classes) == expected_num_classes @pytest.mark.parametrize(['sample_weight', 'pos_label', "exp_shape"], [ @@ -77,15 +78,15 @@ def test_binary_clf_curve(sample_weight, pos_label, exp_shape): # because when the array changes, you also have to fix the shape seed_everything(0) pred = torch.randint(low=51, high=99, size=(100, ), dtype=torch.float) / 100 - target = torch.tensor([0, 1] * 50, dtype=torch.int) + target = tensor([0, 1] * 50, dtype=torch.int) if sample_weight is not None: sample_weight = torch.ones_like(pred) * sample_weight fps, tps, thresh = _binary_clf_curve(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label) - assert isinstance(tps, torch.Tensor) - assert isinstance(fps, torch.Tensor) - assert isinstance(thresh, torch.Tensor) + assert isinstance(tps, Tensor) + assert isinstance(fps, Tensor) + assert isinstance(thresh, Tensor) assert tps.shape == (exp_shape, ) assert fps.shape == (exp_shape, ) assert thresh.shape == (exp_shape, ) @@ -98,5 +99,5 @@ def test_binary_clf_curve(sample_weight, pos_label, exp_shape): pytest.param([[1, 1], [0, 0]], [[1, 1], [0, 0]], 1.), ]) def test_dice_score(pred, target, expected): - score = dice_score(torch.tensor(pred), torch.tensor(target)) + score = dice_score(tensor(pred), tensor(target)) assert score == expected diff --git a/tests/functional/test_image_gradients.py b/tests/functional/test_image_gradients.py index ebbb9cba873..908eefc32c2 100644 --- a/tests/functional/test_image_gradients.py +++ b/tests/functional/test_image_gradients.py @@ -13,6 +13,7 @@ # limitations under the License. import pytest import torch +from torch import Tensor from torchmetrics.functional import image_gradients @@ -65,16 +66,7 @@ def test_multi_batch_image_gradients(): [5., 5., 5., 5., 5.], [0., 0., 0., 0., 0.], ] - - true_dx = [ - [1., 1., 1., 1., 0.], - [1., 1., 1., 1., 0.], - [1., 1., 1., 1., 0.], - [1., 1., 1., 1., 0.], - [1., 1., 1., 1., 0.], - ] - true_dy = torch.Tensor(true_dy) - true_dx = torch.Tensor(true_dx) + true_dy = Tensor(true_dy) dy, dx = image_gradients(image) @@ -113,8 +105,8 @@ def test_image_gradients(): [1., 1., 1., 1., 0.], ] - true_dy = torch.Tensor(true_dy) - true_dx = torch.Tensor(true_dx) + true_dy = Tensor(true_dy) + true_dx = Tensor(true_dx) dy, dx = image_gradients(image) diff --git a/tests/functional/test_nlp.py b/tests/functional/test_nlp.py index 3c37ee9178e..76bf77041da 100644 --- a/tests/functional/test_nlp.py +++ b/tests/functional/test_nlp.py @@ -14,6 +14,7 @@ import pytest import torch from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu, sentence_bleu +from torch import tensor from torchmetrics.functional import bleu_score @@ -62,20 +63,20 @@ def test_bleu_score(weights, n_gram, smooth_func, smooth): smoothing_function=smooth_func, ) pl_output = bleu_score([HYPOTHESIS1], [[REFERENCE1, REFERENCE2, REFERENCE3]], n_gram=n_gram, smooth=smooth) - assert torch.allclose(pl_output, torch.tensor(nltk_output)) + assert torch.allclose(pl_output, tensor(nltk_output)) nltk_output = corpus_bleu(LIST_OF_REFERENCES, HYPOTHESES, weights=weights, smoothing_function=smooth_func) pl_output = bleu_score(HYPOTHESES, LIST_OF_REFERENCES, n_gram=n_gram, smooth=smooth) - assert torch.allclose(pl_output, torch.tensor(nltk_output)) + assert torch.allclose(pl_output, tensor(nltk_output)) def test_bleu_empty(): hyp = [[]] ref = [[[]]] - assert bleu_score(hyp, ref) == torch.tensor(0.0) + assert bleu_score(hyp, ref) == tensor(0.0) def test_no_4_gram(): hyps = [["My", "full", "pytorch-lightning"]] refs = [[["My", "full", "pytorch-lightning", "test"], ["Completely", "Different"]]] - assert bleu_score(hyps, refs) == torch.tensor(0.0) + assert bleu_score(hyps, refs) == tensor(0.0) diff --git a/tests/functional/test_retrieval.py b/tests/functional/test_retrieval.py new file mode 100644 index 00000000000..3aaf60d607d --- /dev/null +++ b/tests/functional/test_retrieval.py @@ -0,0 +1,30 @@ +import math + +import numpy as np +import pytest +import torch +from sklearn.metrics import average_precision_score as sk_average_precision + +from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision + + +@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [ + pytest.param(sk_average_precision, retrieval_average_precision), +]) +@pytest.mark.parametrize("size", [1, 4, 10, 100]) +def test_against_sklearn(sklearn_metric, torch_metric, size): + """Compare PL metrics to sklearn version. """ + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + a = np.random.randn(size) + b = np.random.randn(size) > 0 + + sk = torch.tensor(sklearn_metric(b, a), device=device) + pl = torch_metric(torch.tensor(a, device=device), torch.tensor(b, device=device)) + + # `torch_metric`s return 0 when no label is True + # while `sklearn.average_precision_score` returns NaN + if math.isnan(sk): + assert pl == 0 + else: + assert torch.allclose(sk.float(), pl.float()) diff --git a/tests/functional/test_self_supervised.py b/tests/functional/test_self_supervised.py index d90a3f81f42..06bd23c53b7 100644 --- a/tests/functional/test_self_supervised.py +++ b/tests/functional/test_self_supervised.py @@ -14,6 +14,7 @@ import pytest import torch from sklearn.metrics import pairwise +from torch import tensor from torchmetrics.functional import embedding_similarity @@ -40,6 +41,6 @@ def sklearn_embedding_distance(batch, similarity, reduction): return dist sk_dist = sklearn_embedding_distance(batch.cpu().detach().numpy(), similarity=similarity, reduction=reduction) - sk_dist = torch.tensor(sk_dist, dtype=torch.float, device=device) + sk_dist = tensor(sk_dist, dtype=torch.float, device=device) assert torch.allclose(sk_dist, pl_dist) diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py index 3c5c856abb1..38c92a80c14 100644 --- a/tests/helpers/__init__.py +++ b/tests/helpers/__init__.py @@ -1,7 +1,5 @@ -from distutils.version import LooseVersion +from torchmetrics.utilities.imports import _TORCH_LOWER_1_4, _TORCH_LOWER_1_5, _TORCH_LOWER_1_6 -import torch - -_MARK_TORCH_MIN_1_4 = dict(condition=LooseVersion(torch.__version__) < LooseVersion("1.4"), reason='required PT >= 1.4') -_MARK_TORCH_MIN_1_5 = dict(condition=LooseVersion(torch.__version__) < LooseVersion("1.5"), reason='required PT >= 1.5') -_MARK_TORCH_MIN_1_6 = dict(condition=LooseVersion(torch.__version__) < LooseVersion("1.6"), reason='required PT >= 1.6') +_MARK_TORCH_MIN_1_4 = dict(condition=_TORCH_LOWER_1_4, reason='required PT >= 1.4') +_MARK_TORCH_MIN_1_5 = dict(condition=_TORCH_LOWER_1_5, reason='required PT >= 1.5') +_MARK_TORCH_MIN_1_6 = dict(condition=_TORCH_LOWER_1_6, reason='required PT >= 1.6') diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index 4834edd5448..149f300f0e0 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -20,6 +20,7 @@ import numpy as np import pytest import torch +from torch import Tensor, tensor from torch.multiprocessing import Pool, set_start_method from torchmetrics import Metric @@ -38,7 +39,7 @@ def setup_ddp(rank, world_size): - """ Setup ddp enviroment """ + """ Setup ddp environment """ os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "8088" @@ -51,7 +52,7 @@ def _assert_allclose(pl_result, sk_result, atol: float = 1e-8): a certain tolerance """ # single output compare - if isinstance(pl_result, torch.Tensor): + if isinstance(pl_result, Tensor): assert np.allclose(pl_result.numpy(), sk_result, atol=atol, equal_nan=True) # multi output compare elif isinstance(pl_result, (tuple, list)): @@ -69,18 +70,18 @@ def _assert_tensor(pl_result): for plr in pl_result: _assert_tensor(plr) else: - assert isinstance(pl_result, torch.Tensor) + assert isinstance(pl_result, Tensor) def _class_test( rank: int, worldsize: int, - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, metric_class: Metric, sk_metric: Callable, dist_sync_on_step: bool, - metric_args: dict = {}, + metric_args: dict = None, check_dist_sync_on_step: bool = True, check_batch: bool = True, atol: float = 1e-8, @@ -103,6 +104,8 @@ def _class_test( check_batch: bool, if true will check if the metric is also correctly calculated across devices for each batch (and not just at the end) """ + if not metric_args: + metric_args = {} # Instanciate lightning metric metric = metric_class(compute_on_step=True, dist_sync_on_step=dist_sync_on_step, **metric_args) @@ -140,11 +143,11 @@ def _class_test( def _functional_test( - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, metric_functional: Callable, sk_metric: Callable, - metric_args: dict = {}, + metric_args: dict = None, atol: float = 1e-8, ): """Utility function doing the actual comparison between lightning functional metric @@ -157,6 +160,8 @@ def _functional_test( sk_metric: callable function that is used for comparison metric_args: dict with additional arguments used for class initialization """ + if not metric_args: + metric_args = {} metric = partial(metric_functional, **metric_args) for i in range(NUM_BATCHES): @@ -195,11 +200,11 @@ def teardown_class(self): def run_functional_metric_test( self, - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, metric_functional: Callable, sk_metric: Callable, - metric_args: dict = {}, + metric_args: dict = None, ): """Main method that should be used for testing functions. Call this inside testing method @@ -223,12 +228,12 @@ def run_functional_metric_test( def run_class_metric_test( self, ddp: bool, - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, metric_class: Metric, sk_metric: Callable, dist_sync_on_step: bool, - metric_args: dict = {}, + metric_args: dict = None, check_dist_sync_on_step: bool = True, check_batch: bool = True, ): @@ -249,6 +254,8 @@ def run_class_metric_test( check_batch: bool, if true will check if the metric is also correctly calculated across devices for each batch (and not just at the end) """ + if not metric_args: + metric_args = {} if ddp: if sys.platform == "win32": pytest.skip("DDP not supported on windows") @@ -289,7 +296,7 @@ class DummyMetric(Metric): def __init__(self): super().__init__() - self.add_state("x", torch.tensor(0.0), dist_reduce_fx=None) + self.add_state("x", tensor(0.0), dist_reduce_fx=None) def update(self): pass diff --git a/tests/integrations/lightning_models.py b/tests/integrations/lightning_models.py index 8bc5b3c94f7..8482d76f77e 100644 --- a/tests/integrations/lightning_models.py +++ b/tests/integrations/lightning_models.py @@ -68,7 +68,8 @@ def training_step(...): def forward(self, x): return self.layer(x) - def loss(self, batch, prediction): + @staticmethod + def loss(_, prediction): # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction)) diff --git a/tests/integrations/test_metric_lightning.py b/tests/integrations/test_metric_lightning.py index fdda356a7f1..77db73700e1 100644 --- a/tests/integrations/test_metric_lightning.py +++ b/tests/integrations/test_metric_lightning.py @@ -13,6 +13,7 @@ # limitations under the License. import torch from pytorch_lightning import Trainer +from torch import tensor from tests.integrations.lightning_models import BoringModel from torchmetrics import Metric @@ -22,7 +23,7 @@ class SumMetric(Metric): def __init__(self): super().__init__() - self.add_state("x", torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("x", tensor(0.0), dist_reduce_fx="sum") def update(self, x): self.x += x @@ -35,7 +36,7 @@ class DiffMetric(Metric): def __init__(self): super().__init__() - self.add_state("x", torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("x", tensor(0.0), dist_reduce_fx="sum") def update(self, x): self.x -= x @@ -116,8 +117,8 @@ def training_epoch_end(self, outs): # trainer.fit(model) # # logged = trainer.logged_metrics -# assert torch.allclose(torch.tensor(logged["sum_step"]), model.sum) -# assert torch.allclose(torch.tensor(logged["sum_epoch"]), model.sum) +# assert torch.allclose(tensor(logged["sum_step"]), model.sum) +# assert torch.allclose(tensor(logged["sum_epoch"]), model.sum) # todo: need to be fixed # def test_scriptable(tmpdir): @@ -193,5 +194,5 @@ def training_epoch_end(self, outs): # trainer.fit(model) # # logged = trainer.logged_metrics -# assert torch.allclose(torch.tensor(logged["SumMetric_epoch"]), model.sum) -# assert torch.allclose(torch.tensor(logged["DiffMetric_epoch"]), model.diff) +# assert torch.allclose(tensor(logged["SumMetric_epoch"]), model.sum) +# assert torch.allclose(tensor(logged["DiffMetric_epoch"]), model.diff) diff --git a/tests/regression/test_mean_error.py b/tests/regression/test_mean_error.py index 8aa54ce45b0..4a3207c6780 100644 --- a/tests/regression/test_mean_error.py +++ b/tests/regression/test_mean_error.py @@ -75,6 +75,7 @@ class TestMeanError(MetricTester): def test_mean_error_class( self, preds, target, sk_metric, metric_class, metric_functional, sk_fn, ddp, dist_sync_on_step ): + # todo: `metric_functional` is unused self.run_class_metric_test( ddp=ddp, preds=preds, @@ -85,6 +86,7 @@ def test_mean_error_class( ) def test_mean_error_functional(self, preds, target, sk_metric, metric_class, metric_functional, sk_fn): + # todo: `metric_class` is unused self.run_functional_metric_test( preds=preds, target=target, diff --git a/tests/regression/test_r2score.py b/tests/regression/test_r2score.py index 9a3ba3c89bf..45f7b5d674a 100644 --- a/tests/regression/test_r2score.py +++ b/tests/regression/test_r2score.py @@ -82,6 +82,7 @@ def test_r2(self, adjusted, multioutput, preds, target, sk_metric, num_outputs, ) def test_r2_functional(self, adjusted, multioutput, preds, target, sk_metric, num_outputs): + # todo: `num_outputs` is unused self.run_functional_metric_test( preds, target, @@ -102,14 +103,14 @@ def test_error_on_multidim_tensors(metric_class=R2Score): with pytest.raises( ValueError, match=r'Expected both prediction and target to be 1D or 2D tensors,' - r' but recevied tensors with dimension .' + r' but received tensors with dimension .' ): metric(torch.randn(10, 20, 5), torch.randn(10, 20, 5)) def test_error_on_too_few_samples(metric_class=R2Score): metric = metric_class() - with pytest.raises(ValueError, match='Needs atleast two samples to calculate r2 score.'): + with pytest.raises(ValueError, match='Needs at least two samples to calculate r2 score.'): metric(torch.randn(1, ), torch.randn(1, )) @@ -118,7 +119,7 @@ def test_warning_on_too_large_adjusted(metric_class=R2Score): with pytest.warns( UserWarning, - match="More independent regressions than datapoints in" + match="More independent regressions than data points in" " adjusted r2 score. Falls back to standard r2 score." ): metric(torch.randn(10, ), torch.randn(10, )) diff --git a/tests/retrieval/test_map.py b/tests/retrieval/test_map.py new file mode 100644 index 00000000000..d4d6c212b59 --- /dev/null +++ b/tests/retrieval/test_map.py @@ -0,0 +1,120 @@ +import math +import random +from typing import Callable, List + +import numpy as np +import pytest +import torch +from pytorch_lightning import seed_everything +from sklearn.metrics import average_precision_score as sk_average_precision +from torch import Tensor + +from torchmetrics.metric import Metric +from torchmetrics.retrieval.mean_average_precision import RetrievalMAP + + +@pytest.mark.parametrize(['sklearn_metric', 'torch_class_metric'], [ + [sk_average_precision, RetrievalMAP], +]) +def test_against_sklearn(sklearn_metric: Callable, torch_class_metric: Metric) -> None: + """Compare PL metrics to sklearn version. """ + device = 'cuda' if torch.cuda.is_available() else 'cpu' + seed_everything(0) + + rounds = 20 + sizes = [1, 4, 10, 100] + batch_sizes = [1, 4, 10] + query_without_relevant_docs_options = ['skip', 'pos', 'neg'] + + def compute_sklearn_metric(target: List[np.ndarray], preds: List[np.ndarray], behaviour: str) -> Tensor: + """ Compute sk metric with multiple iterations using the base `sklearn_metric`. """ + sk_results = [] + kwargs = {'device': device, 'dtype': torch.float32} + + for b, a in zip(target, preds): + res = sklearn_metric(b, a) + + if math.isnan(res): + if behaviour == 'skip': + pass + elif behaviour == 'pos': + sk_results.append(torch.tensor(1.0, **kwargs)) + else: + sk_results.append(torch.tensor(0.0, **kwargs)) + else: + sk_results.append(torch.tensor(res, **kwargs)) + if len(sk_results) > 0: + sk_results = torch.stack(sk_results).mean() + else: + sk_results = torch.tensor(0.0, **kwargs) + + return sk_results + + def do_test(batch_size: int, size: int) -> None: + """ For each possible behaviour of the metric, check results are correct. """ + for behaviour in query_without_relevant_docs_options: + metric = torch_class_metric(query_without_relevant_docs=behaviour) + shape = (size, ) + + indexes = [] + preds = [] + target = [] + + for i in range(batch_size): + indexes.append(np.ones(shape, dtype=int) * i) + preds.append(np.random.randn(*shape)) + target.append(np.random.randn(*shape) > 0) + + sk_results = compute_sklearn_metric(target, preds, behaviour) + + indexes_tensor = torch.cat([torch.tensor(i) for i in indexes]) + preds_tensor = torch.cat([torch.tensor(p) for p in preds]) + target_tensor = torch.cat([torch.tensor(t) for t in target]) + + # lets assume data are not ordered + perm = torch.randperm(indexes_tensor.nelement()) + indexes_tensor = indexes_tensor.view(-1)[perm].view(indexes_tensor.size()) + preds_tensor = preds_tensor.view(-1)[perm].view(preds_tensor.size()) + target_tensor = target_tensor.view(-1)[perm].view(target_tensor.size()) + + # shuffle ids to require also sorting of documents ability from the lightning metric + pl_result = metric(indexes_tensor, preds_tensor, target_tensor) + + assert torch.allclose(sk_results.float(), pl_result.float(), equal_nan=True) + + for batch_size in batch_sizes: + for size in sizes: + for _ in range(rounds): + do_test(batch_size, size) + + +@pytest.mark.parametrize(['torch_class_metric'], [ + [RetrievalMAP], +]) +def test_input_data(torch_class_metric: Metric) -> None: + """Check PL metrics inputs are controlled correctly. """ + + device = 'cuda' if torch.cuda.is_available() else 'cpu' + seed_everything(0) + + for _ in range(10): + + length = random.randint(0, 20) + + # check error when `query_without_relevant_docs='error'` is raised correctly + indexes = torch.tensor([0] * length, device=device, dtype=torch.int64) + preds = torch.rand(size=(length, ), device=device, dtype=torch.float32) + target = torch.tensor([False] * length, device=device, dtype=torch.bool) + + metric = torch_class_metric(query_without_relevant_docs='error') + + try: + metric(indexes, preds, target) + except Exception as e: + assert isinstance(e, ValueError) + + # check ValueError with non-accepted argument + try: + metric = torch_class_metric(query_without_relevant_docs='casual_argument') + except Exception as e: + assert isinstance(e, ValueError) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 125af29392c..2e385e5635e 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -2,22 +2,15 @@ import logging as __logging import os -__version__ = '0.2.1dev' -__author__ = 'PyTorchLightning et al.' -__author_email__ = 'name@pytorchlightning.ai' -__license__ = 'Apache-2.0' -__copyright__ = f'Copyright (c) 2020-2021, {__author__}.' -__homepage__ = 'https://github.com/PyTorchLightning/metrics' -__docs__ = "PyTorch native Metrics" -__long_doc__ = """ -Torchmetrics is a metrics API created for easy metric development and usage in both PyTorch and -[PyTorch Lightning](https://pytorch-lightning.readthedocs.io/en/stable/). It was originally a part of -Pytorch Lightning, but got split off so users could take advantage of the large collection of metrics -implemented without having to install Pytorch Lightning (even though we would love for you to try it out). -We currently have around 25+ metrics implemented and we continuously is adding more metrics, both within -already covered domains (classification, regression ect.) but also new domains (object detection ect.). -We make sure that all our metrics are rigorously tested such that you can trust them. -""" +from torchmetrics.info import ( # noqa: F401 + __author__, + __author_email__, + __copyright__, + __docs__, + __homepage__, + __license__, + __version__, +) _logger = __logging.getLogger("torchmetrics") _logger.addHandler(__logging.StreamHandler()) @@ -26,46 +19,33 @@ _PACKAGE_ROOT = os.path.dirname(__file__) _PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT) -try: - # This variable is injected in the __builtins__ by the build - # process. It used to enable importing subpackages of skimage when - # the binaries are not built - _ = None if __LIGHTNING_SETUP__ else None -except NameError: - __LIGHTNING_SETUP__: bool = False - -if __LIGHTNING_SETUP__: - import sys # pragma: no-cover - - sys.stdout.write(f'Partial import of `{__name__}` during the build process.\n') # pragma: no-cover - # We are not importing the rest of the lightning during the build process, as it may not be compiled yet -else: - - from torchmetrics.classification import ( # noqa: F401 - AUC, - AUROC, - F1, - ROC, - Accuracy, - AveragePrecision, - CohenKappa, - ConfusionMatrix, - FBeta, - HammingDistance, - IoU, - Precision, - PrecisionRecallCurve, - Recall, - StatScores, - ) - from torchmetrics.collections import MetricCollection # noqa: F401 - from torchmetrics.metric import Metric # noqa: F401 - from torchmetrics.regression import ( # noqa: F401 - PSNR, - SSIM, - ExplainedVariance, - MeanAbsoluteError, - MeanSquaredError, - MeanSquaredLogError, - R2Score, - ) +from torchmetrics.classification import ( # noqa: F401 E402 + AUC, + AUROC, + F1, + ROC, + Accuracy, + AveragePrecision, + CohenKappa, + ConfusionMatrix, + FBeta, + HammingDistance, + IoU, + MatthewsCorrcoef, + Precision, + PrecisionRecallCurve, + Recall, + StatScores, +) +from torchmetrics.collections import MetricCollection # noqa: F401 E402 +from torchmetrics.metric import Metric # noqa: F401 E402 +from torchmetrics.regression import ( # noqa: F401 E402 + PSNR, + SSIM, + ExplainedVariance, + MeanAbsoluteError, + MeanSquaredError, + MeanSquaredLogError, + R2Score, +) +from torchmetrics.retrieval import RetrievalMAP # noqa: F401 E402 diff --git a/torchmetrics/classification/__init__.py b/torchmetrics/classification/__init__.py index cc03fab62a5..9d1a1ba6fa2 100644 --- a/torchmetrics/classification/__init__.py +++ b/torchmetrics/classification/__init__.py @@ -20,6 +20,7 @@ from torchmetrics.classification.f_beta import F1, FBeta # noqa: F401 from torchmetrics.classification.hamming_distance import HammingDistance # noqa: F401 from torchmetrics.classification.iou import IoU # noqa: F401 +from torchmetrics.classification.matthews_corrcoef import MatthewsCorrcoef # noqa: F401 from torchmetrics.classification.precision_recall import Precision, Recall # noqa: F401 from torchmetrics.classification.precision_recall_curve import PrecisionRecallCurve # noqa: F401 from torchmetrics.classification.roc import ROC # noqa: F401 diff --git a/torchmetrics/classification/accuracy.py b/torchmetrics/classification/accuracy.py index 988263d35dc..e40db2f5619 100644 --- a/torchmetrics/classification/accuracy.py +++ b/torchmetrics/classification/accuracy.py @@ -14,6 +14,7 @@ from typing import Any, Callable, Optional import torch +from torch import Tensor, tensor from torchmetrics.functional.classification.accuracy import _accuracy_compute, _accuracy_update from torchmetrics.metric import Metric @@ -78,8 +79,13 @@ class Accuracy(Metric): Callback that performs the allgather operation on the metric state. When ``None``, DDP will be used to perform the allgather - Example: + Raises: + ValueError: + If ``threshold`` is not between ``0`` and ``1``. + ValueError: + If ``top_k`` is not an ``integer`` larger than ``0``. + Example: >>> from torchmetrics import Accuracy >>> target = torch.tensor([0, 1, 2, 3]) >>> preds = torch.tensor([0, 2, 1, 3]) @@ -112,8 +118,8 @@ def __init__( dist_sync_fn=dist_sync_fn, ) - self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") - self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state("correct", default=tensor(0), dist_reduce_fx="sum") + self.add_state("total", default=tensor(0), dist_reduce_fx="sum") if not 0 < threshold < 1: raise ValueError(f"The `threshold` should be a float in the (0,1) interval, got {threshold}") @@ -125,7 +131,7 @@ def __init__( self.top_k = top_k self.subset_accuracy = subset_accuracy - def update(self, preds: torch.Tensor, target: torch.Tensor): + def update(self, preds: Tensor, target: Tensor): """ Update state with predictions and targets. See :ref:`references/modules:input types` for more information on input types. @@ -142,7 +148,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor): self.correct += correct self.total += total - def compute(self) -> torch.Tensor: + def compute(self) -> Tensor: """ Computes accuracy based on inputs passed in to ``update`` previously. """ diff --git a/torchmetrics/classification/auc.py b/torchmetrics/classification/auc.py index b028a007834..d33d4aa696e 100644 --- a/torchmetrics/classification/auc.py +++ b/torchmetrics/classification/auc.py @@ -14,6 +14,7 @@ from typing import Any, Callable, Optional import torch +from torch import Tensor from torchmetrics.functional.classification.auc import _auc_compute, _auc_update from torchmetrics.metric import Metric @@ -30,7 +31,7 @@ class AUC(Metric): Args: reorder: AUC expects its first input to be sorted. If this is not the case, setting this argument to ``True`` will use a stable sorting algorithm to - sort the input in decending order + sort the input in descending order compute_on_step: Forward only calls ``update()`` and return None if this is set to False. dist_sync_on_step: @@ -39,8 +40,8 @@ class AUC(Metric): process_group: Specify the process group on which synchronization is called. default: None (which selects the entire world) dist_sync_fn: - Callback that performs the allgather operation on the metric state. When ``None``, DDP - will be used to perform the allgather + Callback that performs the ``allgather`` operation on the metric state. When ``None``, DDP + will be used to perform the ``allgather``. """ def __init__( @@ -68,7 +69,7 @@ def __init__( ' For large datasets this may lead to large memory footprint.' ) - def update(self, x: torch.Tensor, y: torch.Tensor): + def update(self, x: Tensor, y: Tensor): """ Update state with predictions and targets. @@ -81,7 +82,7 @@ def update(self, x: torch.Tensor, y: torch.Tensor): self.x.append(x) self.y.append(y) - def compute(self) -> torch.Tensor: + def compute(self) -> Tensor: """ Computes AUC based on inputs passed in to ``update`` previously. """ diff --git a/torchmetrics/classification/auroc.py b/torchmetrics/classification/auroc.py index a78713bac71..79f19415857 100644 --- a/torchmetrics/classification/auroc.py +++ b/torchmetrics/classification/auroc.py @@ -11,18 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from distutils.version import LooseVersion from typing import Any, Callable, Optional import torch +from torch import Tensor from torchmetrics.functional.classification.auroc import _auroc_compute, _auroc_update from torchmetrics.metric import Metric from torchmetrics.utilities import rank_zero_warn +from torchmetrics.utilities.imports import _TORCH_LOWER_1_6 class AUROC(Metric): - r"""Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) + r"""Compute `Area Under the Receiver Operating Characteristic Curve (ROC AUC) `_. Works for both binary, multilabel and multiclass problems. In the case of multiclass, the values will be calculated based on a one-vs-the-rest approach. @@ -39,42 +40,54 @@ class AUROC(Metric): dimension more than the ``target`` tensor the input will be interpretated as multiclass. - Args: - num_classes: integer with number of classes. Not nessesary to provide - for binary problems. - pos_label: integer determining the positive class. Default is ``None`` - which for binary problem is translate to 1. For multiclass problems - this argument should not be set as we iteratively change it in the - range [0,num_classes-1] - average: - - ``'macro'`` computes metric for each class and uniformly averages them - - ``'weighted'`` computes metric for each class and does a weighted-average, - where each class is weighted by their support (accounts for class imbalance) - - ``None`` computes and returns the metric per class - max_fpr: - If not ``None``, calculates standardized partial AUC over the - range [0, max_fpr]. Should be a float between 0 and 1. - compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. - process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) - dist_sync_fn: - Callback that performs the allgather operation on the metric state. When ``None``, DDP - will be used to perform the allgather - - Example (binary case): - + Args: + num_classes: integer with number of classes. Not nessesary to provide + for binary problems. + pos_label: integer determining the positive class. Default is ``None`` + which for binary problem is translate to 1. For multiclass problems + this argument should not be set as we iteratively change it in the + range [0,num_classes-1] + average: + - ``'micro'`` computes metric globally. Only works for multilabel problems + - ``'macro'`` computes metric for each class and uniformly averages them + - ``'weighted'`` computes metric for each class and does a weighted-average, + where each class is weighted by their support (accounts for class imbalance) + - ``None`` computes and returns the metric per class + max_fpr: + If not ``None``, calculates standardized partial AUC over the + range [0, max_fpr]. Should be a float between 0 and 1. + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When ``None``, DDP + will be used to perform the allgather + + Raises: + ValueError: + If ``average`` is none of ``None``, ``"macro"`` or ``"weighted"``. + ValueError: + If ``max_fpr`` is not a ``float`` in the range ``(0, 1]``. + RuntimeError: + If ``PyTorch version`` is ``below 1.6`` since max_fpr requires ``torch.bucketize`` + which is not available below 1.6. + ValueError: + If the mode of data (binary, multi-label, multi-class) changes between batches. + + Example: + >>> # binary case + >>> from torchmetrics import AUROC >>> preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34]) >>> target = torch.tensor([0, 0, 1, 1, 1]) >>> auroc = AUROC(pos_label=1) >>> auroc(preds, target) tensor(0.5000) - Example (multiclass case): - + >>> # multiclass case >>> preds = torch.tensor([[0.90, 0.05, 0.05], ... [0.05, 0.90, 0.05], ... [0.05, 0.05, 0.90], @@ -110,17 +123,17 @@ def __init__( self.average = average self.max_fpr = max_fpr - allowed_average = (None, 'macro', 'weighted') + allowed_average = (None, 'macro', 'weighted', 'micro') if self.average not in allowed_average: raise ValueError( f'Argument `average` expected to be one of the following: {allowed_average} but got {average}' ) if self.max_fpr is not None: - if (not isinstance(max_fpr, float) and 0 < max_fpr <= 1): + if not isinstance(max_fpr, float) or not 0 < max_fpr <= 1: raise ValueError(f"`max_fpr` should be a float in range (0, 1], got: {max_fpr}") - if LooseVersion(torch.__version__) < LooseVersion('1.6.0'): + if _TORCH_LOWER_1_6: raise RuntimeError( '`max_fpr` argument requires `torch.bucketize` which is not available below PyTorch version 1.6' ) @@ -134,7 +147,7 @@ def __init__( ' For large datasets this may lead to large memory footprint.' ) - def update(self, preds: torch.Tensor, target: torch.Tensor): + def update(self, preds: Tensor, target: Tensor): """ Update state with predictions and targets. @@ -154,7 +167,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor): ) self.mode = mode - def compute(self) -> torch.Tensor: + def compute(self) -> Tensor: """ Computes AUROC based on inputs passed in to ``update`` previously. """ diff --git a/torchmetrics/classification/average_precision.py b/torchmetrics/classification/average_precision.py index fcf66029138..0ecfeb3c864 100644 --- a/torchmetrics/classification/average_precision.py +++ b/torchmetrics/classification/average_precision.py @@ -14,6 +14,7 @@ from typing import Any, List, Optional, Union import torch +from torch import Tensor from torchmetrics.functional.classification.average_precision import ( _average_precision_compute, @@ -51,16 +52,16 @@ class AveragePrecision(Metric): process_group: Specify the process group on which synchronization is called. default: None (which selects the entire world) - Example (binary case): - + Example: + >>> # binary case + >>> from torchmetrics import AveragePrecision >>> pred = torch.tensor([0, 1, 2, 3]) >>> target = torch.tensor([0, 1, 1, 1]) >>> average_precision = AveragePrecision(pos_label=1) >>> average_precision(pred, target) tensor(1.) - Example (multiclass case): - + >>> # multiclass case >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], @@ -97,7 +98,7 @@ def __init__( ' For large datasets this may lead to large memory footprint.' ) - def update(self, preds: torch.Tensor, target: torch.Tensor): + def update(self, preds: Tensor, target: Tensor): """ Update state with predictions and targets. @@ -113,7 +114,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor): self.num_classes = num_classes self.pos_label = pos_label - def compute(self) -> Union[torch.Tensor, List[torch.Tensor]]: + def compute(self) -> Union[Tensor, List[Tensor]]: """ Compute the average precision score diff --git a/torchmetrics/classification/cohen_kappa.py b/torchmetrics/classification/cohen_kappa.py index 8f7c4ba6fdb..fd6107220cb 100644 --- a/torchmetrics/classification/cohen_kappa.py +++ b/torchmetrics/classification/cohen_kappa.py @@ -14,6 +14,7 @@ from typing import Any, Optional import torch +from torch import Tensor from torchmetrics.functional.classification.cohen_kappa import _cohen_kappa_compute, _cohen_kappa_update from torchmetrics.metric import Metric @@ -74,6 +75,7 @@ class labels. >>> cohenkappa(preds, target) tensor(0.5000) """ + def __init__( self, num_classes: int, @@ -99,7 +101,7 @@ def __init__( self.add_state("confmat", default=torch.zeros(num_classes, num_classes), dist_reduce_fx="sum") - def update(self, preds: torch.Tensor, target: torch.Tensor): + def update(self, preds: Tensor, target: Tensor): """ Update state with predictions and targets. @@ -110,7 +112,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor): confmat = _cohen_kappa_update(preds, target, self.num_classes, self.threshold) self.confmat += confmat - def compute(self) -> torch.Tensor: + def compute(self) -> Tensor: """ Computes cohen kappa score """ diff --git a/torchmetrics/classification/confusion_matrix.py b/torchmetrics/classification/confusion_matrix.py index 64336854c71..dfd65158da3 100644 --- a/torchmetrics/classification/confusion_matrix.py +++ b/torchmetrics/classification/confusion_matrix.py @@ -14,6 +14,7 @@ from typing import Any, Optional import torch +from torch import Tensor from torchmetrics.functional.classification.confusion_matrix import _confusion_matrix_compute, _confusion_matrix_update from torchmetrics.metric import Metric @@ -60,7 +61,6 @@ class ConfusionMatrix(Metric): Specify the process group on which synchronization is called. default: None (which selects the entire world) Example: - >>> from torchmetrics import ConfusionMatrix >>> target = torch.tensor([1, 1, 0, 0]) >>> preds = torch.tensor([0, 1, 0, 0]) @@ -96,7 +96,7 @@ def __init__( self.add_state("confmat", default=torch.zeros(num_classes, num_classes), dist_reduce_fx="sum") - def update(self, preds: torch.Tensor, target: torch.Tensor): + def update(self, preds: Tensor, target: Tensor): """ Update state with predictions and targets. @@ -107,7 +107,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor): confmat = _confusion_matrix_update(preds, target, self.num_classes, self.threshold) self.confmat += confmat - def compute(self) -> torch.Tensor: + def compute(self) -> Tensor: """ Computes confusion matrix """ diff --git a/torchmetrics/classification/f_beta.py b/torchmetrics/classification/f_beta.py index 59c308c1d0a..3bbf1f1e063 100644 --- a/torchmetrics/classification/f_beta.py +++ b/torchmetrics/classification/f_beta.py @@ -14,6 +14,7 @@ from typing import Any, Optional import torch +from torch import Tensor from torchmetrics.functional.classification.f_beta import _fbeta_compute, _fbeta_update from torchmetrics.metric import Metric @@ -64,8 +65,11 @@ class FBeta(Metric): process_group: Specify the process group on which synchronization is called. default: None (which selects the entire world) - Example: + Raises: + ValueError: + If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"none"``, ``None``. + Example: >>> from torchmetrics import FBeta >>> target = torch.tensor([0, 1, 2, 0, 1, 2]) >>> preds = torch.tensor([0, 2, 1, 0, 0, 1]) @@ -109,7 +113,7 @@ def __init__( self.add_state("predicted_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum") self.add_state("actual_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum") - def update(self, preds: torch.Tensor, target: torch.Tensor): + def update(self, preds: Tensor, target: Tensor): """ Update state with predictions and targets. @@ -125,7 +129,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor): self.predicted_positives += predicted_positives self.actual_positives += actual_positives - def compute(self) -> torch.Tensor: + def compute(self) -> Tensor: """ Computes fbeta over state. """ diff --git a/torchmetrics/classification/hamming_distance.py b/torchmetrics/classification/hamming_distance.py index c0980c94ce2..497db607ad6 100644 --- a/torchmetrics/classification/hamming_distance.py +++ b/torchmetrics/classification/hamming_distance.py @@ -14,6 +14,7 @@ from typing import Any, Callable, Optional import torch +from torch import Tensor, tensor from torchmetrics.functional.classification.hamming_distance import _hamming_distance_compute, _hamming_distance_update from torchmetrics.metric import Metric @@ -53,8 +54,11 @@ class HammingDistance(Metric): Callback that performs the allgather operation on the metric state. When ``None``, DDP will be used to perform the all gather. - Example: + Raises: + ValueError: + If ``threshold`` is not between ``0`` and ``1``. + Example: >>> from torchmetrics import HammingDistance >>> target = torch.tensor([[0, 1], [1, 1]]) >>> preds = torch.tensor([[0, 1], [0, 1]]) @@ -79,14 +83,14 @@ def __init__( dist_sync_fn=dist_sync_fn, ) - self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") - self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state("correct", default=tensor(0), dist_reduce_fx="sum") + self.add_state("total", default=tensor(0), dist_reduce_fx="sum") if not 0 < threshold < 1: raise ValueError("The `threshold` should lie in the (0,1) interval.") self.threshold = threshold - def update(self, preds: torch.Tensor, target: torch.Tensor): + def update(self, preds: Tensor, target: Tensor): """ Update state with predictions and targets. See :ref:`references/modules:input types` for more information on input types. @@ -100,7 +104,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor): self.correct += correct self.total += total - def compute(self) -> torch.Tensor: + def compute(self) -> Tensor: """ Computes hamming distance based on inputs passed in to ``update`` previously. """ diff --git a/torchmetrics/classification/iou.py b/torchmetrics/classification/iou.py index 5ce0549b2f3..326b43d24ad 100644 --- a/torchmetrics/classification/iou.py +++ b/torchmetrics/classification/iou.py @@ -14,6 +14,7 @@ from typing import Any, Optional import torch +from torch import Tensor from torchmetrics.classification.confusion_matrix import ConfusionMatrix from torchmetrics.functional.classification.iou import _iou_from_confmat @@ -100,7 +101,7 @@ def __init__( self.ignore_index = ignore_index self.absent_score = absent_score - def compute(self) -> torch.Tensor: + def compute(self) -> Tensor: """ Computes intersection over union (IoU) """ diff --git a/torchmetrics/classification/matthews_corrcoef.py b/torchmetrics/classification/matthews_corrcoef.py new file mode 100644 index 00000000000..f4e84b21841 --- /dev/null +++ b/torchmetrics/classification/matthews_corrcoef.py @@ -0,0 +1,114 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional + +import torch +from torch import Tensor + +from torchmetrics.functional.classification.matthews_corrcoef import ( + _matthews_corrcoef_compute, + _matthews_corrcoef_update, +) +from torchmetrics.metric import Metric + + +class MatthewsCorrcoef(Metric): + r""" + Calculates `Matthews correlation coefficient + `_ that measures + the general correlation or quality of a classification. In the binary case it + is defined as: + + .. math:: + MCC = \frac{TP*TN - FP*FN}{\sqrt{(TP+FP)*(TP+FN)*(TN+FP)*(TN+FN)}} + + where TP, TN, FP and FN are respectively the true postitives, true negatives, + false positives and false negatives. Also works in the case of multi-label or + multi-class input. + + Note: + This metric produces a multi-dimensional output, so it can not be directly logged. + + Forward accepts + + - ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes + - ``target`` (long tensor): ``(N, ...)`` + + If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument + to convert into integer labels. This is the case for binary and multi-label probabilities. + + If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. + + Args: + num_classes: Number of classes in the dataset. + threshold: + Threshold value for binary or multi-label probabilites. default: 0.5 + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. default: False + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When ``None``, DDP + will be used to perform the allgather + + Example: + + >>> from torchmetrics import MatthewsCorrcoef + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0, 1, 0, 0]) + >>> matthews_corrcoef = MatthewsCorrcoef(num_classes=2) + >>> matthews_corrcoef(preds, target) + tensor(0.5774) + + """ + def __init__( + self, + num_classes: int, + threshold: float = 0.5, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + ): + + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + self.num_classes = num_classes + self.threshold = threshold + + self.add_state("confmat", default=torch.zeros(num_classes, num_classes), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor): + """ + Update state with predictions and targets. + + Args: + preds: Predictions from model + target: Ground truth values + """ + confmat = _matthews_corrcoef_update(preds, target, self.num_classes, self.threshold) + self.confmat += confmat + + def compute(self) -> Tensor: + """ + Computes matthews correlation coefficient + """ + return _matthews_corrcoef_compute(self.confmat) diff --git a/torchmetrics/classification/precision_recall.py b/torchmetrics/classification/precision_recall.py index 7bb8a7f23e2..49c15276a19 100644 --- a/torchmetrics/classification/precision_recall.py +++ b/torchmetrics/classification/precision_recall.py @@ -14,6 +14,7 @@ from typing import Any, Callable, Optional import torch +from torch import Tensor from torchmetrics.classification.stat_scores import StatScores from torchmetrics.functional.classification.precision_recall import _precision_compute, _recall_compute @@ -103,8 +104,11 @@ class Precision(StatScores): Callback that performs the allgather operation on the metric state. When ``None``, DDP will be used to perform the allgather. - Example: + Raises: + ValueError: + If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"``, ``None``. + Example: >>> from torchmetrics import Precision >>> preds = torch.tensor([2, 0, 2, 1]) >>> target = torch.tensor([1, 1, 2, 0]) @@ -151,7 +155,7 @@ def __init__( self.average = average - def compute(self) -> torch.Tensor: + def compute(self) -> Tensor: """ Computes the precision score based on inputs passed in to ``update`` previously. @@ -251,8 +255,11 @@ class Recall(StatScores): Callback that performs the allgather operation on the metric state. When ``None``, DDP will be used to perform the allgather. - Example: + Raises: + ValueError: + If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"``, ``None``. + Example: >>> from torchmetrics import Recall >>> preds = torch.tensor([2, 0, 2, 1]) >>> target = torch.tensor([1, 1, 2, 0]) @@ -299,7 +306,7 @@ def __init__( self.average = average - def compute(self) -> torch.Tensor: + def compute(self) -> Tensor: """ Computes the recall score based on inputs passed in to ``update`` previously. diff --git a/torchmetrics/classification/precision_recall_curve.py b/torchmetrics/classification/precision_recall_curve.py index 679978fdc67..de781cfff15 100644 --- a/torchmetrics/classification/precision_recall_curve.py +++ b/torchmetrics/classification/precision_recall_curve.py @@ -14,6 +14,7 @@ from typing import Any, List, Optional, Tuple, Union import torch +from torch import Tensor from torchmetrics.functional.classification.precision_recall_curve import ( _precision_recall_curve_compute, @@ -51,8 +52,9 @@ class PrecisionRecallCurve(Metric): process_group: Specify the process group on which synchronization is called. default: None (which selects the entire world) - Example (binary case): - + Example: + >>> # binary case + >>> from torchmetrics import PrecisionRecallCurve >>> pred = torch.tensor([0, 1, 2, 3]) >>> target = torch.tensor([0, 1, 1, 0]) >>> pr_curve = PrecisionRecallCurve(pos_label=1) @@ -64,8 +66,7 @@ class PrecisionRecallCurve(Metric): >>> thresholds tensor([1, 2, 3]) - Example (multiclass case): - + >>> # multiclass case >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], @@ -108,7 +109,7 @@ def __init__( ' For large datasets this may lead to large memory footprint.' ) - def update(self, preds: torch.Tensor, target: torch.Tensor): + def update(self, preds: Tensor, target: Tensor): """ Update state with predictions and targets. @@ -124,14 +125,12 @@ def update(self, preds: torch.Tensor, target: torch.Tensor): self.num_classes = num_classes self.pos_label = pos_label - def compute( - self - ) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor], - List[torch.Tensor]]]: + def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: """ Compute the precision-recall curve - Returns: 3-element tuple containing + Returns: + 3-element tuple containing precision: tensor where element i is the precision of predictions with @@ -143,7 +142,6 @@ def compute( If multiclass, this is a list of such tensors, one for each class. thresholds: Thresholds used for computing precision/recall scores - """ preds = torch.cat(self.preds, dim=0) target = torch.cat(self.target, dim=0) diff --git a/torchmetrics/classification/roc.py b/torchmetrics/classification/roc.py index e3477832b21..6cab71d7534 100644 --- a/torchmetrics/classification/roc.py +++ b/torchmetrics/classification/roc.py @@ -14,6 +14,7 @@ from typing import Any, Callable, List, Optional, Tuple, Union import torch +from torch import Tensor from torchmetrics.functional.classification.roc import _roc_compute, _roc_update from torchmetrics.metric import Metric @@ -99,8 +100,10 @@ class ROC(Metric): [tensor([0.0000, 0.3333, 0.3333, 0.6667, 1.0000]), tensor([0., 0., 0., 1., 1.]), tensor([0.0000, 0.0000, 0.3333, 0.6667, 1.0000])] - >>> tpr - [tensor([0., 0., 1., 1., 1.]), tensor([0.0000, 0.3333, 0.6667, 0.6667, 1.0000]), tensor([0., 1., 1., 1., 1.])] + >>> tpr # doctest: +NORMALIZE_WHITESPACE + [tensor([0., 0., 1., 1., 1.]), + tensor([0.0000, 0.3333, 0.6667, 0.6667, 1.0000]), + tensor([0., 1., 1., 1., 1.])] >>> thresholds # doctest: +NORMALIZE_WHITESPACE [tensor([1.8603, 0.8603, 0.8191, 0.3584, 0.2286]), tensor([1.7576, 0.7576, 0.3680, 0.3468, 0.0745]), @@ -134,7 +137,7 @@ def __init__( ' For large datasets this may lead to large memory footprint.' ) - def update(self, preds: torch.Tensor, target: torch.Tensor): + def update(self, preds: Tensor, target: Tensor): """ Update state with predictions and targets. @@ -148,14 +151,12 @@ def update(self, preds: torch.Tensor, target: torch.Tensor): self.num_classes = num_classes self.pos_label = pos_label - def compute( - self - ) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor], - List[torch.Tensor]]]: + def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: """ Compute the receiver operating characteristic - Returns: 3-element tuple containing + Returns: + 3-element tuple containing fpr: tensor with false positive rates. @@ -165,7 +166,6 @@ def compute( If multiclass, this is a list of such tensors, one for each class. thresholds: thresholds used for computing false- and true postive rates - """ preds = torch.cat(self.preds, dim=0) target = torch.cat(self.target, dim=0) diff --git a/torchmetrics/classification/stat_scores.py b/torchmetrics/classification/stat_scores.py index 8312714af98..c96c75b5171 100644 --- a/torchmetrics/classification/stat_scores.py +++ b/torchmetrics/classification/stat_scores.py @@ -15,6 +15,7 @@ import numpy as np import torch +from torch import Tensor, tensor from torchmetrics.functional.classification.stat_scores import _stat_scores_compute, _stat_scores_update from torchmetrics.metric import Metric @@ -103,8 +104,20 @@ class StatScores(Metric): Callback that performs the allgather operation on the metric state. When ``None``, DDP will be used to perform the allgather. - Example: + Raises: + ValueError: + If ``threshold`` is not a ``float`` between ``0`` and ``1``. + ValueError: + If ``reduce`` is none of ``"micro"``, ``"macro"`` or ``"samples"``. + ValueError: + If ``mdmc_reduce`` is none of ``None``, ``"samplewise"``, ``"global"``. + ValueError: + If ``reduce`` is set to ``"macro"`` and ``num_classes`` is not provided. + ValueError: + If ``num_classes`` is set + and ``ignore_index`` is not in the range ``0`` <= ``ignore_index`` < ``num_classes``. + Example: >>> from torchmetrics.classification import StatScores >>> preds = torch.tensor([1, 0, 2, 1]) >>> target = torch.tensor([1, 1, 2, 0]) @@ -175,7 +188,7 @@ def __init__( for s in ("tp", "fp", "tn", "fn"): self.add_state(s, default=default(), dist_reduce_fx=reduce_fn) - def update(self, preds: torch.Tensor, target: torch.Tensor): + def update(self, preds: Tensor, target: Tensor): """ Update state with predictions and targets. See :ref:`references/modules:input types` for more information on input types. @@ -209,7 +222,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor): self.tn.append(tn) self.fn.append(fn) - def _get_final_stats(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + def _get_final_stats(self) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """Performs concatenation on the stat scores if neccesary, before passing them to a compute function. """ @@ -224,7 +237,7 @@ def _get_final_stats(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, to return tp, fp, tn, fn - def compute(self) -> torch.Tensor: + def compute(self) -> Tensor: """ Computes the stat scores based on inputs passed in to ``update`` previously. @@ -262,13 +275,13 @@ def compute(self) -> torch.Tensor: def _reduce_stat_scores( - numerator: torch.Tensor, - denominator: torch.Tensor, - weights: Optional[torch.Tensor], + numerator: Tensor, + denominator: Tensor, + weights: Optional[Tensor], average: str, mdmc_average: Optional[str], zero_division: int = 0, -) -> torch.Tensor: +) -> Tensor: """ Reduces scores of type ``numerator/denominator`` or ``weights * (numerator/denominator)``, if ``average='weighted'``. @@ -303,9 +316,9 @@ def _reduce_stat_scores( else: weights = weights.float() - numerator = torch.where(zero_div_mask, torch.tensor(float(zero_division), device=numerator.device), numerator) - denominator = torch.where(zero_div_mask | ignore_mask, torch.tensor(1.0, device=denominator.device), denominator) - weights = torch.where(ignore_mask, torch.tensor(0.0, device=weights.device), weights) + numerator = torch.where(zero_div_mask, tensor(float(zero_division), device=numerator.device), numerator) + denominator = torch.where(zero_div_mask | ignore_mask, tensor(1.0, device=denominator.device), denominator) + weights = torch.where(ignore_mask, tensor(0.0, device=weights.device), weights) if average not in (AverageMethod.MICRO, AverageMethod.NONE, None): weights = weights / weights.sum(dim=-1, keepdim=True) @@ -313,14 +326,14 @@ def _reduce_stat_scores( scores = weights * (numerator / denominator) # This is in case where sum(weights) = 0, which happens if we ignore the only present class with average='weighted' - scores = torch.where(torch.isnan(scores), torch.tensor(float(zero_division), device=scores.device), scores) + scores = torch.where(torch.isnan(scores), tensor(float(zero_division), device=scores.device), scores) if mdmc_average == MDMCAverageMethod.SAMPLEWISE: scores = scores.mean(dim=0) ignore_mask = ignore_mask.sum(dim=0).bool() if average in (AverageMethod.NONE, None): - scores = torch.where(ignore_mask, torch.tensor(np.nan, device=scores.device), scores) + scores = torch.where(ignore_mask, tensor(np.nan, device=scores.device), scores) else: scores = scores.sum() diff --git a/torchmetrics/collections.py b/torchmetrics/collections.py index e60efca63dc..cda4dacaa4d 100644 --- a/torchmetrics/collections.py +++ b/torchmetrics/collections.py @@ -38,7 +38,16 @@ class MetricCollection(nn.ModuleDict): prefix: a string to append in front of the keys of the output dict - Example (input as list): + Raises: + ValueError: + If one of the elements of ``metrics`` is not an instance of ``pl.metrics.Metric``. + ValueError: + If two elements in ``metrics`` have the same ``name``. + ValueError: + If ``metrics`` is not a ``list``, ``tuple`` or a ``dict``. + + Example: + >>> # input as list >>> import torch >>> from torchmetrics import MetricCollection, Accuracy, Precision, Recall >>> target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2]) @@ -49,7 +58,7 @@ class MetricCollection(nn.ModuleDict): >>> metrics(preds, target) {'Accuracy': tensor(0.1250), 'Precision': tensor(0.0667), 'Recall': tensor(0.1111)} - Example (input as dict): + >>> # input as dict >>> metrics = MetricCollection({'micro_recall': Recall(num_classes=3, average='micro'), ... 'macro_recall': Recall(num_classes=3, average='macro')}) >>> same_metric = metrics.clone() @@ -60,10 +69,11 @@ class MetricCollection(nn.ModuleDict): >>> metrics.persistent() """ + def __init__( - self, - metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric]], - prefix: Optional[str] = None + self, + metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric]], + prefix: Optional[str] = None, ): super().__init__() if isinstance(metrics, dict): @@ -136,7 +146,8 @@ def persistent(self, mode: bool = True) -> None: def _set_prefix(self, k: str) -> str: return k if self.prefix is None else self.prefix + k - def _check_prefix_arg(self, prefix: str) -> Optional[str]: + @staticmethod + def _check_prefix_arg(prefix: str) -> Optional[str]: if prefix is not None: if isinstance(prefix, str): return prefix diff --git a/torchmetrics/functional/__init__.py b/torchmetrics/functional/__init__.py index 9c294f328fa..a57dcbcc800 100644 --- a/torchmetrics/functional/__init__.py +++ b/torchmetrics/functional/__init__.py @@ -21,6 +21,7 @@ from torchmetrics.functional.classification.f_beta import f1, fbeta # noqa: F401 from torchmetrics.functional.classification.hamming_distance import hamming_distance # noqa: F401 from torchmetrics.functional.classification.iou import iou # noqa: F401 +from torchmetrics.functional.classification.matthews_corrcoef import matthews_corrcoef # noqa: F401 from torchmetrics.functional.classification.precision_recall import precision, precision_recall, recall # noqa: F401 from torchmetrics.functional.classification.precision_recall_curve import precision_recall_curve # noqa: F401 from torchmetrics.functional.classification.roc import roc # noqa: F401 @@ -29,9 +30,11 @@ from torchmetrics.functional.nlp import bleu_score # noqa: F401 from torchmetrics.functional.regression.explained_variance import explained_variance # noqa: F401 from torchmetrics.functional.regression.mean_absolute_error import mean_absolute_error # noqa: F401 +from torchmetrics.functional.regression.mean_relative_error import mean_relative_error # noqa: F401 from torchmetrics.functional.regression.mean_squared_error import mean_squared_error # noqa: F401 from torchmetrics.functional.regression.mean_squared_log_error import mean_squared_log_error # noqa: F401 from torchmetrics.functional.regression.psnr import psnr # noqa: F401 from torchmetrics.functional.regression.r2score import r2score # noqa: F401 from torchmetrics.functional.regression.ssim import ssim # noqa: F401 +from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision # noqa: F401 from torchmetrics.functional.self_supervised import embedding_similarity # noqa: F401 diff --git a/torchmetrics/functional/classification/__init__.py b/torchmetrics/functional/classification/__init__.py index 90c5458091f..655081d298d 100644 --- a/torchmetrics/functional/classification/__init__.py +++ b/torchmetrics/functional/classification/__init__.py @@ -21,6 +21,7 @@ from torchmetrics.functional.classification.f_beta import f1, fbeta # noqa: F401 from torchmetrics.functional.classification.hamming_distance import hamming_distance # noqa: F401 from torchmetrics.functional.classification.iou import iou # noqa: F401 +from torchmetrics.functional.classification.matthews_corrcoef import matthews_corrcoef # noqa: F401 from torchmetrics.functional.classification.precision_recall import precision, precision_recall, recall # noqa: F401 from torchmetrics.functional.classification.precision_recall_curve import precision_recall_curve # noqa: F401 from torchmetrics.functional.classification.roc import roc # noqa: F401 diff --git a/torchmetrics/functional/classification/accuracy.py b/torchmetrics/functional/classification/accuracy.py index 68810531ecb..9222202ba94 100644 --- a/torchmetrics/functional/classification/accuracy.py +++ b/torchmetrics/functional/classification/accuracy.py @@ -14,48 +14,54 @@ from typing import Optional, Tuple import torch +from torch import Tensor, tensor from torchmetrics.utilities.checks import _input_format_classification from torchmetrics.utilities.enums import DataType def _accuracy_update( - preds: torch.Tensor, target: torch.Tensor, threshold: float, top_k: Optional[int], subset_accuracy: bool -) -> Tuple[torch.Tensor, torch.Tensor]: + preds: Tensor, + target: Tensor, + threshold: float, + top_k: Optional[int], + subset_accuracy: bool, +) -> Tuple[Tensor, Tensor]: preds, target, mode = _input_format_classification(preds, target, threshold=threshold, top_k=top_k) + correct, total = None, None if mode == DataType.MULTILABEL and top_k: raise ValueError("You can not use the `top_k` parameter to calculate accuracy for multi-label inputs.") if mode == DataType.BINARY or (mode == DataType.MULTILABEL and subset_accuracy): correct = (preds == target).all(dim=1).sum() - total = torch.tensor(target.shape[0], device=target.device) + total = tensor(target.shape[0], device=target.device) elif mode == DataType.MULTILABEL and not subset_accuracy: correct = (preds == target).sum() - total = torch.tensor(target.numel(), device=target.device) + total = tensor(target.numel(), device=target.device) elif mode == DataType.MULTICLASS or (mode == DataType.MULTIDIM_MULTICLASS and not subset_accuracy): correct = (preds * target).sum() total = target.sum() elif mode == DataType.MULTIDIM_MULTICLASS and subset_accuracy: sample_correct = (preds * target).sum(dim=(1, 2)) correct = (sample_correct == target.shape[2]).sum() - total = torch.tensor(target.shape[0], device=target.device) + total = tensor(target.shape[0], device=target.device) return correct, total -def _accuracy_compute(correct: torch.Tensor, total: torch.Tensor) -> torch.Tensor: +def _accuracy_compute(correct: Tensor, total: Tensor) -> Tensor: return correct.float() / total def accuracy( - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, threshold: float = 0.5, top_k: Optional[int] = None, subset_accuracy: bool = False, -) -> torch.Tensor: +) -> Tensor: r"""Computes `Accuracy `_: .. math:: @@ -103,8 +109,11 @@ def accuracy( ``preds = preds.flatten()`` and same for ``target``). Note that the ``top_k`` parameter still applies in both cases, if set. - Example: + Raises: + ValueError: + If ``top_k`` parameter is set for ``multi-label`` inputs. + Example: >>> from torchmetrics.functional import accuracy >>> target = torch.tensor([0, 1, 2, 3]) >>> preds = torch.tensor([0, 2, 1, 3]) diff --git a/torchmetrics/functional/classification/auc.py b/torchmetrics/functional/classification/auc.py index 467de147c16..75bad398464 100644 --- a/torchmetrics/functional/classification/auc.py +++ b/torchmetrics/functional/classification/auc.py @@ -14,11 +14,12 @@ from typing import Tuple import torch +from torch import Tensor from torchmetrics.utilities.data import _stable_1d_sort -def _auc_update(x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: +def _auc_update(x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: if x.ndim > 1 or y.ndim > 1: raise ValueError( f'Expected both `x` and `y` tensor to be 1d, but got' @@ -32,7 +33,7 @@ def _auc_update(x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.T return x, y -def _auc_compute(x: torch.Tensor, y: torch.Tensor, reorder: bool = False) -> torch.Tensor: +def _auc_compute(x: Tensor, y: Tensor, reorder: bool = False) -> Tensor: if reorder: x, x_idx = _stable_1d_sort(x) y = y[x_idx] @@ -51,7 +52,7 @@ def _auc_compute(x: torch.Tensor, y: torch.Tensor, reorder: bool = False) -> tor return direction * torch.trapz(y, x) -def auc(x: torch.Tensor, y: torch.Tensor, reorder: bool = False) -> torch.Tensor: +def auc(x: Tensor, y: Tensor, reorder: bool = False) -> Tensor: """ Computes Area Under the Curve (AUC) using the trapezoidal rule @@ -63,8 +64,16 @@ def auc(x: torch.Tensor, y: torch.Tensor, reorder: bool = False) -> torch.Tensor Return: Tensor containing AUC score (float) - Example: + Raises: + ValueError: + If both ``x`` and ``y`` tensors are not ``1d``. + ValueError: + If both ``x`` and ``y`` don't have the same numnber of elements. + ValueError: + If ``x`` tesnsor is neither increasing or decreasing. + Example: + >>> from torchmetrics.functional import auc >>> x = torch.tensor([0, 1, 2, 3]) >>> y = torch.tensor([0, 1, 2, 2]) >>> auc(x, y) diff --git a/torchmetrics/functional/classification/auroc.py b/torchmetrics/functional/classification/auroc.py index bb8fa0b5483..22093a02470 100644 --- a/torchmetrics/functional/classification/auroc.py +++ b/torchmetrics/functional/classification/auroc.py @@ -11,18 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from distutils.version import LooseVersion from typing import Optional, Sequence, Tuple import torch +from torch import Tensor, tensor from torchmetrics.functional.classification.auc import auc from torchmetrics.functional.classification.roc import roc from torchmetrics.utilities.checks import _input_format_classification from torchmetrics.utilities.enums import AverageMethod, DataType +from torchmetrics.utilities.imports import _TORCH_LOWER_1_6 -def _auroc_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, str]: +def _auroc_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor, str]: # use _input_format_classification for validating the input and get the mode of data _, _, mode = _input_format_classification(preds, target) @@ -39,25 +40,25 @@ def _auroc_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tens def _auroc_compute( - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, mode: str, num_classes: Optional[int] = None, pos_label: Optional[int] = None, average: Optional[str] = 'macro', max_fpr: Optional[float] = None, sample_weights: Optional[Sequence] = None, -) -> torch.Tensor: +) -> Tensor: # binary mode override num_classes if mode == 'binary': num_classes = 1 # check max_fpr parameter if max_fpr is not None: - if (not isinstance(max_fpr, float) and 0 < max_fpr <= 1): + if not isinstance(max_fpr, float) and 0 < max_fpr <= 1: raise ValueError(f"`max_fpr` should be a float in range (0, 1], got: {max_fpr}") - if LooseVersion(torch.__version__) < LooseVersion('1.6.0'): + if _TORCH_LOWER_1_6: raise RuntimeError( "`max_fpr` argument requires `torch.bucketize` which" " is not available below PyTorch version 1.6" @@ -73,19 +74,24 @@ def _auroc_compute( # calculate fpr, tpr if mode == 'multi-label': - # for multilabel we iteratively evaluate roc in a binary fashion - output = [ - roc(preds[:, i], target[:, i], num_classes=1, pos_label=1, sample_weights=sample_weights) - for i in range(num_classes) - ] - fpr = [o[0] for o in output] - tpr = [o[1] for o in output] + if average == AverageMethod.MICRO: + fpr, tpr, _ = roc(preds.flatten(), target.flatten(), num_classes, pos_label, sample_weights) + else: + # for multilabel we iteratively evaluate roc in a binary fashion + output = [ + roc(preds[:, i], target[:, i], num_classes=1, pos_label=1, sample_weights=sample_weights) + for i in range(num_classes) + ] + fpr = [o[0] for o in output] + tpr = [o[1] for o in output] else: fpr, tpr, _ = roc(preds, target, num_classes, pos_label, sample_weights) # calculate standard roc auc score if max_fpr is None or max_fpr == 1: - if num_classes != 1: + if mode == 'multi-label' and average == AverageMethod.MICRO: + pass + elif num_classes != 1: # calculate auc scores per class auc_scores = [auc(x, y) for x, y in zip(fpr, tpr)] @@ -109,7 +115,7 @@ def _auroc_compute( return auc(fpr, tpr) - max_fpr = torch.tensor(max_fpr, device=fpr.device) + max_fpr = tensor(max_fpr, device=fpr.device) # Add a single point at max_fpr and interpolate its tpr value stop = torch.bucketize(max_fpr, fpr, out_int32=True, right=True) weight = (max_fpr - fpr[stop - 1]) / (fpr[stop] - fpr[stop - 1]) @@ -128,14 +134,14 @@ def _auroc_compute( def auroc( - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, num_classes: Optional[int] = None, pos_label: Optional[int] = None, average: Optional[str] = 'macro', max_fpr: Optional[float] = None, sample_weights: Optional[Sequence] = None, -) -> torch.Tensor: +) -> Tensor: """ Compute `Area Under the Receiver Operating Characteristic Curve (ROC AUC) `_ @@ -149,6 +155,7 @@ def auroc( this argument should not be set as we iteratively change it in the range [0,num_classes-1] average: + - ``'micro'`` computes metric globally. Only works for multilabel problems - ``'macro'`` computes metric for each class and uniformly averages them - ``'weighted'`` computes metric for each class and does a weighted-average, where each class is weighted by their support (accounts for class imbalance) @@ -156,17 +163,29 @@ def auroc( max_fpr: If not ``None``, calculates standardized partial AUC over the range [0, max_fpr]. Should be a float between 0 and 1. - sample_weight: sample weights for each data point - - Example (binary case): - + sample_weights: sample weights for each data point + + Raises: + ValueError: + If ``max_fpr`` is not a ``float`` in the range ``(0, 1]``. + RuntimeError: + If ``PyTorch version`` is ``below 1.6`` since max_fpr requires `torch.bucketize` + which is not available below 1.6. + ValueError: + If ``max_fpr`` is not set to ``None`` and the mode is ``not binary`` + since partial AUC computation is not available in multilabel/multiclass. + ValueError: + If ``average`` is none of ``None``, ``"macro"`` or ``"weighted"``. + + Example: + >>> # binary case + >>> from torchmetrics.functional import auroc >>> preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34]) >>> target = torch.tensor([0, 0, 1, 1, 1]) >>> auroc(preds, target, pos_label=1) tensor(0.5000) - Example (multiclass case): - + >>> # multiclass case >>> preds = torch.tensor([[0.90, 0.05, 0.05], ... [0.05, 0.90, 0.05], ... [0.05, 0.05, 0.90], diff --git a/torchmetrics/functional/classification/average_precision.py b/torchmetrics/functional/classification/average_precision.py index b4457681ce0..543bbfb943f 100644 --- a/torchmetrics/functional/classification/average_precision.py +++ b/torchmetrics/functional/classification/average_precision.py @@ -14,6 +14,7 @@ from typing import List, Optional, Sequence, Tuple, Union import torch +from torch import Tensor from torchmetrics.functional.classification.precision_recall_curve import ( _precision_recall_curve_compute, @@ -22,21 +23,22 @@ def _average_precision_update( - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, num_classes: Optional[int] = None, pos_label: Optional[int] = None, -) -> Tuple[torch.Tensor, torch.Tensor, int, int]: +) -> Tuple[Tensor, Tensor, int, int]: return _precision_recall_curve_update(preds, target, num_classes, pos_label) def _average_precision_compute( - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, num_classes: int, pos_label: int, - sample_weights: Optional[Sequence] = None -) -> Union[List[torch.Tensor], torch.Tensor]: + sample_weights: Optional[Sequence] = None, +) -> Union[List[Tensor], Tensor]: + # todo: `sample_weights` is unused precision, recall, _ = _precision_recall_curve_compute(preds, target, num_classes, pos_label) # Return the step function integral # The following works because the last entry of precision is @@ -51,12 +53,12 @@ def _average_precision_compute( def average_precision( - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, num_classes: Optional[int] = None, pos_label: Optional[int] = None, sample_weights: Optional[Sequence] = None, -) -> Union[List[torch.Tensor], torch.Tensor]: +) -> Union[List[Tensor], Tensor]: """ Computes the average precision score. @@ -75,15 +77,15 @@ def average_precision( tensor with average precision. If multiclass will return list of such tensors, one for each class - Example (binary case): - + Example: + >>> # binary case + >>> from torchmetrics.functional import average_precision >>> pred = torch.tensor([0, 1, 2, 3]) >>> target = torch.tensor([0, 1, 1, 1]) >>> average_precision(pred, target, pos_label=1) tensor(1.) - Example (multiclass case): - + >>> # multiclass case >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], diff --git a/torchmetrics/functional/classification/cohen_kappa.py b/torchmetrics/functional/classification/cohen_kappa.py index 448d1898cc0..f26fa22d45e 100644 --- a/torchmetrics/functional/classification/cohen_kappa.py +++ b/torchmetrics/functional/classification/cohen_kappa.py @@ -14,13 +14,14 @@ from typing import Optional import torch +from torch import Tensor from torchmetrics.functional.classification.confusion_matrix import _confusion_matrix_compute, _confusion_matrix_update _cohen_kappa_update = _confusion_matrix_update -def _cohen_kappa_compute(confmat: torch.Tensor, weights: Optional[str] = None) -> torch.Tensor: +def _cohen_kappa_compute(confmat: Tensor, weights: Optional[str] = None) -> Tensor: confmat = _confusion_matrix_compute(confmat) n_classes = confmat.shape[0] sum0 = confmat.sum(dim=0, keepdim=True) @@ -39,20 +40,22 @@ def _cohen_kappa_compute(confmat: torch.Tensor, weights: Optional[str] = None) - else: w_mat = torch.pow(w_mat - w_mat.T, 2.0) else: - raise ValueError(f"Received {weights} for argument ``weights`` but should be either" - " None, 'linear' or 'quadratic'") + raise ValueError( + f"Received {weights} for argument ``weights`` but should be either" + " None, 'linear' or 'quadratic'" + ) k = torch.sum(w_mat * confmat) / torch.sum(w_mat * expected) return 1 - k def cohen_kappa( - preds: torch.Tensor, - target: torch.Tensor, - num_classes: int, - weights: Optional[str] = None, - threshold: float = 0.5 -) -> torch.Tensor: + preds: Tensor, + target: Tensor, + num_classes: int, + weights: Optional[str] = None, + threshold: float = 0.5, +) -> Tensor: r""" Calculates `Cohen's kappa score `_ that measures inter-annotator agreement. It is defined as diff --git a/torchmetrics/functional/classification/confusion_matrix.py b/torchmetrics/functional/classification/confusion_matrix.py index 431cf6fe8a1..d468e0a78c1 100644 --- a/torchmetrics/functional/classification/confusion_matrix.py +++ b/torchmetrics/functional/classification/confusion_matrix.py @@ -14,15 +14,14 @@ from typing import Optional import torch +from torch import Tensor from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.checks import _input_format_classification from torchmetrics.utilities.enums import DataType -def _confusion_matrix_update( - preds: torch.Tensor, target: torch.Tensor, num_classes: int, threshold: float = 0.5 -) -> torch.Tensor: +def _confusion_matrix_update(preds: Tensor, target: Tensor, num_classes: int, threshold: float = 0.5) -> Tensor: preds, target, mode = _input_format_classification(preds, target, threshold) if mode not in (DataType.BINARY, DataType.MULTILABEL): preds = preds.argmax(dim=1) @@ -33,12 +32,13 @@ def _confusion_matrix_update( return confmat -def _confusion_matrix_compute(confmat: torch.Tensor, normalize: Optional[str] = None) -> torch.Tensor: +def _confusion_matrix_compute(confmat: Tensor, normalize: Optional[str] = None) -> Tensor: allowed_normalize = ('true', 'pred', 'all', 'none', None) assert normalize in allowed_normalize, \ f"Argument average needs to one of the following: {allowed_normalize}" confmat = confmat.float() if normalize is not None and normalize != 'none': + cm = None if normalize == 'true': cm = confmat / confmat.sum(axis=1, keepdim=True) elif normalize == 'pred': @@ -54,12 +54,8 @@ def _confusion_matrix_compute(confmat: torch.Tensor, normalize: Optional[str] = def confusion_matrix( - preds: torch.Tensor, - target: torch.Tensor, - num_classes: int, - normalize: Optional[str] = None, - threshold: float = 0.5 -) -> torch.Tensor: + preds: Tensor, target: Tensor, num_classes: int, normalize: Optional[str] = None, threshold: float = 0.5 +) -> Tensor: """ Computes the confusion matrix. Works with binary, multiclass, and multilabel data. Accepts probabilities from a model output or integer class values in prediction. @@ -86,7 +82,6 @@ def confusion_matrix( Threshold value for binary or multi-label probabilities. default: 0.5 Example: - >>> from torchmetrics.functional import confusion_matrix >>> target = torch.tensor([1, 1, 0, 0]) >>> preds = torch.tensor([0, 1, 0, 0]) diff --git a/torchmetrics/functional/classification/dice.py b/torchmetrics/functional/classification/dice.py index ef892488954..d3b79bab63f 100644 --- a/torchmetrics/functional/classification/dice.py +++ b/torchmetrics/functional/classification/dice.py @@ -14,17 +14,18 @@ from typing import Tuple import torch +from torch import Tensor from torchmetrics.utilities.data import to_categorical from torchmetrics.utilities.distributed import reduce def _stat_scores( - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, class_index: int, argmax_dim: int = 1, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: +) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: """ Calculates the number of true positive, false positive, true negative and false negative for a specific class @@ -40,7 +41,6 @@ def _stat_scores( True Positive, False Positive, True Negative, False Negative, Support Example: - >>> x = torch.tensor([1, 2, 3]) >>> y = torch.tensor([0, 2, 3]) >>> tp, fp, tn, fn, sup = _stat_scores(x, y, class_index=1) @@ -61,13 +61,13 @@ def _stat_scores( def dice_score( - pred: torch.Tensor, - target: torch.Tensor, + pred: Tensor, + target: Tensor, bg: bool = False, nan_score: float = 0.0, no_fg_score: float = 0.0, reduction: str = 'elementwise_mean', -) -> torch.Tensor: +) -> Tensor: """ Compute dice score from prediction scores @@ -87,7 +87,7 @@ def dice_score( Tensor containing dice score Example: - + >>> from torchmetrics.functional import dice_score >>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], ... [0.05, 0.85, 0.05, 0.05], ... [0.05, 0.05, 0.85, 0.05], diff --git a/torchmetrics/functional/classification/f_beta.py b/torchmetrics/functional/classification/f_beta.py index 3bb60b1d09e..bf92d69c3b4 100644 --- a/torchmetrics/functional/classification/f_beta.py +++ b/torchmetrics/functional/classification/f_beta.py @@ -14,18 +14,19 @@ from typing import Tuple import torch +from torch import Tensor from torchmetrics.utilities.checks import _input_format_classification_one_hot from torchmetrics.utilities.distributed import class_reduce def _fbeta_update( - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, num_classes: int, threshold: float = 0.5, - multilabel: bool = False -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + multilabel: bool = False, +) -> Tuple[Tensor, Tensor, Tensor]: preds, target = _input_format_classification_one_hot(num_classes, preds, target, threshold, multilabel) true_positives = torch.sum(preds * target, dim=1) predicted_positives = torch.sum(preds, dim=1) @@ -34,12 +35,12 @@ def _fbeta_update( def _fbeta_compute( - true_positives: torch.Tensor, - predicted_positives: torch.Tensor, - actual_positives: torch.Tensor, + true_positives: Tensor, + predicted_positives: Tensor, + actual_positives: Tensor, beta: float = 1.0, - average: str = "micro" -) -> torch.Tensor: + average: str = "micro", +) -> Tensor: if average == "micro": precision = true_positives.sum().float() / predicted_positives.sum() recall = true_positives.sum().float() / actual_positives.sum() @@ -53,14 +54,14 @@ def _fbeta_compute( def fbeta( - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, num_classes: int, beta: float = 1.0, threshold: float = 0.5, average: str = "micro", multilabel: bool = False -) -> torch.Tensor: +) -> Tensor: """ Computes f_beta metric. @@ -91,7 +92,6 @@ def fbeta( multilabel: If predictions are from multilabel classification. Example: - >>> from torchmetrics.functional import fbeta >>> target = torch.tensor([0, 1, 2, 0, 1, 2]) >>> preds = torch.tensor([0, 2, 1, 0, 0, 1]) @@ -106,13 +106,13 @@ def fbeta( def f1( - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, num_classes: int, threshold: float = 0.5, average: str = "micro", multilabel: bool = False -) -> torch.Tensor: +) -> Tensor: """ Computes F1 metric. F1 metrics correspond to a equally weighted average of the precision and recall scores. diff --git a/torchmetrics/functional/classification/hamming_distance.py b/torchmetrics/functional/classification/hamming_distance.py index 60a32e4377c..f6ab51f0a6b 100644 --- a/torchmetrics/functional/classification/hamming_distance.py +++ b/torchmetrics/functional/classification/hamming_distance.py @@ -14,15 +14,16 @@ from typing import Tuple, Union import torch +from torch import Tensor from torchmetrics.utilities.checks import _input_format_classification def _hamming_distance_update( - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, threshold: float = 0.5, -) -> Tuple[torch.Tensor, int]: +) -> Tuple[Tensor, int]: preds, target, _ = _input_format_classification(preds, target, threshold=threshold) correct = (preds == target).sum() @@ -31,11 +32,11 @@ def _hamming_distance_update( return correct, total -def _hamming_distance_compute(correct: torch.Tensor, total: Union[int, torch.Tensor]) -> torch.Tensor: +def _hamming_distance_compute(correct: Tensor, total: Union[int, Tensor]) -> Tensor: return 1 - correct.float() / total -def hamming_distance(preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5) -> torch.Tensor: +def hamming_distance(preds: Tensor, target: Tensor, threshold: float = 0.5) -> Tensor: r""" Computes the average `Hamming distance `_ (also known as Hamming loss) between targets and predictions: @@ -61,7 +62,6 @@ def hamming_distance(preds: torch.Tensor, target: torch.Tensor, threshold: float (0 or 1) predictions, in the case of binary or multi-label inputs. Example: - >>> from torchmetrics.functional import hamming_distance >>> target = torch.tensor([[0, 1], [1, 1]]) >>> preds = torch.tensor([[0, 1], [0, 1]]) diff --git a/torchmetrics/functional/classification/iou.py b/torchmetrics/functional/classification/iou.py index 15502263853..46e60bcd470 100644 --- a/torchmetrics/functional/classification/iou.py +++ b/torchmetrics/functional/classification/iou.py @@ -14,6 +14,7 @@ from typing import Optional import torch +from torch import Tensor from torchmetrics.functional.classification.confusion_matrix import _confusion_matrix_update from torchmetrics.utilities.data import get_num_classes @@ -21,7 +22,7 @@ def _iou_from_confmat( - confmat: torch.Tensor, + confmat: Tensor, num_classes: int, ignore_index: Optional[int] = None, absent_score: float = 0.0, @@ -35,7 +36,7 @@ def _iou_from_confmat( scores[union == 0] = absent_score # Remove the ignored class index from the scores. - if ignore_index is not None and ignore_index >= 0 and ignore_index < num_classes: + if ignore_index is not None and 0 <= ignore_index < num_classes: scores = torch.cat([ scores[:ignore_index], scores[ignore_index + 1:], @@ -44,14 +45,14 @@ def _iou_from_confmat( def iou( - pred: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, ignore_index: Optional[int] = None, absent_score: float = 0.0, threshold: float = 0.5, num_classes: Optional[int] = None, reduction: str = 'elementwise_mean', -) -> torch.Tensor: +) -> Tensor: r""" Computes `Intersection over union, or Jaccard index calculation `_: @@ -97,15 +98,14 @@ def iou( 'elementwise_mean', or number of classes if reduction is 'none' Example: - + >>> from torchmetrics.functional import iou >>> target = torch.randint(0, 2, (10, 25, 25)) >>> pred = torch.tensor(target) >>> pred[2:5, 7:13, 9:15] = 1 - pred[2:5, 7:13, 9:15] >>> iou(pred, target) tensor(0.9660) - """ - num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes) - confmat = _confusion_matrix_update(pred, target, num_classes, threshold) + num_classes = get_num_classes(preds=preds, target=target, num_classes=num_classes) + confmat = _confusion_matrix_update(preds, target, num_classes, threshold) return _iou_from_confmat(confmat, num_classes, ignore_index, absent_score, reduction) diff --git a/torchmetrics/functional/classification/matthews_corrcoef.py b/torchmetrics/functional/classification/matthews_corrcoef.py new file mode 100644 index 00000000000..91db05a7a40 --- /dev/null +++ b/torchmetrics/functional/classification/matthews_corrcoef.py @@ -0,0 +1,66 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from torch import Tensor + +from torchmetrics.functional.classification.confusion_matrix import _confusion_matrix_update + +_matthews_corrcoef_update = _confusion_matrix_update + + +def _matthews_corrcoef_compute(confmat: Tensor) -> Tensor: + tk = confmat.sum(dim=0).float() + pk = confmat.sum(dim=1).float() + c = torch.trace(confmat).float() + s = confmat.sum().float() + return (c * s - sum(tk * pk)) / (torch.sqrt(s ** 2 - sum(pk * pk)) * torch.sqrt(s ** 2 - sum(tk * tk))) + + +def matthews_corrcoef( + preds: Tensor, + target: Tensor, + num_classes: int, + threshold: float = 0.5 +) -> Tensor: + r""" + Calculates `Matthews correlation coefficient + `_ that measures + the general correlation or quality of a classification. In the binary case it + is defined as: + + .. math:: + MCC = \frac{TP*TN - FP*FN}{\sqrt{(TP+FP)*(TP+FN)*(TN+FP)*(TN+FN)}} + + where TP, TN, FP and FN are respectively the true postitives, true negatives, + false positives and false negatives. Also works in the case of multi-label or + multi-class input. + + Args: + preds: (float or long tensor), Either a ``(N, ...)`` tensor with labels or + ``(N, C, ...)`` where C is the number of classes, tensor with labels/probabilities + target: ``target`` (long tensor), tensor with shape ``(N, ...)`` with ground true labels + num_classes: Number of classes in the dataset. + threshold: + Threshold value for binary or multi-label probabilities. default: 0.5 + + Example: + >>> from torchmetrics.functional import matthews_corrcoef + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0, 1, 0, 0]) + >>> matthews_corrcoef(preds, target, num_classes=2) + tensor(0.5774) + + """ + confmat = _matthews_corrcoef_update(preds, target, num_classes, threshold) + return _matthews_corrcoef_compute(confmat) diff --git a/torchmetrics/functional/classification/precision_recall.py b/torchmetrics/functional/classification/precision_recall.py index 62395454ce8..1b2680e00a6 100644 --- a/torchmetrics/functional/classification/precision_recall.py +++ b/torchmetrics/functional/classification/precision_recall.py @@ -11,22 +11,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Optional, Tuple import torch +from torch import Tensor from torchmetrics.classification.stat_scores import _reduce_stat_scores from torchmetrics.functional.classification.stat_scores import _stat_scores_update def _precision_compute( - tp: torch.Tensor, - fp: torch.Tensor, - tn: torch.Tensor, - fn: torch.Tensor, + tp: Tensor, + fp: Tensor, + tn: Tensor, + fn: Tensor, average: str, mdmc_average: Optional[str], -) -> torch.Tensor: +) -> Tensor: + # todo: `tn` is unused return _reduce_stat_scores( numerator=tp, denominator=tp + fp, @@ -37,8 +39,8 @@ def _precision_compute( def precision( - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, average: str = "micro", mdmc_average: Optional[str] = None, ignore_index: Optional[int] = None, @@ -46,7 +48,7 @@ def precision( threshold: float = 0.5, top_k: Optional[int] = None, is_multiclass: Optional[bool] = None, -) -> torch.Tensor: +) -> Tensor: r""" Computes `Precision `_: @@ -128,8 +130,19 @@ def precision( - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number of classes - Example: + Raises: + ValueError: + If ``average`` is not one of ``"micro"``, ``"macro"``, ``"weighted"``, + ``"samples"``, ``"none"`` or ``None``. + ValueError: + If ``mdmc_average`` is not one of ``None``, ``"samplewise"``, ``"global"``. + ValueError: + If ``average`` is set but ``num_classes`` is not provided. + ValueError: + If ``num_classes`` is set + and ``ignore_index`` is not in the range ``[0, num_classes)``. + Example: >>> from torchmetrics.functional import precision >>> preds = torch.tensor([2, 0, 2, 1]) >>> target = torch.tensor([1, 1, 2, 0]) @@ -170,13 +183,15 @@ def precision( def _recall_compute( - tp: torch.Tensor, - fp: torch.Tensor, - tn: torch.Tensor, - fn: torch.Tensor, + tp: Tensor, + fp: Tensor, + tn: Tensor, + fn: Tensor, average: str, mdmc_average: Optional[str], -) -> torch.Tensor: +) -> Tensor: + # todo: `tp` is unused + # todo: `tn` is unused return _reduce_stat_scores( numerator=tp, denominator=tp + fn, @@ -187,8 +202,8 @@ def _recall_compute( def recall( - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, average: str = "micro", mdmc_average: Optional[str] = None, ignore_index: Optional[int] = None, @@ -196,7 +211,7 @@ def recall( threshold: float = 0.5, top_k: Optional[int] = None, is_multiclass: Optional[bool] = None, -) -> torch.Tensor: +) -> Tensor: r""" Computes `Recall `_: @@ -278,8 +293,19 @@ def recall( - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number of classes - Example: + Raises: + ValueError: + If ``average`` is not one of ``"micro"``, ``"macro"``, ``"weighted"``, + ``"samples"``, ``"none"`` or ``None``. + ValueError: + If ``mdmc_average`` is not one of ``None``, ``"samplewise"``, ``"global"``. + ValueError: + If ``average`` is set but ``num_classes`` is not provided. + ValueError: + If ``num_classes`` is set + and ``ignore_index`` is not in the range ``[0, num_classes)``. + Example: >>> from torchmetrics.functional import recall >>> preds = torch.tensor([2, 0, 2, 1]) >>> target = torch.tensor([1, 1, 2, 0]) @@ -320,8 +346,8 @@ def recall( def precision_recall( - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, average: str = "micro", mdmc_average: Optional[str] = None, ignore_index: Optional[int] = None, @@ -329,7 +355,7 @@ def precision_recall( threshold: float = 0.5, top_k: Optional[int] = None, is_multiclass: Optional[bool] = None, -) -> torch.Tensor: +) -> Tuple[Tensor, Tensor]: r""" Computes `Precision and Recall `_: @@ -415,8 +441,19 @@ def precision_recall( - If ``average in ['none', None]``, they are a tensor of shape ``(C, )``, where ``C`` stands for the number of classes - Example: + Raises: + ValueError: + If ``average`` is not one of ``"micro"``, ``"macro"``, ``"weighted"``, + ``"samples"``, ``"none"`` or ``None``. + ValueError: + If ``mdmc_average`` is not one of ``None``, ``"samplewise"``, ``"global"``. + ValueError: + If ``average`` is set but ``num_classes`` is not provided. + ValueError: + If ``num_classes`` is set + and ``ignore_index`` is not in the range ``[0, num_classes)``. + Example: >>> from torchmetrics.functional import precision_recall >>> preds = torch.tensor([2, 0, 2, 1]) >>> target = torch.tensor([1, 1, 2, 0]) diff --git a/torchmetrics/functional/classification/precision_recall_curve.py b/torchmetrics/functional/classification/precision_recall_curve.py index 064d294b5db..7d6881b23f3 100644 --- a/torchmetrics/functional/classification/precision_recall_curve.py +++ b/torchmetrics/functional/classification/precision_recall_curve.py @@ -14,22 +14,23 @@ from typing import List, Optional, Sequence, Tuple, Union import torch -import torch.nn.functional as F +from torch import Tensor, tensor +from torch.nn import functional as F from torchmetrics.utilities import rank_zero_warn def _binary_clf_curve( - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, sample_weights: Optional[Sequence] = None, pos_label: int = 1., -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> Tuple[Tensor, Tensor, Tensor]: """ adapted from https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/_ranking.py """ - if sample_weights is not None and not isinstance(sample_weights, torch.Tensor): - sample_weights = torch.tensor(sample_weights, device=preds.device, dtype=torch.float) + if sample_weights is not None and not isinstance(sample_weights, Tensor): + sample_weights = tensor(sample_weights, device=preds.device, dtype=torch.float) # remove class dimension if necessary if preds.ndim > target.ndim: @@ -48,7 +49,7 @@ def _binary_clf_curve( # the indices associated with the distinct values. We also # concatenate a value for the end of the curve. distinct_value_indices = torch.where(preds[1:] - preds[:-1])[0] - threshold_idxs = F.pad(distinct_value_indices, (0, 1), value=target.size(0) - 1) + threshold_idxs = F.pad(distinct_value_indices, [0, 1], value=target.size(0) - 1) target = (target == pos_label).to(torch.long) tps = torch.cumsum(target * weight, dim=0)[threshold_idxs] @@ -63,11 +64,11 @@ def _binary_clf_curve( def _precision_recall_curve_update( - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, num_classes: Optional[int] = None, pos_label: Optional[int] = None, -) -> Tuple[torch.Tensor, torch.Tensor, int, int]: +) -> Tuple[Tensor, Tensor, int, int]: if not (len(preds.shape) == len(target.shape) or len(preds.shape) == len(target.shape) + 1): raise ValueError("preds and target must have same number of dimensions, or one additional dimension for preds") @@ -111,13 +112,12 @@ def _precision_recall_curve_update( def _precision_recall_curve_compute( - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, num_classes: int, pos_label: int, sample_weights: Optional[Sequence] = None, -) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor], - List[torch.Tensor]]]: +) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: if num_classes == 1: fps, tps, thresholds = _binary_clf_curve( @@ -161,13 +161,12 @@ def _precision_recall_curve_compute( def precision_recall_curve( - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, num_classes: Optional[int] = None, pos_label: Optional[int] = None, sample_weights: Optional[Sequence] = None, -) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor], - List[torch.Tensor]]]: +) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: """ Computes precision-recall pairs for different thresholds. @@ -182,7 +181,8 @@ def precision_recall_curve( range [0,num_classes-1] sample_weights: sample weights for each data point - Returns: 3-element tuple containing + Returns: + 3-element tuple containing precision: tensor where element i is the precision of predictions with @@ -195,8 +195,17 @@ def precision_recall_curve( thresholds: Thresholds used for computing precision/recall scores - Example (binary case): - + Raises: + ValueError: + If ``preds`` and ``target`` don't have the same number of dimensions, + or one additional dimension for ``preds``. + ValueError: + If the number of classes deduced from ``preds`` is not the same as the + ``num_classes`` provided. + + Example: + >>> # binary case + >>> from torchmetrics.functional import precision_recall_curve >>> pred = torch.tensor([0, 1, 2, 3]) >>> target = torch.tensor([0, 1, 1, 0]) >>> precision, recall, thresholds = precision_recall_curve(pred, target, pos_label=1) @@ -207,8 +216,7 @@ def precision_recall_curve( >>> thresholds tensor([1, 2, 3]) - Example (multiclass case): - + >>> # multiclass case >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], @@ -222,7 +230,6 @@ def precision_recall_curve( [tensor([1., 0.]), tensor([1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])] >>> thresholds [tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])] - """ preds, target, num_classes, pos_label = _precision_recall_curve_update(preds, target, num_classes, pos_label) return _precision_recall_curve_compute(preds, target, num_classes, pos_label, sample_weights) diff --git a/torchmetrics/functional/classification/roc.py b/torchmetrics/functional/classification/roc.py index de40bc8b166..b7fee4ba620 100644 --- a/torchmetrics/functional/classification/roc.py +++ b/torchmetrics/functional/classification/roc.py @@ -14,6 +14,7 @@ from typing import List, Optional, Sequence, Tuple, Union import torch +from torch import Tensor from torchmetrics.functional.classification.precision_recall_curve import ( _binary_clf_curve, @@ -22,23 +23,22 @@ def _roc_update( - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, num_classes: Optional[int] = None, pos_label: Optional[int] = None, -) -> Tuple[torch.Tensor, torch.Tensor, int, int, str]: +) -> Tuple[Tensor, Tensor, int, int, str]: preds, target, num_classes, pos_label = _precision_recall_curve_update(preds, target, num_classes, pos_label) return preds, target, num_classes, pos_label def _roc_compute( - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, num_classes: int, pos_label: int, sample_weights: Optional[Sequence] = None, -) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor], - List[torch.Tensor]]]: +) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: if num_classes == 1 and preds.ndim == 1: # binary fps, tps, thresholds = _binary_clf_curve( @@ -86,13 +86,12 @@ def _roc_compute( def roc( - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, num_classes: Optional[int] = None, pos_label: Optional[int] = None, sample_weights: Optional[Sequence] = None, -) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor], - List[torch.Tensor]]]: +) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: """ Computes the Receiver Operating Characteristic (ROC). Works with both binary, multiclass and multilabel input. @@ -108,7 +107,8 @@ def roc( range [0,num_classes-1] sample_weights: sample weights for each data point - Returns: 3-element tuple containing + Returns: + 3-element tuple containing fpr: tensor with false positive rates. diff --git a/torchmetrics/functional/classification/stat_scores.py b/torchmetrics/functional/classification/stat_scores.py index a5682eb8bc6..985f84dd9b2 100644 --- a/torchmetrics/functional/classification/stat_scores.py +++ b/torchmetrics/functional/classification/stat_scores.py @@ -14,21 +14,22 @@ from typing import Optional, Tuple import torch +from torch import Tensor, tensor from torchmetrics.utilities.checks import _input_format_classification -def _del_column(tensor: torch.Tensor, index: int): +def _del_column(tensor: Tensor, index: int): """ Delete the column at index.""" return torch.cat([tensor[:, :index], tensor[:, (index + 1):]], 1) def _stat_scores( - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, reduce: str = "micro", -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """Calculate the number of tp, fp, tn, fn. Args: @@ -74,8 +75,8 @@ def _stat_scores( def _stat_scores_update( - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, reduce: str = "micro", mdmc_reduce: Optional[str] = None, num_classes: Optional[int] = None, @@ -83,7 +84,7 @@ def _stat_scores_update( threshold: float = 0.5, is_multiclass: Optional[bool] = None, ignore_index: Optional[int] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: preds, target, _ = _input_format_classification( preds, target, threshold=threshold, num_classes=num_classes, is_multiclass=is_multiclass, top_k=top_k @@ -121,7 +122,7 @@ def _stat_scores_update( return tp, fp, tn, fn -def _stat_scores_compute(tp: torch.Tensor, fp: torch.Tensor, tn: torch.Tensor, fn: torch.Tensor) -> torch.Tensor: +def _stat_scores_compute(tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor) -> Tensor: outputs = [ tp.unsqueeze(-1), @@ -131,14 +132,14 @@ def _stat_scores_compute(tp: torch.Tensor, fp: torch.Tensor, tn: torch.Tensor, f tp.unsqueeze(-1) + fn.unsqueeze(-1), # support ] outputs = torch.cat(outputs, -1) - outputs = torch.where(outputs < 0, torch.tensor(-1, device=outputs.device), outputs) + outputs = torch.where(outputs < 0, tensor(-1, device=outputs.device), outputs) return outputs def stat_scores( - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, reduce: str = "micro", mdmc_reduce: Optional[str] = None, num_classes: Optional[int] = None, @@ -146,7 +147,7 @@ def stat_scores( threshold: float = 0.5, is_multiclass: Optional[bool] = None, ignore_index: Optional[int] = None, -) -> torch.Tensor: +) -> Tensor: """Computes the number of true positives, false positives, true negatives, false negatives. Related to `Type I and Type II errors `__ and the `confusion matrix `__. @@ -244,8 +245,22 @@ def stat_scores( - If ``reduce='macro'``, the shape will be ``(N, C, 5)`` - If ``reduce='samples'``, the shape will be ``(N, X, 5)`` - Example: + Raises: + ValueError: + If ``reduce`` is none of ``"micro"``, ``"macro"`` or ``"samples"``. + ValueError: + If ``mdmc_reduce`` is none of ``None``, ``"samplewise"``, ``"global"``. + ValueError: + If ``reduce`` is set to ``"macro"`` and ``num_classes`` is not provided. + ValueError: + If ``num_classes`` is set + and ``ignore_index`` is not in the range ``[0, num_classes)``. + ValueError: + If ``ignore_index`` is used with ``binary data``. + ValueError: + If inputs are ``multi-dimensional multi-class`` and ``mdmc_reduce`` is not provided. + Example: >>> from torchmetrics.functional import stat_scores >>> preds = torch.tensor([1, 0, 2, 1]) >>> target = torch.tensor([1, 1, 2, 0]) @@ -255,7 +270,6 @@ def stat_scores( [1, 0, 3, 0, 1]]) >>> stat_scores(preds, target, reduce='micro') tensor([2, 2, 6, 2, 4]) - """ if reduce not in ["micro", "macro", "samples"]: diff --git a/torchmetrics/functional/image_gradients.py b/torchmetrics/functional/image_gradients.py index 3b7e8ffce42..d48c84d4836 100644 --- a/torchmetrics/functional/image_gradients.py +++ b/torchmetrics/functional/image_gradients.py @@ -14,18 +14,19 @@ from typing import Tuple import torch +from torch import Tensor -def _image_gradients_validate(img: torch.Tensor) -> torch.Tensor: +def _image_gradients_validate(img: Tensor) -> Tensor: """ Validates whether img is a 4D torch Tensor """ - if not isinstance(img, torch.Tensor): - raise TypeError(f"The `img` expects a value of type but got {type(img)}") + if not isinstance(img, Tensor): + raise TypeError(f"The `img` expects a value of type but got {type(img)}") if img.ndim != 4: raise RuntimeError(f"The `img` expects a 4D tensor but got {img.ndim}D tensor") -def _compute_image_gradients(img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: +def _compute_image_gradients(img: Tensor) -> Tuple[Tensor, Tensor]: """ Computes image gradients (dy/dx) for a given image """ batch_size, channels, height, width = img.shape @@ -44,7 +45,7 @@ def _compute_image_gradients(img: torch.Tensor) -> Tuple[torch.Tensor, torch.Ten return dy, dx -def image_gradients(img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: +def image_gradients(img: Tensor) -> Tuple[Tensor, Tensor]: """ Computes the `gradients `_ of a given image using finite difference @@ -54,7 +55,14 @@ def image_gradients(img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: Return: Tuple of (dy, dx) with each gradient of shape ``[N, C, H, W]`` + Raises: + TypeError: + If ``img`` is not of the type . + RuntimeError: + If ``img`` is not a 4D tensor. + Example: + >>> from torchmetrics.functional import image_gradients >>> image = torch.arange(0, 1*1*5*5, dtype=torch.float32) >>> image = torch.reshape(image, (1, 1, 5, 5)) >>> dy, dx = image_gradients(image) diff --git a/torchmetrics/functional/nlp.py b/torchmetrics/functional/nlp.py index 57c6fc1ece4..53f5e47e40c 100644 --- a/torchmetrics/functional/nlp.py +++ b/torchmetrics/functional/nlp.py @@ -20,6 +20,7 @@ from typing import List, Sequence import torch +from torch import Tensor, tensor def _count_ngram(ngram_input_list: List[str], n_gram: int) -> Counter: @@ -45,11 +46,8 @@ def _count_ngram(ngram_input_list: List[str], n_gram: int) -> Counter: def bleu_score( - translate_corpus: Sequence[str], - reference_corpus: Sequence[str], - n_gram: int = 4, - smooth: bool = False -) -> torch.Tensor: + translate_corpus: Sequence[str], reference_corpus: Sequence[str], n_gram: int = 4, smooth: bool = False +) -> Tensor: """ Calculate BLEU score of machine translated text with one or more references @@ -63,12 +61,11 @@ def bleu_score( Tensor with BLEU Score Example: - + >>> from torchmetrics.functional import bleu_score >>> translate_corpus = ['the cat is on the mat'.split()] >>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]] >>> bleu_score(translate_corpus, reference_corpus) tensor(0.7598) - """ assert len(translate_corpus) == len(reference_corpus) @@ -96,20 +93,20 @@ def bleu_score( for counter in translation_counter: denominator[len(counter) - 1] += translation_counter[counter] - trans_len = torch.tensor(c) - ref_len = torch.tensor(r) + trans_len = tensor(c) + ref_len = tensor(r) if min(numerator) == 0.0: - return torch.tensor(0.0) + return tensor(0.0) if smooth: precision_scores = torch.add(numerator, torch.ones(n_gram)) / torch.add(denominator, torch.ones(n_gram)) else: precision_scores = numerator / denominator - log_precision_scores = torch.tensor([1.0 / n_gram] * n_gram) * torch.log(precision_scores) + log_precision_scores = tensor([1.0 / n_gram] * n_gram) * torch.log(precision_scores) geometric_mean = torch.exp(torch.sum(log_precision_scores)) - brevity_penalty = torch.tensor(1.0) if c > r else torch.exp(1 - (ref_len / trans_len)) + brevity_penalty = tensor(1.0) if c > r else torch.exp(1 - (ref_len / trans_len)) bleu = brevity_penalty * geometric_mean return bleu diff --git a/torchmetrics/functional/regression/explained_variance.py b/torchmetrics/functional/regression/explained_variance.py index b2791cf9de0..0f38b2259ab 100644 --- a/torchmetrics/functional/regression/explained_variance.py +++ b/torchmetrics/functional/regression/explained_variance.py @@ -14,38 +14,37 @@ from typing import Sequence, Tuple, Union import torch +from torch import Tensor from torchmetrics.utilities.checks import _check_same_shape -def _explained_variance_update( - preds: torch.Tensor, target: torch.Tensor -) -> Tuple[int, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: +def _explained_variance_update(preds: Tensor, target: Tensor) -> Tuple[int, Tensor, Tensor, Tensor, Tensor]: _check_same_shape(preds, target) n_obs = preds.size(0) sum_error = torch.sum(target - preds, dim=0) - sum_squared_error = torch.sum((target - preds) ** 2, dim=0) + sum_squared_error = torch.sum((target - preds)**2, dim=0) sum_target = torch.sum(target, dim=0) - sum_squared_target = torch.sum(target ** 2, dim=0) + sum_squared_target = torch.sum(target**2, dim=0) return n_obs, sum_error, sum_squared_error, sum_target, sum_squared_target def _explained_variance_compute( - n_obs: torch.Tensor, - sum_error: torch.Tensor, - sum_squared_error: torch.Tensor, - sum_target: torch.Tensor, - sum_squared_target: torch.Tensor, + n_obs: Tensor, + sum_error: Tensor, + sum_squared_error: Tensor, + sum_target: Tensor, + sum_squared_target: Tensor, multioutput: str = "uniform_average", -) -> Union[torch.Tensor, Sequence[torch.Tensor]]: +) -> Union[Tensor, Sequence[Tensor]]: diff_avg = sum_error / n_obs - numerator = sum_squared_error / n_obs - diff_avg ** 2 + numerator = sum_squared_error / n_obs - diff_avg**2 target_avg = sum_target / n_obs - denominator = sum_squared_target / n_obs - target_avg ** 2 + denominator = sum_squared_target / n_obs - target_avg**2 # Take care of division by zero nonzero_numerator = numerator != 0 @@ -67,10 +66,10 @@ def _explained_variance_compute( def explained_variance( - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, multioutput: str = "uniform_average", -) -> Union[torch.Tensor, Sequence[torch.Tensor]]: +) -> Union[Tensor, Sequence[Tensor]]: """ Computes explained variance. @@ -85,7 +84,6 @@ def explained_variance( * `'variance_weighted'` scores are weighted by their individual variances Example: - >>> from torchmetrics.functional import explained_variance >>> target = torch.tensor([3, -0.5, 2, 7]) >>> preds = torch.tensor([2.5, 0.0, 2, 8]) @@ -98,4 +96,11 @@ def explained_variance( tensor([0.9677, 1.0000]) """ n_obs, sum_error, sum_squared_error, sum_target, sum_squared_target = _explained_variance_update(preds, target) - return _explained_variance_compute(n_obs, sum_error, sum_squared_error, sum_target, sum_squared_target, multioutput) + return _explained_variance_compute( + n_obs, + sum_error, + sum_squared_error, + sum_target, + sum_squared_target, + multioutput, + ) diff --git a/torchmetrics/functional/regression/mean_absolute_error.py b/torchmetrics/functional/regression/mean_absolute_error.py index 16413c33e20..d9fc9295cec 100644 --- a/torchmetrics/functional/regression/mean_absolute_error.py +++ b/torchmetrics/functional/regression/mean_absolute_error.py @@ -14,39 +14,39 @@ from typing import Tuple import torch +from torch import Tensor from torchmetrics.utilities.checks import _check_same_shape -def _mean_absolute_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]: +def _mean_absolute_error_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, int]: _check_same_shape(preds, target) sum_abs_error = torch.sum(torch.abs(preds - target)) n_obs = target.numel() return sum_abs_error, n_obs -def _mean_absolute_error_compute(sum_abs_error: torch.Tensor, n_obs: int) -> torch.Tensor: +def _mean_absolute_error_compute(sum_abs_error: Tensor, n_obs: int) -> Tensor: return sum_abs_error / n_obs -def mean_absolute_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: +def mean_absolute_error(preds: Tensor, target: Tensor) -> Tensor: """ Computes mean absolute error Args: - pred: estimated labels + preds: estimated labels target: ground truth labels Return: Tensor with MAE Example: - + >>> from torchmetrics.functional import mean_absolute_error >>> x = torch.tensor([0., 1, 2, 3]) >>> y = torch.tensor([0., 1, 2, 2]) >>> mean_absolute_error(x, y) tensor(0.2500) - """ sum_abs_error, n_obs = _mean_absolute_error_update(preds, target) return _mean_absolute_error_compute(sum_abs_error, n_obs) diff --git a/torchmetrics/functional/regression/mean_relative_error.py b/torchmetrics/functional/regression/mean_relative_error.py index 66f82e3b94e..286a5c9b687 100644 --- a/torchmetrics/functional/regression/mean_relative_error.py +++ b/torchmetrics/functional/regression/mean_relative_error.py @@ -14,11 +14,12 @@ from typing import Tuple import torch +from torch import Tensor from torchmetrics.utilities.checks import _check_same_shape -def _mean_relative_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]: +def _mean_relative_error_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, int]: _check_same_shape(preds, target) target_nz = target.clone() target_nz[target == 0] = 1 @@ -27,23 +28,23 @@ def _mean_relative_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tu return sum_rltv_error, n_obs -def _mean_relative_error_compute(sum_rltv_error: torch.Tensor, n_obs: int) -> torch.Tensor: +def _mean_relative_error_compute(sum_rltv_error: Tensor, n_obs: int) -> Tensor: return sum_rltv_error / n_obs -def mean_relative_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: +def mean_relative_error(preds: Tensor, target: Tensor) -> Tensor: """ Computes mean relative error Args: - pred: estimated labels + preds: estimated labels target: ground truth labels Return: Tensor with mean relative error Example: - + >>> from torchmetrics.functional import mean_relative_error >>> x = torch.tensor([0., 1, 2, 3]) >>> y = torch.tensor([0., 1, 2, 2]) >>> mean_relative_error(x, y) diff --git a/torchmetrics/functional/regression/mean_squared_error.py b/torchmetrics/functional/regression/mean_squared_error.py index bd88d736c95..d046cf76f4f 100644 --- a/torchmetrics/functional/regression/mean_squared_error.py +++ b/torchmetrics/functional/regression/mean_squared_error.py @@ -14,22 +14,23 @@ from typing import Tuple import torch +from torch import Tensor from torchmetrics.utilities.checks import _check_same_shape -def _mean_squared_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]: +def _mean_squared_error_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, int]: _check_same_shape(preds, target) sum_squared_error = torch.sum(torch.pow(preds - target, 2)) n_obs = target.numel() return sum_squared_error, n_obs -def _mean_squared_error_compute(sum_squared_error: torch.Tensor, n_obs: int) -> torch.Tensor: +def _mean_squared_error_compute(sum_squared_error: Tensor, n_obs: int) -> Tensor: return sum_squared_error / n_obs -def mean_squared_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: +def mean_squared_error(preds: Tensor, target: Tensor) -> Tensor: """ Computes mean squared error @@ -41,12 +42,11 @@ def mean_squared_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tenso Tensor with MSE Example: - + >>> from torchmetrics.functional import mean_squared_error >>> x = torch.tensor([0., 1, 2, 3]) >>> y = torch.tensor([0., 1, 2, 2]) >>> mean_squared_error(x, y) tensor(0.2500) - """ sum_squared_error, n_obs = _mean_squared_error_update(preds, target) return _mean_squared_error_compute(sum_squared_error, n_obs) diff --git a/torchmetrics/functional/regression/mean_squared_log_error.py b/torchmetrics/functional/regression/mean_squared_log_error.py index 7308549529a..212e17a73da 100644 --- a/torchmetrics/functional/regression/mean_squared_log_error.py +++ b/torchmetrics/functional/regression/mean_squared_log_error.py @@ -14,22 +14,23 @@ from typing import Tuple import torch +from torch import Tensor from torchmetrics.utilities.checks import _check_same_shape -def _mean_squared_log_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]: +def _mean_squared_log_error_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, int]: _check_same_shape(preds, target) sum_squared_log_error = torch.sum(torch.pow(torch.log1p(preds) - torch.log1p(target), 2)) n_obs = target.numel() return sum_squared_log_error, n_obs -def _mean_squared_log_error_compute(sum_squared_log_error: torch.Tensor, n_obs: int) -> torch.Tensor: +def _mean_squared_log_error_compute(sum_squared_log_error: Tensor, n_obs: int) -> Tensor: return sum_squared_log_error / n_obs -def mean_squared_log_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: +def mean_squared_log_error(preds: Tensor, target: Tensor) -> Tensor: """ Computes mean squared log error @@ -41,12 +42,11 @@ def mean_squared_log_error(preds: torch.Tensor, target: torch.Tensor) -> torch.T Tensor with RMSLE Example: - + >>> from torchmetrics.functional import mean_squared_log_error >>> x = torch.tensor([0., 1, 2, 3]) >>> y = torch.tensor([0., 1, 2, 2]) >>> mean_squared_log_error(x, y) tensor(0.0207) - """ sum_squared_log_error, n_obs = _mean_squared_log_error_update(preds, target) return _mean_squared_log_error_compute(sum_squared_log_error, n_obs) diff --git a/torchmetrics/functional/regression/psnr.py b/torchmetrics/functional/regression/psnr.py index c2cf6d86121..54c27647732 100644 --- a/torchmetrics/functional/regression/psnr.py +++ b/torchmetrics/functional/regression/psnr.py @@ -14,28 +14,31 @@ from typing import Optional, Tuple, Union import torch +from torch import Tensor, tensor from torchmetrics.utilities import rank_zero_warn, reduce def _psnr_compute( - sum_squared_error: torch.Tensor, - n_obs: torch.Tensor, - data_range: torch.Tensor, + sum_squared_error: Tensor, + n_obs: Tensor, + data_range: Tensor, base: float = 10.0, reduction: str = 'elementwise_mean', -) -> torch.Tensor: +) -> Tensor: psnr_base_e = 2 * torch.log(data_range) - torch.log(sum_squared_error / n_obs) - psnr = psnr_base_e * (10 / torch.log(torch.tensor(base))) + psnr = psnr_base_e * (10 / torch.log(tensor(base))) return reduce(psnr, reduction=reduction) -def _psnr_update(preds: torch.Tensor, - target: torch.Tensor, - dim: Optional[Union[int, Tuple[int, ...]]] = None) -> Tuple[torch.Tensor, torch.Tensor]: +def _psnr_update( + preds: Tensor, + target: Tensor, + dim: Optional[Union[int, Tuple[int, ...]]] = None, +) -> Tuple[Tensor, Tensor]: if dim is None: sum_squared_error = torch.sum(torch.pow(preds - target, 2)) - n_obs = torch.tensor(target.numel(), device=target.device) + n_obs = tensor(target.numel(), device=target.device) return sum_squared_error, n_obs sum_squared_error = torch.sum(torch.pow(preds - target, 2), dim=dim) @@ -45,22 +48,22 @@ def _psnr_update(preds: torch.Tensor, else: dim_list = list(dim) if not dim_list: - n_obs = torch.tensor(target.numel(), device=target.device) + n_obs = tensor(target.numel(), device=target.device) else: - n_obs = torch.tensor(target.size(), device=target.device)[dim_list].prod() + n_obs = tensor(target.size(), device=target.device)[dim_list].prod() n_obs = n_obs.expand_as(sum_squared_error) return sum_squared_error, n_obs def psnr( - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, data_range: Optional[float] = None, base: float = 10.0, reduction: str = 'elementwise_mean', dim: Optional[Union[int, Tuple[int, ...]]] = None, -) -> torch.Tensor: +) -> Tensor: """ Computes the peak signal-to-noise ratio @@ -83,8 +86,12 @@ def psnr( Return: Tensor with PSNR score - Example: + Raises: + ValueError: + If ``dim`` is not ``None`` and ``data_range`` is not provided. + Example: + >>> from torchmetrics.functional import psnr >>> pred = torch.tensor([[0.0, 1.0], [2.0, 3.0]]) >>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]]) >>> psnr(pred, target) @@ -102,6 +109,6 @@ def psnr( data_range = target.max() - target.min() else: - data_range = torch.tensor(float(data_range)) + data_range = tensor(float(data_range)) sum_squared_error, n_obs = _psnr_update(preds, target, dim=dim) return _psnr_compute(sum_squared_error, n_obs, data_range, base=base, reduction=reduction) diff --git a/torchmetrics/functional/regression/r2score.py b/torchmetrics/functional/regression/r2score.py index 8b226e076c9..ebafa94b61a 100644 --- a/torchmetrics/functional/regression/r2score.py +++ b/torchmetrics/functional/regression/r2score.py @@ -14,23 +14,21 @@ from typing import Tuple import torch +from torch import Tensor from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.checks import _check_same_shape -def _r2score_update( - preds: torch.tensor, - target: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: +def _r2score_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: _check_same_shape(preds, target) if preds.ndim > 2: raise ValueError( 'Expected both prediction and target to be 1D or 2D tensors,' - f' but recevied tensors with dimension {preds.shape}' + f' but received tensors with dimension {preds.shape}' ) if len(preds) < 2: - raise ValueError('Needs atleast two samples to calculate r2 score.') + raise ValueError('Needs at least two samples to calculate r2 score.') sum_error = torch.sum(target, dim=0) sum_squared_error = torch.sum(torch.pow(target, 2.0), dim=0) @@ -41,13 +39,13 @@ def _r2score_update( def _r2score_compute( - sum_squared_error: torch.Tensor, - sum_error: torch.Tensor, - residual: torch.Tensor, - total: torch.Tensor, + sum_squared_error: Tensor, + sum_error: Tensor, + residual: Tensor, + total: Tensor, adjusted: int = 0, - multioutput: str = "uniform_average" -) -> torch.Tensor: + multioutput: str = "uniform_average", +) -> Tensor: mean_error = sum_error / total diff = sum_squared_error - sum_error * mean_error raw_scores = 1 - (residual / diff) @@ -71,7 +69,7 @@ def _r2score_compute( if adjusted != 0: if adjusted > total - 1: rank_zero_warn( - "More independent regressions than datapoints in" + "More independent regressions than data points in" " adjusted r2 score. Falls back to standard r2 score.", UserWarning ) elif adjusted == total - 1: @@ -82,11 +80,11 @@ def _r2score_compute( def r2score( - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, adjusted: int = 0, multioutput: str = "uniform_average", -) -> torch.Tensor: +) -> Tensor: r""" Computes r2 score also known as `coefficient of determination `_: @@ -114,8 +112,19 @@ def r2score( * ``'uniform_average'`` scores are uniformly averaged * ``'variance_weighted'`` scores are weighted by their individual variances - Example: + Raises: + ValueError: + If both ``preds`` and ``targets`` are not ``1D`` or ``2D`` tensors. + ValueError: + If ``len(preds)`` is less than ``2`` + since at least ``2`` sampels are needed to calculate r2 score. + ValueError: + If ``multioutput`` is not one of ``raw_values``, + ``uniform_average`` or ``variance_weighted``. + ValueError: + If ``adjusted`` is not an ``integer`` greater than ``0``. + Example: >>> from torchmetrics.functional import r2score >>> target = torch.tensor([3, -0.5, 2, 7]) >>> preds = torch.tensor([2.5, 0.0, 2, 8]) diff --git a/torchmetrics/functional/regression/ssim.py b/torchmetrics/functional/regression/ssim.py index 630e0df490e..91204dc9428 100644 --- a/torchmetrics/functional/regression/ssim.py +++ b/torchmetrics/functional/regression/ssim.py @@ -14,6 +14,7 @@ from typing import Optional, Sequence, Tuple import torch +from torch import Tensor from torch.nn import functional as F from torchmetrics.utilities.checks import _check_same_shape @@ -36,10 +37,7 @@ def _gaussian_kernel( return kernel.expand(channel, 1, kernel_size[0], kernel_size[1]) -def _ssim_update( - preds: torch.Tensor, - target: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: +def _ssim_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: if preds.dtype != target.dtype: raise TypeError( "Expected `preds` and `target` to have the same data type." @@ -55,8 +53,8 @@ def _ssim_update( def _ssim_compute( - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, kernel_size: Sequence[int] = (11, 11), sigma: Sequence[float] = (1.5, 1.5), reduction: str = "elementwise_mean", @@ -114,15 +112,15 @@ def _ssim_compute( def ssim( - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, kernel_size: Sequence[int] = (11, 11), sigma: Sequence[float] = (1.5, 1.5), reduction: str = "elementwise_mean", data_range: Optional[float] = None, k1: float = 0.01, k2: float = 0.03, -) -> torch.Tensor: +) -> Tensor: """ Computes Structual Similarity Index Measure @@ -144,7 +142,20 @@ def ssim( Return: Tensor with SSIM score + Raises: + TypeError: + If ``preds`` and ``target`` don't have the same data type. + ValueError: + If ``preds`` and ``target`` don't have ``BxCxHxW shape``. + ValueError: + If the length of ``kernel_size`` or ``sigma`` is not ``2``. + ValueError: + If one of the elements of ``kernel_size`` is not an ``odd positive number``. + ValueError: + If one of the elements of ``sigma`` is not a ``positive number``. + Example: + >>> from torchmetrics.functional import ssim >>> preds = torch.rand([16, 1, 16, 16]) >>> target = preds * 0.75 >>> ssim(preds, target) diff --git a/torchmetrics/functional/retrieval/__init__.py b/torchmetrics/functional/retrieval/__init__.py new file mode 100644 index 00000000000..9ddcb8b729d --- /dev/null +++ b/torchmetrics/functional/retrieval/__init__.py @@ -0,0 +1,15 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision # noqa: F401 diff --git a/torchmetrics/functional/retrieval/average_precision.py b/torchmetrics/functional/retrieval/average_precision.py new file mode 100644 index 00000000000..43482965325 --- /dev/null +++ b/torchmetrics/functional/retrieval/average_precision.py @@ -0,0 +1,55 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from torch import Tensor + + +def retrieval_average_precision(preds: Tensor, target: Tensor) -> Tensor: + r""" + Computes average precision (for information retrieval), as explained + `here `_. + + `preds` and `target` should be of the same shape and live on the same device. If no `target` is ``True``, + 0 is returned. Target must be of type `bool` or `int`, otherwise an error is raised. + + Args: + preds: estimated probabilities of each document to be relevant. + target: ground truth about each document being relevant or not. Requires `bool` or `int` tensor. + + Return: + a single-value tensor with the average precision (AP) of the predictions `preds` wrt the labels `target`. + + Example: + >>> preds = torch.tensor([0.2, 0.3, 0.5]) + >>> target = torch.tensor([True, False, True]) + >>> retrieval_average_precision(preds, target) + tensor(0.8333) + """ + + if preds.shape != target.shape or preds.device != target.device: + raise ValueError("`preds` and `target` must have the same shape and live on the same device") + + if target.dtype not in (torch.bool, torch.int16, torch.int32, torch.int64): + raise ValueError("`target` must be a tensor of booleans or integers") + + if target.dtype is not torch.bool: + target = target.bool() + + if target.sum() == 0: + return torch.tensor(0, device=preds.device) + + target = target[torch.argsort(preds, dim=-1, descending=True)] + positions = torch.arange(1, len(target) + 1, device=target.device, dtype=torch.float32)[target > 0] + res = torch.div((torch.arange(len(positions), device=positions.device, dtype=torch.float32) + 1), positions).mean() + return res diff --git a/torchmetrics/functional/self_supervised.py b/torchmetrics/functional/self_supervised.py index de70ae5335f..5a634ffa74d 100644 --- a/torchmetrics/functional/self_supervised.py +++ b/torchmetrics/functional/self_supervised.py @@ -12,19 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch +from torch import Tensor def embedding_similarity( - batch: torch.Tensor, - similarity: str = 'cosine', - reduction: str = 'none', - zero_diagonal: bool = True -) -> torch.Tensor: + batch: Tensor, similarity: str = 'cosine', reduction: str = 'none', zero_diagonal: bool = True +) -> Tensor: """ Computes representation similarity Example: - + >>> from torchmetrics.functional import embedding_similarity >>> embeddings = torch.tensor([[1., 2., 3., 4.], [1., 2., 3., 4.], [4., 5., 6., 7.]]) >>> embedding_similarity(embeddings) tensor([[0.0000, 1.0000, 0.9759], diff --git a/torchmetrics/info.py b/torchmetrics/info.py new file mode 100644 index 00000000000..f380c8c14ff --- /dev/null +++ b/torchmetrics/info.py @@ -0,0 +1,16 @@ +__version__ = '0.2.1dev' +__author__ = 'PyTorchLightning et al.' +__author_email__ = 'name@pytorchlightning.ai' +__license__ = 'Apache-2.0' +__copyright__ = f'Copyright (c) 2020-2021, {__author__}.' +__homepage__ = 'https://github.com/PyTorchLightning/metrics' +__docs__ = "PyTorch native Metrics" +__long_doc__ = """ +Torchmetrics is a metrics API created for easy metric development and usage in both PyTorch and +[PyTorch Lightning](https://pytorch-lightning.readthedocs.io/en/stable/). It was originally a part of +Pytorch Lightning, but got split off so users could take advantage of the large collection of metrics +implemented without having to install Pytorch Lightning (even though we would love for you to try it out). +We currently have around 25+ metrics implemented and we continuously is adding more metrics, both within +already covered domains (classification, regression ect.) but also new domains (object detection ect.). +We make sure that all our metrics are rigorously tested such that you can trust them. +""" diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index bbe7242ebe9..466ed293d0b 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -19,7 +19,7 @@ from typing import Any, Callable, Optional, Union import torch -from torch import nn +from torch import Tensor, nn from torchmetrics.utilities import apply_to_collection from torchmetrics.utilities.data import _flatten, dim_zero_cat, dim_zero_mean, dim_zero_sum @@ -95,7 +95,7 @@ def add_state( name: The name of the state variable. The variable will then be accessible at ``self.name``. default: Default value of the state; can either be a ``torch.Tensor`` or an empty list. The state will be reset to this value when ``self.reset()`` is called. - dist_reduce_fx (Optional): Function to reduce state accross mutliple processes in distributed mode. + dist_reduce_fx (Optional): Function to reduce state across multiple processes in distributed mode. If value is ``"sum"``, ``"mean"``, or ``"cat"``, we will use ``torch.sum``, ``torch.mean``, and ``torch.cat`` respectively, each with argument ``dim=0``. Note that the ``"cat"`` reduction only makes sense if the state is a list, and not a tensor. The user can also pass a custom @@ -120,9 +120,14 @@ def add_state( When passing a custom function to ``dist_reduce_fx``, expect the synchronized metric state to follow the format discussed in the above note. + Raises: + ValueError: + If ``default`` is not a ``tensor`` or an ``empty list``. + ValueError: + If ``dist_reduce_fx`` is not callable or one of ``"mean"``, ``"sum"``, ``"cat"``, ``None``. """ if ( - not isinstance(default, torch.Tensor) and not isinstance(default, list) # noqa: W503 + not isinstance(default, Tensor) and not isinstance(default, list) # noqa: W503 or (isinstance(default, list) and len(default) != 0) # noqa: W503 ): raise ValueError("state variable must be a tensor or any empty list (where you can append tensors)") @@ -175,19 +180,19 @@ def _sync_dist(self, dist_sync_fn=gather_all_tensors): input_dict = {attr: getattr(self, attr) for attr in self._reductions.keys()} output_dict = apply_to_collection( input_dict, - torch.Tensor, + Tensor, dist_sync_fn, group=self.process_group, ) for attr, reduction_fn in self._reductions.items(): # pre-processing ops (stack or flatten for inputs) - if isinstance(output_dict[attr][0], torch.Tensor): + if isinstance(output_dict[attr][0], Tensor): output_dict[attr] = torch.stack(output_dict[attr]) elif isinstance(output_dict[attr][0], list): output_dict[attr] = _flatten(output_dict[attr]) - assert isinstance(reduction_fn, (Callable)) or reduction_fn is None + assert isinstance(reduction_fn, Callable) or reduction_fn is None reduced = reduction_fn(output_dict[attr]) if reduction_fn is not None else output_dict[attr] setattr(self, attr, reduced) @@ -214,6 +219,7 @@ def wrapped_func(*args, **kwargs): dist_sync_fn = gather_all_tensors synced = False + cache = [] if self._to_sync and dist_sync_fn is not None: # cache prior to syncing cache = {attr: getattr(self, attr) for attr in self._defaults.keys()} @@ -253,7 +259,7 @@ def reset(self): """ for attr, default in self._defaults.items(): current_val = getattr(self, attr) - if isinstance(default, torch.Tensor): + if isinstance(default, Tensor): setattr(self, attr, deepcopy(default).to(current_val.device)) else: setattr(self, attr, deepcopy(default)) @@ -276,20 +282,20 @@ def _apply(self, fn): """Overwrite _apply function such that we can also move metric states to the correct device when `.to`, `.cuda`, etc methods are called """ - self = super()._apply(fn) + this = super()._apply(fn) # Also apply fn to metric states - for key in self._defaults.keys(): - current_val = getattr(self, key) - if isinstance(current_val, torch.Tensor): - setattr(self, key, fn(current_val)) + for key in this._defaults.keys(): + current_val = getattr(this, key) + if isinstance(current_val, Tensor): + setattr(this, key, fn(current_val)) elif isinstance(current_val, Sequence): - setattr(self, key, [fn(cur_v) for cur_v in current_val]) + setattr(this, key, [fn(cur_v) for cur_v in current_val]) else: raise TypeError( - "Expected metric state to be either a torch.Tensor" - f"or a list of torch.Tensor, but encountered {current_val}" + "Expected metric state to be either a Tensor" + f"or a list of Tensor, but encountered {current_val}" ) - return self + return this def persistent(self, mode: bool = False): """Method for post-init to change if metric states should be saved to @@ -336,7 +342,7 @@ def __hash__(self): val = getattr(self, key) # Special case: allow list values, so long # as their elements are hashable - if hasattr(val, '__iter__') and not isinstance(val, torch.Tensor): + if hasattr(val, '__iter__') and not isinstance(val, Tensor): hash_vals.extend(val) else: hash_vals.append(val) @@ -444,18 +450,18 @@ def __pos__(self): return CompositionalMetric(torch.abs, self, None) -def _neg(tensor: torch.Tensor): +def _neg(tensor: Tensor): return -torch.abs(tensor) class CompositionalMetric(Metric): - """Composition of two metrics with a specific operator which will be executed upon metric's compute """ + """Composition of two metrics with a specific operator which will be executed upon metrics compute """ def __init__( self, operator: Callable, - metric_a: Union[Metric, int, float, torch.Tensor], - metric_b: Union[Metric, int, float, torch.Tensor, None], + metric_a: Union[Metric, int, float, Tensor], + metric_b: Union[Metric, int, float, Tensor, None], ): """ Args: @@ -470,12 +476,12 @@ def __init__( self.op = operator - if isinstance(metric_a, torch.Tensor): + if isinstance(metric_a, Tensor): self.register_buffer("metric_a", metric_a) else: self.metric_a = metric_a - if isinstance(metric_b, torch.Tensor): + if isinstance(metric_b, Tensor): self.register_buffer("metric_b", metric_b) else: self.metric_b = metric_b diff --git a/torchmetrics/regression/explained_variance.py b/torchmetrics/regression/explained_variance.py index 350f91c75cc..a76980135dc 100644 --- a/torchmetrics/regression/explained_variance.py +++ b/torchmetrics/regression/explained_variance.py @@ -14,6 +14,7 @@ from typing import Any, Callable, Optional import torch +from torch import Tensor, tensor from torchmetrics.functional.regression.explained_variance import ( _explained_variance_compute, @@ -58,8 +59,11 @@ class ExplainedVariance(Metric): process_group: Specify the process group on which synchronization is called. default: None (which selects the entire world) - Example: + Raises: + ValueError: + If ``multioutput`` is not one of ``"raw_values"``, ``"uniform_average"`` or ``"variance_weighted"``. + Example: >>> from torchmetrics import ExplainedVariance >>> target = torch.tensor([3, -0.5, 2, 7]) >>> preds = torch.tensor([2.5, 0.0, 2, 8]) @@ -94,13 +98,13 @@ def __init__( f"Invalid input to argument `multioutput`. Choose one of the following: {allowed_multioutput}" ) self.multioutput = multioutput - self.add_state("sum_error", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("sum_target", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("sum_squared_target", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("n_obs", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("sum_error", default=tensor(0.0), dist_reduce_fx="sum") + self.add_state("sum_squared_error", default=tensor(0.0), dist_reduce_fx="sum") + self.add_state("sum_target", default=tensor(0.0), dist_reduce_fx="sum") + self.add_state("sum_squared_target", default=tensor(0.0), dist_reduce_fx="sum") + self.add_state("n_obs", default=tensor(0.0), dist_reduce_fx="sum") - def update(self, preds: torch.Tensor, target: torch.Tensor): + def update(self, preds: Tensor, target: Tensor): """ Update state with predictions and targets. diff --git a/torchmetrics/regression/mean_absolute_error.py b/torchmetrics/regression/mean_absolute_error.py index 96a1ada06db..470394a52b4 100644 --- a/torchmetrics/regression/mean_absolute_error.py +++ b/torchmetrics/regression/mean_absolute_error.py @@ -14,6 +14,7 @@ from typing import Any, Callable, Optional import torch +from torch import Tensor, tensor from torchmetrics.functional.regression.mean_absolute_error import ( _mean_absolute_error_compute, @@ -40,7 +41,6 @@ class MeanAbsoluteError(Metric): Specify the process group on which synchronization is called. default: None (which selects the entire world) Example: - >>> from torchmetrics import MeanAbsoluteError >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) >>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0]) @@ -63,10 +63,10 @@ def __init__( dist_sync_fn=dist_sync_fn, ) - self.add_state("sum_abs_error", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state("sum_abs_error", default=tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=tensor(0), dist_reduce_fx="sum") - def update(self, preds: torch.Tensor, target: torch.Tensor): + def update(self, preds: Tensor, target: Tensor): """ Update state with predictions and targets. diff --git a/torchmetrics/regression/mean_squared_error.py b/torchmetrics/regression/mean_squared_error.py index eae78b7f7ae..4a71639ea35 100644 --- a/torchmetrics/regression/mean_squared_error.py +++ b/torchmetrics/regression/mean_squared_error.py @@ -14,6 +14,7 @@ from typing import Any, Callable, Optional import torch +from torch import Tensor, tensor from torchmetrics.functional.regression.mean_squared_error import ( _mean_squared_error_compute, @@ -40,7 +41,6 @@ class MeanSquaredError(Metric): Specify the process group on which synchronization is called. default: None (which selects the entire world) Example: - >>> from torchmetrics import MeanSquaredError >>> target = torch.tensor([2.5, 5.0, 4.0, 8.0]) >>> preds = torch.tensor([3.0, 5.0, 2.5, 7.0]) @@ -64,10 +64,10 @@ def __init__( dist_sync_fn=dist_sync_fn, ) - self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state("sum_squared_error", default=tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=tensor(0), dist_reduce_fx="sum") - def update(self, preds: torch.Tensor, target: torch.Tensor): + def update(self, preds: Tensor, target: Tensor): """ Update state with predictions and targets. diff --git a/torchmetrics/regression/mean_squared_log_error.py b/torchmetrics/regression/mean_squared_log_error.py index 2adbda0a2c9..9d2ae75f6f5 100644 --- a/torchmetrics/regression/mean_squared_log_error.py +++ b/torchmetrics/regression/mean_squared_log_error.py @@ -14,6 +14,7 @@ from typing import Any, Callable, Optional import torch +from torch import Tensor, tensor from torchmetrics.functional.regression.mean_squared_log_error import ( _mean_squared_log_error_compute, @@ -42,7 +43,6 @@ class MeanSquaredLogError(Metric): Specify the process group on which synchronization is called. default: None (which selects the entire world) Example: - >>> from torchmetrics import MeanSquaredLogError >>> target = torch.tensor([2.5, 5, 4, 8]) >>> preds = torch.tensor([3, 5, 2.5, 7]) @@ -66,10 +66,10 @@ def __init__( dist_sync_fn=dist_sync_fn, ) - self.add_state("sum_squared_log_error", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state("sum_squared_log_error", default=tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=tensor(0), dist_reduce_fx="sum") - def update(self, preds: torch.Tensor, target: torch.Tensor): + def update(self, preds: Tensor, target: Tensor): """ Update state with predictions and targets. diff --git a/torchmetrics/regression/psnr.py b/torchmetrics/regression/psnr.py index 1a9e42f8c7b..238295be681 100644 --- a/torchmetrics/regression/psnr.py +++ b/torchmetrics/regression/psnr.py @@ -14,6 +14,7 @@ from typing import Any, Optional, Sequence, Tuple, Union import torch +from torch import Tensor, tensor from torchmetrics.functional.regression.psnr import _psnr_compute, _psnr_update from torchmetrics.metric import Metric @@ -51,8 +52,11 @@ class PSNR(Metric): process_group: Specify the process group on which synchronization is called. default: None (which selects the entire world) - Example: + Raises: + ValueError: + If ``dim`` is not ``None`` and ``data_range`` is not given. + Example: >>> from torchmetrics import PSNR >>> psnr = PSNR() >>> preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]]) @@ -82,8 +86,8 @@ def __init__( rank_zero_warn(f'The `reduction={reduction}` will not have any effect when `dim` is None.') if dim is None: - self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state("sum_squared_error", default=tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=tensor(0), dist_reduce_fx="sum") else: self.add_state("sum_squared_error", default=[]) self.add_state("total", default=[]) @@ -95,15 +99,15 @@ def __init__( raise ValueError("The `data_range` must be given when `dim` is not None.") self.data_range = None - self.add_state("min_target", default=torch.tensor(0.0), dist_reduce_fx=torch.min) - self.add_state("max_target", default=torch.tensor(0.0), dist_reduce_fx=torch.max) + self.add_state("min_target", default=tensor(0.0), dist_reduce_fx=torch.min) + self.add_state("max_target", default=tensor(0.0), dist_reduce_fx=torch.max) else: - self.register_buffer("data_range", torch.tensor(float(data_range))) + self.register_buffer("data_range", tensor(float(data_range))) self.base = base self.reduction = reduction self.dim = tuple(dim) if isinstance(dim, Sequence) else dim - def update(self, preds: torch.Tensor, target: torch.Tensor): + def update(self, preds: Tensor, target: Tensor): """ Update state with predictions and targets. diff --git a/torchmetrics/regression/r2score.py b/torchmetrics/regression/r2score.py index fb2309f5ee3..c5ee7a8b534 100644 --- a/torchmetrics/regression/r2score.py +++ b/torchmetrics/regression/r2score.py @@ -14,6 +14,7 @@ from typing import Any, Callable, Optional import torch +from torch import Tensor, tensor from torchmetrics.functional.regression.r2score import _r2score_compute, _r2score_update from torchmetrics.metric import Metric @@ -66,8 +67,13 @@ class R2Score(Metric): process_group: Specify the process group on which synchronization is called. default: None (which selects the entire world) - Example: + Raises: + ValueError: + If ``adjusted`` parameter is not an integer larger or equal to 0. + ValueError: + If ``multioutput`` is not one of ``"raw_values"``, ``"uniform_average"`` or ``"variance_weighted"``. + Example: >>> from torchmetrics import R2Score >>> target = torch.tensor([3, -0.5, 2, 7]) >>> preds = torch.tensor([2.5, 0.0, 2, 8]) @@ -102,7 +108,7 @@ def __init__( self.num_outputs = num_outputs if adjusted < 0 or not isinstance(adjusted, int): - raise ValueError('`adjusted` parameter should be an integer larger or' ' equal to 0.') + raise ValueError('`adjusted` parameter should be an integer larger or equal to 0.') self.adjusted = adjusted allowed_multioutput = ('raw_values', 'uniform_average', 'variance_weighted') @@ -115,9 +121,9 @@ def __init__( self.add_state("sum_squared_error", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum") self.add_state("sum_error", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum") self.add_state("residual", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum") - self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state("total", default=tensor(0), dist_reduce_fx="sum") - def update(self, preds: torch.Tensor, target: torch.Tensor): + def update(self, preds: Tensor, target: Tensor): """ Update state with predictions and targets. @@ -132,7 +138,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor): self.residual += residual self.total += total - def compute(self) -> torch.Tensor: + def compute(self) -> Tensor: """ Computes r2 score over the metric states. """ diff --git a/torchmetrics/regression/ssim.py b/torchmetrics/regression/ssim.py index 78b93c6fe7e..886da1abf3a 100644 --- a/torchmetrics/regression/ssim.py +++ b/torchmetrics/regression/ssim.py @@ -14,6 +14,7 @@ from typing import Any, Optional, Sequence import torch +from torch import Tensor from torchmetrics.functional.regression.ssim import _ssim_compute, _ssim_update from torchmetrics.metric import Metric @@ -82,7 +83,7 @@ def __init__( self.k2 = k2 self.reduction = reduction - def update(self, preds: torch.Tensor, target: torch.Tensor): + def update(self, preds: Tensor, target: Tensor): """ Update state with predictions and targets. diff --git a/torchmetrics/retrieval/__init__.py b/torchmetrics/retrieval/__init__.py new file mode 100644 index 00000000000..9db84787f6f --- /dev/null +++ b/torchmetrics/retrieval/__init__.py @@ -0,0 +1,15 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from torchmetrics.retrieval.mean_average_precision import RetrievalMAP # noqa: F401 +from torchmetrics.retrieval.retrieval_metric import RetrievalMetric # noqa: F401 diff --git a/torchmetrics/retrieval/mean_average_precision.py b/torchmetrics/retrieval/mean_average_precision.py new file mode 100644 index 00000000000..383bea70a41 --- /dev/null +++ b/torchmetrics/retrieval/mean_average_precision.py @@ -0,0 +1,75 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from torch import Tensor + +from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision +from torchmetrics.retrieval.retrieval_metric import RetrievalMetric + + +class RetrievalMAP(RetrievalMetric): + r""" + Computes `Mean Average Precision + `_. + + Works with binary data. Accepts integer or float predictions from a model output. + + Forward accepts + - ``indexes`` (long tensor): ``(N, ...)`` + - ``preds`` (float tensor): ``(N, ...)`` + - ``target`` (long or bool tensor): ``(N, ...)`` + + `indexes`, `preds` and `target` must have the same dimension. + `indexes` indicate to which query a prediction belongs. + Predictions will be first grouped by indexes and then MAP will be computed as the mean + of the Average Precisions over each query. + + Args: + query_without_relevant_docs: + Specify what to do with queries that do not have at least a positive target. Choose from: + + - ``'skip'``: skip those queries (default); if all queries are skipped, ``0.0`` is returned + - ``'error'``: raise a ``ValueError`` + - ``'pos'``: score on those queries is counted as ``1.0`` + - ``'neg'``: score on those queries is counted as ``0.0`` + exclude: + Do not take into account predictions where the target is equal to this value. default `-100` + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. default: False + process_group: + Specify the process group on which synchronization is called. default: None (which selects + the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When `None`, DDP + will be used to perform the allgather. default: None + + Example: + >>> from torchmetrics import RetrievalMAP + >>> indexes = torch.tensor([0, 0, 0, 1, 1, 1, 1]) + >>> preds = torch.tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) + >>> target = torch.tensor([False, False, True, False, True, False, False]) + + >>> map = RetrievalMAP() + >>> map(indexes, preds, target) + tensor(0.7500) + >>> map.compute() + tensor(0.7500) + """ + + def _metric(self, preds: Tensor, target: Tensor) -> Tensor: + valid_indexes = target != self.exclude + return retrieval_average_precision(preds[valid_indexes], target[valid_indexes]) diff --git a/torchmetrics/retrieval/retrieval_metric.py b/torchmetrics/retrieval/retrieval_metric.py new file mode 100644 index 00000000000..3937343e7ca --- /dev/null +++ b/torchmetrics/retrieval/retrieval_metric.py @@ -0,0 +1,153 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC, abstractmethod +from typing import Any, Callable, Optional + +import torch +from torch import Tensor + +from torchmetrics import Metric +from torchmetrics.utilities.data import get_group_indexes + +#: get_group_indexes is used to group predictions belonging to the same query +IGNORE_IDX = -100 + + +class RetrievalMetric(Metric, ABC): + r""" + Works with binary data. Accepts integer or float predictions from a model output. + + Forward accepts + - ``indexes`` (long tensor): ``(N, ...)`` + - ``preds`` (float or int tensor): ``(N, ...)`` + - ``target`` (long or bool tensor): ``(N, ...)`` + + `indexes`, `preds` and `target` must have the same dimension and will be flatten + to single dimension once provided. + + `indexes` indicate to which query a prediction belongs. + Predictions will be first grouped by indexes. Then the + real metric, defined by overriding the `_metric` method, + will be computed as the mean of the scores over each query. + + Args: + query_without_relevant_docs: + Specify what to do with queries that do not have at least a positive target. Choose from: + + - ``'skip'``: skip those queries (default); if all queries are skipped, ``0.0`` is returned + - ``'error'``: raise a ``ValueError`` + - ``'pos'``: score on those queries is counted as ``1.0`` + - ``'neg'``: score on those queries is counted as ``0.0`` + exclude: + Do not take into account predictions where the target is equal to this value. default `-100` + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. default: False + process_group: + Specify the process group on which synchronization is called. default: None (which selects + the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When `None`, DDP + will be used to perform the allgather. default: None + + """ + + def __init__( + self, + query_without_relevant_docs: str = 'skip', + exclude: int = IGNORE_IDX, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None + ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn + ) + + query_without_relevant_docs_options = ('error', 'skip', 'pos', 'neg') + if query_without_relevant_docs not in query_without_relevant_docs_options: + raise ValueError( + f"`query_without_relevant_docs` received a wrong value {query_without_relevant_docs}. " + f"Allowed values are {query_without_relevant_docs_options}" + ) + + self.query_without_relevant_docs = query_without_relevant_docs + self.exclude = exclude + + self.add_state("idx", default=[], dist_reduce_fx=None) + self.add_state("preds", default=[], dist_reduce_fx=None) + self.add_state("target", default=[], dist_reduce_fx=None) + + def update(self, idx: Tensor, preds: Tensor, target: Tensor) -> None: + if not (idx.shape == target.shape == preds.shape): + raise ValueError("`idx`, `preds` and `target` must be of the same shape") + + idx = idx.to(dtype=torch.int64).flatten() + preds = preds.to(dtype=torch.float32).flatten() + target = target.to(dtype=torch.int64).flatten() + + self.idx.append(idx) + self.preds.append(preds) + self.target.append(target) + + def compute(self) -> Tensor: + r""" + First concat state `idx`, `preds` and `target` since they were stored as lists. After that, + compute list of groups that will help in keeping together predictions about the same query. + Finally, for each group compute the `_metric` if the number of positive targets is at least + 1, otherwise behave as specified by `self.query_without_relevant_docs`. + """ + + idx = torch.cat(self.idx, dim=0) + preds = torch.cat(self.preds, dim=0) + target = torch.cat(self.target, dim=0) + + res = [] + kwargs = {'device': idx.device, 'dtype': torch.float32} + + groups = get_group_indexes(idx) + for group in groups: + + mini_preds = preds[group] + mini_target = target[group] + + if not mini_target.sum(): + if self.query_without_relevant_docs == 'error': + raise ValueError( + f"`{self.__class__.__name__}.compute()` was provided with " + f"a query without positive targets, indexes: {group}" + ) + if self.query_without_relevant_docs == 'pos': + res.append(torch.tensor(1.0, **kwargs)) + elif self.query_without_relevant_docs == 'neg': + res.append(torch.tensor(0.0, **kwargs)) + else: + res.append(self._metric(mini_preds, mini_target)) + + if len(res) > 0: + return torch.stack(res).mean() + return torch.tensor(0.0, **kwargs) + + @abstractmethod + def _metric(self, preds: Tensor, target: Tensor) -> Tensor: + r""" + Compute a metric over a predictions and target of a single group. + This method should be overridden by subclasses. + """ diff --git a/torchmetrics/setup_tools.py b/torchmetrics/setup_tools.py index dfde2bc6e42..60e0bf22fa5 100644 --- a/torchmetrics/setup_tools.py +++ b/torchmetrics/setup_tools.py @@ -15,14 +15,14 @@ import re from typing import List -from torchmetrics import _PROJECT_ROOT, __homepage__, __version__ +_PROJECT_ROOT = os.path.dirname(os.path.dirname(__file__)) def _load_requirements(path_dir: str, file_name: str = 'requirements.txt', comment_char: str = '#') -> List[str]: """Load requirements from a file >>> _load_requirements(_PROJECT_ROOT) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE - ['torch...'] + ['numpy...', 'torch...'] """ with open(os.path.join(path_dir, file_name), 'r') as file: lines = [ln.strip() for ln in file.readlines()] @@ -39,10 +39,10 @@ def _load_requirements(path_dir: str, file_name: str = 'requirements.txt', comme return reqs -def _load_readme_description(path_dir: str, homepage: str = __homepage__, version: str = __version__) -> str: +def _load_readme_description(path_dir: str, homepage: str, version: str) -> str: """Load readme as decribtion - >>> _load_readme_description(_PROJECT_ROOT) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + >>> _load_readme_description(_PROJECT_ROOT, "", "") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE '
...' """ path_readme = os.path.join(path_dir, "README.md") diff --git a/torchmetrics/utilities/checks.py b/torchmetrics/utilities/checks.py index 32dd1ddfb53..78b1e972d9e 100644 --- a/torchmetrics/utilities/checks.py +++ b/torchmetrics/utilities/checks.py @@ -14,18 +14,19 @@ from typing import Optional, Tuple import torch +from torch import Tensor from torchmetrics.utilities.data import select_topk, to_onehot from torchmetrics.utilities.enums import DataType -def _check_same_shape(pred: torch.Tensor, target: torch.Tensor): +def _check_same_shape(pred: Tensor, target: Tensor): """ Check that predictions and target have the same shape, else raise error """ if pred.shape != target.shape: raise RuntimeError("Predictions and targets are expected to have the same shape") -def _basic_input_validation(preds: torch.Tensor, target: torch.Tensor, threshold: float, is_multiclass: bool): +def _basic_input_validation(preds: Tensor, target: Tensor, threshold: float, is_multiclass: bool): """ Perform basic validation of inputs that does not require deducing any information of the type of inputs. @@ -56,7 +57,7 @@ def _basic_input_validation(preds: torch.Tensor, target: torch.Tensor, threshold raise ValueError("If you set `is_multiclass=False` and `preds` are integers, then `preds` should not exceed 1.") -def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) -> Tuple[str, int]: +def _check_shape_and_type_consistency(preds: Tensor, target: Tensor) -> Tuple[str, int]: """ This checks that the shape and type of inputs are consistent with each other and fall into one of the allowed input types (see the @@ -139,9 +140,7 @@ def _check_num_classes_binary(num_classes: int, is_multiclass: bool): ) -def _check_num_classes_mc( - preds: torch.Tensor, target: torch.Tensor, num_classes: int, is_multiclass: bool, implied_classes: int -): +def _check_num_classes_mc(preds: Tensor, target: Tensor, num_classes: int, is_multiclass: bool, implied_classes: int): """ This checks that the consistency of `num_classes` with the data and `is_multiclass` param for (multi-dimensional) multi-class data. @@ -206,8 +205,8 @@ def _check_top_k(top_k: int, case: str, implied_classes: int, is_multiclass: Opt def _check_classification_inputs( - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, threshold: float, num_classes: Optional[int], is_multiclass: bool, @@ -222,8 +221,8 @@ def _check_classification_inputs( In case where preds are floats (probabilities), it is checked whether they are in [0,1] interval. - When ``num_classes`` is given, it is checked that it is consitent with input cases (binary, - multi-label, ...), and that, if availible, the implied number of classes in the ``C`` + When ``num_classes`` is given, it is checked that it is consistent with input cases (binary, + multi-label, ...), and that, if available, the implied number of classes in the ``C`` dimension is consistent with it (as well as that max label in target is smaller than it). When ``num_classes`` is not specified in these cases, consistency of the highest target @@ -243,13 +242,13 @@ def _check_classification_inputs( Threshold probability value for transforming probability predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. num_classes: - Number of classes. If not explicitly set, the number of classes will be infered + Number of classes. If not explicitly set, the number of classes will be inferred either from the shape of inputs, or the maximum label in the ``target`` and ``preds`` tensor, where applicable. top_k: Number of highest probability entries for each sample to convert to 1s - relevant only for inputs with probability predictions. The default value (``None``) will be - interepreted as 1 for these inputs. If this parameter is set for multi-label inputs, + interpreted as 1 for these inputs. If this parameter is set for multi-label inputs, it will take precedence over threshold. Should be left unset (``None``) for inputs with label predictions. @@ -265,7 +264,7 @@ def _check_classification_inputs( 'multi-dim multi-class' """ - # Baisc validation (that does not need case/type information) + # Basic validation (that does not need case/type information) _basic_input_validation(preds, target, threshold, is_multiclass) # Check that shape/types fall into one of the cases @@ -274,7 +273,7 @@ def _check_classification_inputs( # For (multi-dim) multi-class case with prob preds, check that preds sum up to 1 if case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) and preds.is_floating_point(): if not torch.isclose(preds.sum(dim=1), torch.ones_like(preds.sum(dim=1))).all(): - raise ValueError("Probabilities in `preds` must sum up to 1 accross the `C` dimension.") + raise ValueError("Probabilities in `preds` must sum up to 1 across the `C` dimension.") # Check consistency with the `C` dimension in case of multi-class data if preds.shape != target.shape: @@ -305,13 +304,13 @@ def _check_classification_inputs( def _input_format_classification( - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, threshold: float = 0.5, top_k: Optional[int] = None, num_classes: Optional[int] = None, is_multiclass: Optional[bool] = None, -) -> Tuple[torch.Tensor, torch.Tensor, str]: +) -> Tuple[Tensor, Tensor, str]: """Convert preds and target tensors into common format. Preds and targets are supposed to fall into one of these categories (and are @@ -371,7 +370,7 @@ def _input_format_classification( Threshold probability value for transforming probability predictions to binary (0 or 1) predictions, in the case of binary or multi-label inputs. num_classes: - Number of classes. If not explicitly set, the number of classes will be infered + Number of classes. If not explicitly set, the number of classes will be inferred either from the shape of inputs, or the maximum label in the ``target`` and ``preds`` tensor, where applicable. top_k: @@ -439,7 +438,7 @@ def _input_format_classification( target = target.reshape(target.shape[0], -1) preds = preds.reshape(preds.shape[0], -1) - # Some operatins above create an extra dimension for MC/binary case - this removes it + # Some operations above create an extra dimension for MC/binary case - this removes it if preds.ndim > 2: preds, target = preds.squeeze(-1), target.squeeze(-1) @@ -448,11 +447,11 @@ def _input_format_classification( def _input_format_classification_one_hot( num_classes: int, - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, threshold: float = 0.5, multilabel: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[Tensor, Tensor]: """Convert preds and target tensors into one hot spare label tensors Args: @@ -462,6 +461,11 @@ def _input_format_classification_one_hot( threshold: float used for thresholding multilabel input multilabel: boolean flag indicating if input is multilabel + Raises: + ValueError: + If ``preds`` and ``target`` don't have the same number of dimensions + or one additional dimension for ``preds``. + Returns: preds: one hot tensor of shape [num_classes, -1] with predicted labels target: one hot tensors of shape [num_classes, -1] with true labels @@ -470,7 +474,7 @@ def _input_format_classification_one_hot( raise ValueError("preds and target must have same number of dimensions, or one additional dimension for preds") if preds.ndim == target.ndim + 1: - # multi class probabilites + # multi class probabilities preds = torch.argmax(preds, dim=1) if preds.ndim == target.ndim and preds.dtype in (torch.long, torch.int) and num_classes > 1 and not multilabel: @@ -479,7 +483,7 @@ def _input_format_classification_one_hot( target = to_onehot(target, num_classes=num_classes) elif preds.ndim == target.ndim and preds.is_floating_point(): - # binary or multilabel probablities + # binary or multilabel probabilities preds = (preds >= threshold).long() # transpose class as first dim and reshape diff --git a/torchmetrics/utilities/data.py b/torchmetrics/utilities/data.py index 2ca60beb366..32dfd01e44f 100644 --- a/torchmetrics/utilities/data.py +++ b/torchmetrics/utilities/data.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Mapping, Optional, Sequence, Union +from typing import Any, Callable, List, Mapping, Optional, Sequence, Union import torch +from torch import Tensor from torchmetrics.utilities.prints import rank_zero_warn @@ -38,9 +39,9 @@ def _flatten(x): def to_onehot( - label_tensor: torch.Tensor, + label_tensor: Tensor, num_classes: Optional[int] = None, -) -> torch.Tensor: +) -> Tensor: """ Converts a dense label tensor to one-hot format @@ -48,11 +49,10 @@ def to_onehot( label_tensor: dense label tensor, with shape [N, d1, d2, ...] num_classes: number of classes C - Output: + Returns: A sparse label tensor with shape [N, C, d1, d2, ...] Example: - >>> x = torch.tensor([1, 2, 3]) >>> to_onehot(x) tensor([[0, 1, 0, 0], @@ -74,7 +74,7 @@ def to_onehot( return tensor_onehot.scatter_(1, index, 1.0) -def select_topk(prob_tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch.Tensor: +def select_topk(prob_tensor: Tensor, topk: int = 1, dim: int = 1) -> Tensor: """ Convert a probability tensor to binary by selecting top-k highest entries. @@ -84,11 +84,10 @@ def select_topk(prob_tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch topk: number of highest entries to turn into 1s dim: dimension on which to compare entries - Output: + Returns: A binary tensor of the same shape as the input tensor of type torch.int32 Example: - >>> x = torch.tensor([[1.1, 2.0, 3.0], [2.0, 1.0, 0.5]]) >>> select_topk(x, topk=2) tensor([[0, 1, 1], @@ -99,7 +98,7 @@ def select_topk(prob_tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch return topk_tensor.int() -def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor: +def to_categorical(tensor: Tensor, argmax_dim: int = 1) -> Tensor: """ Converts a tensor of probabilities to a dense label tensor @@ -111,7 +110,6 @@ def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor: A tensor with categorical labels [N, d2, ...] Example: - >>> x = torch.tensor([[0.2, 0.5], [0.9, 0.1]]) >>> to_categorical(x) tensor([1, 0]) @@ -121,15 +119,15 @@ def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor: def get_num_classes( - pred: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, num_classes: Optional[int] = None, ) -> int: """ Calculates the number of classes for a given prediction and target tensor. Args: - pred: predicted values + preds: predicted values target: true labels num_classes: number of classes if known @@ -137,7 +135,7 @@ def get_num_classes( An integer that represents the number of classes. """ num_target_classes = int(target.max().detach().item() + 1) - num_pred_classes = int(pred.max().detach().item() + 1) + num_pred_classes = int(preds.max().detach().item() + 1) num_all_classes = max(num_target_classes, num_pred_classes) if num_classes is None: @@ -159,23 +157,26 @@ def _stable_1d_sort(x: torch, nb: int = 2049): makes the sort and returns the sorted array (with the padding removed) See this discussion: https://discuss.pytorch.org/t/is-torch-sort-stable/20714 + Raises: + ValueError: + If dim of ``x`` is greater than 1 since stable sort works with only 1d tensors. + Example: >>> data = torch.tensor([8, 7, 2, 6, 4, 5, 3, 1, 9, 0]) >>> _stable_1d_sort(data) (tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), tensor([9, 7, 2, 6, 4, 5, 3, 1, 0, 8])) >>> _stable_1d_sort(data, nb=5) - (tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), tensor([9, 7, 2, 6, 4, 5, 3, 1, 0, 8])) + (tensor([0, 1, 2, 3, 4]), tensor([9, 7, 2, 6, 4])) """ if x.ndim > 1: raise ValueError('Stable sort only works on 1d tensors') n = x.numel() if n < nb: x_max = x.max() - x_pad = torch.cat([x, (x_max + 1) * torch.ones(2049 - n, dtype=x.dtype, device=x.device)], 0) - x_sort = x_pad.sort() - else: - x_sort = x.sort() - return x_sort.values[:n], x_sort.indices[:n] + x = torch.cat([x, (x_max + 1) * torch.ones(nb - n, dtype=x.dtype, device=x.device)], 0) + x_sort = x.sort() + i = min(nb, n) + return x_sort.values[:i], x_sort.indices[:i] def apply_to_collection( @@ -202,7 +203,7 @@ def apply_to_collection( the resulting collection Example: - >>> apply_to_collection(torch.tensor([8, 0, 2, 6, 7]), dtype=torch.Tensor, function=lambda x: x ** 2) + >>> apply_to_collection(torch.tensor([8, 0, 2, 6, 7]), dtype=Tensor, function=lambda x: x ** 2) tensor([64, 0, 4, 36, 49]) >>> apply_to_collection([8, 0, 2, 6, 7], dtype=int, function=lambda x: x ** 2) [64, 0, 4, 36, 49] @@ -227,3 +228,32 @@ def apply_to_collection( # data is neither of dtype, nor a collection return data + + +def get_group_indexes(idx: Tensor) -> List[Tensor]: + """ + Given an integer `torch.Tensor` `idx`, return a `torch.Tensor` of indexes for + each different value in `idx`. + + Args: + idx: a `torch.Tensor` of integers + + Return: + A list of integer `torch.Tensor`s + + Example: + + >>> indexes = torch.tensor([0, 0, 0, 1, 1, 1, 1]) + >>> groups = get_group_indexes(indexes) + >>> groups + [tensor([0, 1, 2]), tensor([3, 4, 5, 6])] + """ + + indexes = dict() + for i, _id in enumerate(idx): + _id = _id.item() + if _id in indexes: + indexes[_id] += [i] + else: + indexes[_id] = [i] + return [torch.tensor(x, dtype=torch.int64) for x in indexes.values()] diff --git a/torchmetrics/utilities/distributed.py b/torchmetrics/utilities/distributed.py index 76cdc589e6c..34180de2d28 100644 --- a/torchmetrics/utilities/distributed.py +++ b/torchmetrics/utilities/distributed.py @@ -14,9 +14,10 @@ from typing import Any, Optional, Union import torch +from torch import Tensor -def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor: +def reduce(to_reduce: Tensor, reduction: str) -> Tensor: """ Reduces a given tensor by a given reduction method @@ -39,9 +40,7 @@ def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor: raise ValueError("Reduction parameter unknown.") -def class_reduce( - num: torch.Tensor, denom: torch.Tensor, weights: torch.Tensor, class_reduction: str = "none" -) -> torch.Tensor: +def class_reduce(num: Tensor, denom: Tensor, weights: Tensor, class_reduction: str = "none") -> Tensor: """ Function used to reduce classification metrics of the form `num / denom * weights`. For example for calculating standard accuracy the num would be number of @@ -59,6 +58,10 @@ def class_reduce( - ``'weighted'``: calculate metrics for each label, and find their weighted mean. - ``'none'`` or ``None``: returns calculated metric per class + Raises: + ValueError: + If ``class_reduction`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"none"`` or ``None``. + """ valid_reduction = ("micro", "macro", "weighted", "none", None) if class_reduction == "micro": @@ -85,7 +88,7 @@ def class_reduce( ) -def gather_all_tensors(result: Union[torch.Tensor], group: Optional[Any] = None): +def gather_all_tensors(result: Union[Tensor], group: Optional[Any] = None): """ Function to gather all tensors from several ddp processes onto a list that is broadcasted to all processes diff --git a/torchmetrics/utilities/enums.py b/torchmetrics/utilities/enums.py index 41261deff17..9c865bc44f6 100644 --- a/torchmetrics/utilities/enums.py +++ b/torchmetrics/utilities/enums.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from enum import Enum -from typing import Union +from typing import Optional, Union class EnumStr(str, Enum): @@ -28,7 +28,7 @@ class EnumStr(str, Enum): """ @classmethod - def from_str(cls, value: str) -> 'EnumStr': + def from_str(cls, value: str) -> Optional['EnumStr']: statuses = [status for status in dir(cls) if not status.startswith('_')] for st in statuses: if st.lower() == value.lower(): @@ -63,7 +63,9 @@ class AverageMethod(EnumStr): >>> None in list(AverageMethod) True - >>> 'none' == AverageMethod.NONE == None + >>> AverageMethod.NONE == None + True + >>> AverageMethod.NONE == 'none' True """ diff --git a/torchmetrics/utilities/imports.py b/torchmetrics/utilities/imports.py new file mode 100644 index 00000000000..9aa15dc8e82 --- /dev/null +++ b/torchmetrics/utilities/imports.py @@ -0,0 +1,7 @@ +from distutils.version import LooseVersion + +import torch + +_TORCH_LOWER_1_4 = LooseVersion(torch.__version__) < LooseVersion("1.4.0") +_TORCH_LOWER_1_5 = LooseVersion(torch.__version__) < LooseVersion("1.5.0") +_TORCH_LOWER_1_6 = LooseVersion(torch.__version__) < LooseVersion("1.6.0")