From 7fa5f75273dbcf2f691d7b716401a7b40ce50f3d Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 18 Aug 2021 09:51:52 +0200 Subject: [PATCH 01/31] import --- tests/classification/test_accuracy.py | 2 +- .../classification/test_calibration_error.py | 2 +- tests/classification/test_f_beta.py | 3 +- tests/classification/test_hamming_distance.py | 2 +- tests/classification/test_hinge.py | 2 +- tests/classification/test_precision_recall.py | 3 +- tests/classification/test_specificity.py | 3 +- tests/classification/test_stat_scores.py | 2 +- tests/wrappers/test_tracker.py | 3 +- torchmetrics/__init__.py | 122 +++--------------- 10 files changed, 31 insertions(+), 113 deletions(-) diff --git a/tests/classification/test_accuracy.py b/tests/classification/test_accuracy.py index 0ff54b020c7..dd2da650a6a 100644 --- a/tests/classification/test_accuracy.py +++ b/tests/classification/test_accuracy.py @@ -33,7 +33,7 @@ from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob from tests.helpers import seed_all from tests.helpers.testers import NUM_BATCHES, NUM_CLASSES, THRESHOLD, MetricTester -from torchmetrics import Accuracy +from torchmetrics.classification import Accuracy from torchmetrics.functional import accuracy from torchmetrics.utilities.checks import _input_format_classification from torchmetrics.utilities.enums import AverageMethod, DataType diff --git a/tests/classification/test_calibration_error.py b/tests/classification/test_calibration_error.py index 68822b863f8..4e6405d4e1c 100644 --- a/tests/classification/test_calibration_error.py +++ b/tests/classification/test_calibration_error.py @@ -13,7 +13,7 @@ # TODO: replace this with official sklearn implementation after next sklearn release from tests.helpers.non_sklearn_metrics import calibration_error as sk_calib from tests.helpers.testers import THRESHOLD, MetricTester -from torchmetrics import CalibrationError +from torchmetrics.classification import CalibrationError from torchmetrics.functional import calibration_error from torchmetrics.utilities.checks import _input_format_classification from torchmetrics.utilities.enums import DataType diff --git a/tests/classification/test_f_beta.py b/tests/classification/test_f_beta.py index e453057b8ed..628b73582ae 100644 --- a/tests/classification/test_f_beta.py +++ b/tests/classification/test_f_beta.py @@ -32,7 +32,8 @@ from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob from tests.helpers import seed_all from tests.helpers.testers import NUM_BATCHES, NUM_CLASSES, THRESHOLD, MetricTester -from torchmetrics import F1, FBeta, Metric +from torchmetrics import Metric +from torchmetrics.classification import F1, FBeta from torchmetrics.functional import f1, fbeta from torchmetrics.utilities.checks import _input_format_classification from torchmetrics.utilities.enums import AverageMethod diff --git a/tests/classification/test_hamming_distance.py b/tests/classification/test_hamming_distance.py index eeac1bce8c9..32f17b90839 100644 --- a/tests/classification/test_hamming_distance.py +++ b/tests/classification/test_hamming_distance.py @@ -27,7 +27,7 @@ from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob from tests.helpers import seed_all from tests.helpers.testers import THRESHOLD, MetricTester -from torchmetrics import HammingDistance +from torchmetrics.classification import HammingDistance from torchmetrics.functional import hamming_distance from torchmetrics.utilities.checks import _input_format_classification diff --git a/tests/classification/test_hinge.py b/tests/classification/test_hinge.py index 07e9f81de9c..20d04b950c3 100644 --- a/tests/classification/test_hinge.py +++ b/tests/classification/test_hinge.py @@ -21,7 +21,7 @@ from tests.classification.inputs import Input from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, MetricTester -from torchmetrics import Hinge +from torchmetrics.classification import Hinge from torchmetrics.functional import hinge from torchmetrics.functional.classification.hinge import MulticlassMode diff --git a/tests/classification/test_precision_recall.py b/tests/classification/test_precision_recall.py index 5f98730b248..4314c93cb9b 100644 --- a/tests/classification/test_precision_recall.py +++ b/tests/classification/test_precision_recall.py @@ -32,7 +32,8 @@ from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob from tests.helpers import seed_all from tests.helpers.testers import NUM_BATCHES, NUM_CLASSES, THRESHOLD, MetricTester -from torchmetrics import Metric, Precision, Recall +from torchmetrics import Metric +from torchmetrics.classification import Precision, Recall from torchmetrics.functional import precision, precision_recall, recall from torchmetrics.utilities.checks import _input_format_classification from torchmetrics.utilities.enums import AverageMethod diff --git a/tests/classification/test_specificity.py b/tests/classification/test_specificity.py index ada49edb332..da5e20fbcde 100644 --- a/tests/classification/test_specificity.py +++ b/tests/classification/test_specificity.py @@ -30,7 +30,8 @@ from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob from tests.helpers import seed_all from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester -from torchmetrics import Metric, Specificity +from torchmetrics import Metric +from torchmetrics.classification import Specificity from torchmetrics.functional import specificity from torchmetrics.functional.classification.stat_scores import _reduce_stat_scores from torchmetrics.utilities.checks import _input_format_classification diff --git a/tests/classification/test_stat_scores.py b/tests/classification/test_stat_scores.py index 5a550db19cd..cb472e73e1b 100644 --- a/tests/classification/test_stat_scores.py +++ b/tests/classification/test_stat_scores.py @@ -30,7 +30,7 @@ from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob from tests.helpers import seed_all from tests.helpers.testers import NUM_CLASSES, MetricTester -from torchmetrics import StatScores +from torchmetrics.classification import StatScores from torchmetrics.functional import stat_scores from torchmetrics.utilities.checks import _input_format_classification diff --git a/tests/wrappers/test_tracker.py b/tests/wrappers/test_tracker.py index ce3f977811c..cab44a3a823 100644 --- a/tests/wrappers/test_tracker.py +++ b/tests/wrappers/test_tracker.py @@ -17,7 +17,8 @@ import torch from tests.helpers import seed_all -from torchmetrics import Accuracy, MeanAbsoluteError, MeanSquaredError, Precision, Recall +from torchmetrics.classification import Accuracy, Precision, Recall +from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError from torchmetrics.wrappers import MetricTracker seed_all(42) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 9f9dd038235..5c7254bff3a 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -11,117 +11,31 @@ _PACKAGE_ROOT = os.path.dirname(__file__) _PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT) -from torchmetrics import functional # noqa: E402 -from torchmetrics.audio import PIT, SI_SDR, SI_SNR, SNR # noqa: E402 from torchmetrics.average import AverageMeter # noqa: E402 -from torchmetrics.classification import ( # noqa: E402 - AUC, - AUROC, - F1, - ROC, - Accuracy, - AveragePrecision, - BinnedAveragePrecision, - BinnedPrecisionRecallCurve, - BinnedRecallAtFixedPrecision, - CalibrationError, - CohenKappa, - ConfusionMatrix, - FBeta, - HammingDistance, - Hinge, - IoU, - KLDivergence, - MatthewsCorrcoef, - Precision, - PrecisionRecallCurve, - Recall, - Specificity, - StatScores, -) from torchmetrics.collections import MetricCollection # noqa: E402 -from torchmetrics.image import FID, IS, KID, PSNR, SSIM # noqa: E402 from torchmetrics.metric import Metric # noqa: E402 -from torchmetrics.regression import ( # noqa: E402 - CosineSimilarity, - ExplainedVariance, - MeanAbsoluteError, - MeanAbsolutePercentageError, - MeanSquaredError, - MeanSquaredLogError, - PearsonCorrcoef, - R2Score, - SpearmanCorrcoef, - SymmetricMeanAbsolutePercentageError, -) -from torchmetrics.retrieval import ( # noqa: E402 - RetrievalFallOut, - RetrievalMAP, - RetrievalMRR, - RetrievalNormalizedDCG, - RetrievalPrecision, - RetrievalRecall, -) -from torchmetrics.text import WER, BERTScore, BLEUScore, ROUGEScore # noqa: E402 -from torchmetrics.wrappers import BootStrapper, MetricTracker # noqa: E402 + +from torchmetrics import ( + audio, + classification, + functional, + image, + regression, + retrieval, + text, + wrappers +) # noqa: E402 __all__ = [ + "audio", + "classification", "functional", - "Accuracy", - "AUC", - "AUROC", + "image", + "regression", + "retrieval", + "text", + "wrappers", "AverageMeter", - "AveragePrecision", - "BinnedAveragePrecision", - "BinnedPrecisionRecallCurve", - "BinnedRecallAtFixedPrecision", - "BERTScore", - "BLEUScore", - "BootStrapper", - "CalibrationError", - "CohenKappa", - "ConfusionMatrix", - "CosineSimilarity", - "ExplainedVariance", - "F1", - "FBeta", - "FID", - "HammingDistance", - "Hinge", - "IoU", - "IS", - "KID", - "KLDivergence", - "MatthewsCorrcoef", - "MeanAbsoluteError", - "MeanAbsolutePercentageError", - "MeanSquaredError", - "MeanSquaredLogError", "Metric", "MetricCollection", - "MetricTracker", - "PearsonCorrcoef", - "PIT", - "Precision", - "PrecisionRecallCurve", - "PSNR", - "R2Score", - "Recall", - "RetrievalFallOut", - "RetrievalMAP", - "RetrievalMRR", - "RetrievalNormalizedDCG", - "RetrievalPrecision", - "RetrievalRecall", - "ROC", - "ROUGEScore", - "SI_SDR", - "SI_SNR", - "SNR", - "SpearmanCorrcoef", - "Specificity", - "SSIM", - "StatScores", - "SymmetricMeanAbsolutePercentageError", - "WER", ] From 0fcab09513d76daa582cc05ee7c7e6035fb18a5d Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 18 Aug 2021 17:12:14 +0200 Subject: [PATCH 02/31] remove from init --- torchmetrics/__init__.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 5c7254bff3a..83f4509d4e9 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -15,26 +15,8 @@ from torchmetrics.collections import MetricCollection # noqa: E402 from torchmetrics.metric import Metric # noqa: E402 -from torchmetrics import ( - audio, - classification, - functional, - image, - regression, - retrieval, - text, - wrappers -) # noqa: E402 __all__ = [ - "audio", - "classification", - "functional", - "image", - "regression", - "retrieval", - "text", - "wrappers", "AverageMeter", "Metric", "MetricCollection", From e06301d8aeac8b70f5ee169bb8c286e67098c996 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 18 Aug 2021 17:13:53 +0200 Subject: [PATCH 03/31] changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7a4ed7a56d0..7f898ed9296 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- Change import structure from root to submodule level ([#459](https://github.com/PyTorchLightning/metrics/issues/459)) + ### Deprecated From 9f39c6d99e34a507798e74ad0922f05df7617b32 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 19 Aug 2021 10:08:56 +0200 Subject: [PATCH 04/31] change docs --- docs/source/pages/brief_intro.rst | 2 +- docs/source/pages/lightning.rst | 6 +++--- docs/source/pages/overview.rst | 9 ++++++--- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/docs/source/pages/brief_intro.rst b/docs/source/pages/brief_intro.rst index d8088a5bc8f..b10489f2cfd 100644 --- a/docs/source/pages/brief_intro.rst +++ b/docs/source/pages/brief_intro.rst @@ -25,7 +25,7 @@ Module metrics import torchmetrics # initialize metric - metric = torchmetrics.Accuracy() + metric = torchmetrics.classification.Accuracy() n_batches = 10 for i in range(n_batches): diff --git a/docs/source/pages/lightning.rst b/docs/source/pages/lightning.rst index 0261ce91d65..13c07b4b21a 100644 --- a/docs/source/pages/lightning.rst +++ b/docs/source/pages/lightning.rst @@ -25,7 +25,7 @@ The example below shows how to use a metric in your `LightningModule Date: Sat, 28 Aug 2021 08:52:26 +0000 Subject: [PATCH 05/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 5fbc68f7899..8103904df57 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -2,7 +2,7 @@ import logging as __logging import os -from torchmetrics.__about__ import * # noqa: F401, F403 +from torchmetrics.__about__ import * # noqa: F403 _logger = __logging.getLogger("torchmetrics") _logger.addHandler(__logging.StreamHandler()) @@ -16,7 +16,6 @@ from torchmetrics.image import FID, IS, KID, LPIPS, PSNR, SSIM # noqa: E402 from torchmetrics.metric import Metric # noqa: E402 - __all__ = [ "AverageMeter", "AveragePrecision", From 3ffacb228b558e01c635d1bf53a3e2cf46d0944c Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Sat, 28 Aug 2021 11:09:25 +0200 Subject: [PATCH 06/31] docs --- docs/source/references/modules.rst | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index 63ea6f5c32b..49f409875ed 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -4,6 +4,29 @@ Module metrics .. include:: ../links.rst +All modular metrics are by default availble to import as + +```python +from torchmetrics import Accuracy, MeanSquaredError, SSIM # ect +``` + +However, metrics that requires some additionaly dependencies (other than pytorch) +such as some of the image and text metrics need to be imported from their respective submodule + +``python +from torchmetrics.image import FID, KID +from torchmetrics.text import RougeScore +``` + +Metrics have this clearly stated in their docstring and additionaly dependencies can always +be installed as + +```bash +pip install torchmetrics[image] +pip install torchmetrics[text] +pip install torchmetrics[all] # install all of the above +``` + ********** Base class ********** From 5b42f8ed360f52b7b3ac6f0f6226b472cc4dfce5 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Sat, 28 Aug 2021 11:12:02 +0200 Subject: [PATCH 07/31] revert --- docs/source/pages/brief_intro.rst | 2 +- docs/source/pages/lightning.rst | 6 +++--- docs/source/pages/overview.rst | 13 +++---------- tests/classification/test_accuracy.py | 2 +- tests/classification/test_calibration_error.py | 2 +- tests/classification/test_f_beta.py | 3 +-- tests/classification/test_hamming_distance.py | 2 +- tests/classification/test_hinge.py | 2 +- tests/classification/test_precision_recall.py | 3 +-- tests/classification/test_specificity.py | 3 +-- tests/classification/test_stat_scores.py | 2 +- tests/wrappers/test_tracker.py | 3 +-- 12 files changed, 16 insertions(+), 27 deletions(-) diff --git a/docs/source/pages/brief_intro.rst b/docs/source/pages/brief_intro.rst index b10489f2cfd..d8088a5bc8f 100644 --- a/docs/source/pages/brief_intro.rst +++ b/docs/source/pages/brief_intro.rst @@ -25,7 +25,7 @@ Module metrics import torchmetrics # initialize metric - metric = torchmetrics.classification.Accuracy() + metric = torchmetrics.Accuracy() n_batches = 10 for i in range(n_batches): diff --git a/docs/source/pages/lightning.rst b/docs/source/pages/lightning.rst index 13c07b4b21a..0261ce91d65 100644 --- a/docs/source/pages/lightning.rst +++ b/docs/source/pages/lightning.rst @@ -25,7 +25,7 @@ The example below shows how to use a metric in your `LightningModule Date: Sat, 28 Aug 2021 09:18:46 +0000 Subject: [PATCH 08/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/source/references/modules.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index 49f409875ed..8c7e8a20394 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -10,7 +10,7 @@ All modular metrics are by default availble to import as from torchmetrics import Accuracy, MeanSquaredError, SSIM # ect ``` -However, metrics that requires some additionaly dependencies (other than pytorch) +However, metrics that requires some additionaly dependencies (other than pytorch) such as some of the image and text metrics need to be imported from their respective submodule ``python From 5bbc2d95c1ea4bb5e5ff6f21549d584bfb466826 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Sat, 28 Aug 2021 11:19:00 +0200 Subject: [PATCH 09/31] change --- torchmetrics/__init__.py | 88 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 80 insertions(+), 8 deletions(-) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 8103904df57..43b64e28670 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -2,7 +2,7 @@ import logging as __logging import os -from torchmetrics.__about__ import * # noqa: F403 +from torchmetrics.__about__ import * # noqa: F401, F403 _logger = __logging.getLogger("torchmetrics") _logger.addHandler(__logging.StreamHandler()) @@ -11,18 +11,70 @@ _PACKAGE_ROOT = os.path.dirname(__file__) _PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT) +from torchmetrics import functional # noqa: E402 +from torchmetrics.audio import PIT, SI_SDR, SI_SNR, SNR # noqa: E402 from torchmetrics.average import AverageMeter # noqa: E402 +from torchmetrics.classification import ( # noqa: E402 + AUC, + AUROC, + F1, + ROC, + Accuracy, + AveragePrecision, + BinnedAveragePrecision, + BinnedPrecisionRecallCurve, + BinnedRecallAtFixedPrecision, + CalibrationError, + CohenKappa, + ConfusionMatrix, + FBeta, + HammingDistance, + Hinge, + IoU, + KLDivergence, + MatthewsCorrcoef, + Precision, + PrecisionRecallCurve, + Recall, + Specificity, + StatScores, +) from torchmetrics.collections import MetricCollection # noqa: E402 -from torchmetrics.image import FID, IS, KID, LPIPS, PSNR, SSIM # noqa: E402 +from torchmetrics.image import PSNR, SSIM # noqa: E402 from torchmetrics.metric import Metric # noqa: E402 +from torchmetrics.regression import ( # noqa: E402 + CosineSimilarity, + ExplainedVariance, + MeanAbsoluteError, + MeanAbsolutePercentageError, + MeanSquaredError, + MeanSquaredLogError, + PearsonCorrcoef, + R2Score, + SpearmanCorrcoef, + SymmetricMeanAbsolutePercentageError, +) +from torchmetrics.retrieval import ( # noqa: E402 + RetrievalFallOut, + RetrievalMAP, + RetrievalMRR, + RetrievalNormalizedDCG, + RetrievalPrecision, + RetrievalRecall, +) +from torchmetrics.text import BLEUScore, WER # noqa: E402 +from torchmetrics.wrappers import BootStrapper, MetricTracker # noqa: E402 __all__ = [ + "functional", + "Accuracy", + "AUC", + "AUROC", "AverageMeter", "AveragePrecision", "BinnedAveragePrecision", "BinnedPrecisionRecallCurve", "BinnedRecallAtFixedPrecision", - "BERTScore", "BLEUScore", "BootStrapper", "CalibrationError", @@ -32,14 +84,10 @@ "ExplainedVariance", "F1", "FBeta", - "FID", "HammingDistance", "Hinge", "IoU", - "IS", - "KID", "KLDivergence", - "LPIPS", "MatthewsCorrcoef", "MeanAbsoluteError", "MeanAbsolutePercentageError", @@ -47,4 +95,28 @@ "MeanSquaredLogError", "Metric", "MetricCollection", -] + "MetricTracker", + "PearsonCorrcoef", + "PIT", + "Precision", + "PrecisionRecallCurve", + "PSNR", + "R2Score", + "Recall", + "RetrievalFallOut", + "RetrievalMAP", + "RetrievalMRR", + "RetrievalNormalizedDCG", + "RetrievalPrecision", + "RetrievalRecall", + "ROC", + "SI_SDR", + "SI_SNR", + "SNR", + "SpearmanCorrcoef", + "Specificity", + "SSIM", + "StatScores", + "SymmetricMeanAbsolutePercentageError", + "WER", +] \ No newline at end of file From 9e04e23862d2d7038931c2f0c7bfe0fa3ce1c2eb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 28 Aug 2021 09:19:35 +0000 Subject: [PATCH 10/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 43b64e28670..7ba277f33d8 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -62,7 +62,7 @@ RetrievalPrecision, RetrievalRecall, ) -from torchmetrics.text import BLEUScore, WER # noqa: E402 +from torchmetrics.text import WER, BLEUScore # noqa: E402 from torchmetrics.wrappers import BootStrapper, MetricTracker # noqa: E402 __all__ = [ @@ -119,4 +119,4 @@ "StatScores", "SymmetricMeanAbsolutePercentageError", "WER", -] \ No newline at end of file +] From f63b6261fccf61cd00ab2c619aba55bdfddb7e80 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Sat, 28 Aug 2021 11:19:56 +0200 Subject: [PATCH 11/31] changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c83fdec0251..bb235ad05d2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,7 +24,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed -- Change import structure from root to submodule level ([#459](https://github.com/PyTorchLightning/metrics/issues/459)) +- Change import structure from root to submodule level for metrics having additional requirements ([#459](https://github.com/PyTorchLightning/metrics/issues/459)) ### Deprecated From 6ecedc4f4ea64d241eff6688d2e03dc9b65174f2 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Sat, 28 Aug 2021 11:21:01 +0200 Subject: [PATCH 12/31] revert --- docs/source/pages/overview.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/source/pages/overview.rst b/docs/source/pages/overview.rst index d161c3a5f73..73bf6c16e01 100644 --- a/docs/source/pages/overview.rst +++ b/docs/source/pages/overview.rst @@ -113,6 +113,8 @@ the native `MetricCollection`_ module can also be used to wrap multiple metrics. val3 = self.metric3['accuracy'](preds, target) val4 = self.metric4(preds, target) +You can always check which device the metric is located on using the `.device` property. + Metrics in Dataparallel (DP) mode ================================= @@ -169,6 +171,8 @@ the following limitations: - :ref:`references/modules:SSIM` and :ref:`references/functional:ssim [func]` - :ref:`references/modules:KLDivergence` and :ref:`references/functional:kl_divergence [func]` +You can always check the precision/dtype of the metric by checking the `.dtype` property. + ****************** Metric Arithmetics ****************** From 730dfc257d1ec14e522ff7ff634f792b34da371f Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 8 Sep 2021 10:59:26 +0200 Subject: [PATCH 13/31] update --- docs/source/references/modules.rst | 1 - torchmetrics/__init__.py | 7 +++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index 6677e8368f7..8b178ee8315 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -15,7 +15,6 @@ such as some of the image and text metrics need to be imported from their respec ``python from torchmetrics.image import FID, KID -from torchmetrics.text import RougeScore ``` Metrics have this clearly stated in their docstring and additionaly dependencies can always diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index da6debd4516..59bd3bb603d 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -40,7 +40,8 @@ StatScores, ) from torchmetrics.collections import MetricCollection # noqa: E402 -from torchmetrics.image import PSNR, SSIM # noqa: E402 +from torchmetrics.image.psnr import PSNR # noqa: E402 +from torchmetrics.image.ssim import SSIM # noqa: E402 from torchmetrics.metric import Metric # noqa: E402 from torchmetrics.regression import ( # noqa: E402 CosineSimilarity, @@ -63,7 +64,7 @@ RetrievalPrecision, RetrievalRecall, ) -from torchmetrics.text import WER, BLEUScore # noqa: E402 +from torchmetrics.text import WER, BERTScore, BLEUScore, ROUGEScore # noqa: E402 from torchmetrics.wrappers import BootStrapper, MetricTracker # noqa: E402 __all__ = [ @@ -76,6 +77,7 @@ "BinnedAveragePrecision", "BinnedPrecisionRecallCurve", "BinnedRecallAtFixedPrecision", + "BERTScore", "BLEUScore", "BootStrapper", "CalibrationError", @@ -112,6 +114,7 @@ "RetrievalPrecision", "RetrievalRecall", "ROC", + "ROUGEScore", "SI_SDR", "SI_SNR", "SNR", From e7e30a96e17ae88cd989410d194d416cb619f912 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 8 Sep 2021 11:01:21 +0200 Subject: [PATCH 14/31] update imports --- torchmetrics/image/fid.py | 2 +- torchmetrics/image/inception.py | 2 +- torchmetrics/image/kid.py | 2 +- torchmetrics/image/lpip_similarity.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/torchmetrics/image/fid.py b/torchmetrics/image/fid.py index 2950d200fce..5ca029e7033 100644 --- a/torchmetrics/image/fid.py +++ b/torchmetrics/image/fid.py @@ -192,7 +192,7 @@ class FID(Metric): Example: >>> import torch >>> _ = torch.manual_seed(123) - >>> from torchmetrics import FID + >>> from torchmetrics.image import FID >>> fid = FID(feature=64) # doctest: +SKIP >>> # generate two slightly overlapping image intensity distributions >>> imgs_dist1 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8) # doctest: +SKIP diff --git a/torchmetrics/image/inception.py b/torchmetrics/image/inception.py index 3d70d90939c..a616f0d7d84 100644 --- a/torchmetrics/image/inception.py +++ b/torchmetrics/image/inception.py @@ -91,7 +91,7 @@ class IS(Metric): Example: >>> import torch >>> _ = torch.manual_seed(123) - >>> from torchmetrics import IS + >>> from torchmetrics.image import IS >>> inception = IS() # doctest: +SKIP >>> # generate some images >>> imgs = torch.randint(0, 255, (100, 3, 299, 299), dtype=torch.uint8) # doctest: +SKIP diff --git a/torchmetrics/image/kid.py b/torchmetrics/image/kid.py index 78c9c3502e2..dda345f16ff 100644 --- a/torchmetrics/image/kid.py +++ b/torchmetrics/image/kid.py @@ -150,7 +150,7 @@ class KID(Metric): Example: >>> import torch >>> _ = torch.manual_seed(123) - >>> from torchmetrics import KID + >>> from torchmetrics.image import KID >>> kid = KID(subset_size=50) # doctest: +SKIP >>> # generate two slightly overlapping image intensity distributions >>> imgs_dist1 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8) # doctest: +SKIP diff --git a/torchmetrics/image/lpip_similarity.py b/torchmetrics/image/lpip_similarity.py index 52418187059..c3c4096e6b1 100644 --- a/torchmetrics/image/lpip_similarity.py +++ b/torchmetrics/image/lpip_similarity.py @@ -79,7 +79,7 @@ class LPIPS(Metric): Example: >>> import torch >>> _ = torch.manual_seed(123) - >>> from torchmetrics import LPIPS + >>> from torchmetrics.image import LPIPS >>> lpips = LPIPS(net_type='vgg') >>> img1 = torch.rand(10, 3, 100, 100) >>> img2 = torch.rand(10, 3, 100, 100) From 0b6b6eec7a07c664030fe6cf3fc7af539de1d434 Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 21 Sep 2021 18:17:19 +0200 Subject: [PATCH 15/31] adjusting PT 1.8.2 --- requirements/adjust-versions.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/requirements/adjust-versions.py b/requirements/adjust-versions.py index e13c55a61ee..97bb3ea1d61 100644 --- a/requirements/adjust-versions.py +++ b/requirements/adjust-versions.py @@ -4,10 +4,13 @@ import sys from typing import Dict, Optional +from packaging.version import Version + VERSIONS = [ dict(torch="1.10.0", torchvision="0.11.0", torchtext=""), # nightly dict(torch="1.9.1", torchvision="0.10.1", torchtext="0.10.1"), dict(torch="1.9.0", torchvision="0.10.0", torchtext="0.10.0"), + dict(torch="1.8.2", torchvision="0.9.1", torchtext="0.9.1"), dict(torch="1.8.1", torchvision="0.9.1", torchtext="0.9.1"), dict(torch="1.8.0", torchvision="0.9.0", torchtext="0.9.0"), dict(torch="1.7.1", torchvision="0.8.2", torchtext="0.8.1"), @@ -19,7 +22,7 @@ dict(torch="1.3.1", torchvision="0.4.2", torchtext="0.4"), dict(torch="1.3.0", torchvision="0.4.1", torchtext="0.4"), ] -VERSIONS.sort(key=lambda v: v["torch"], reverse=True) +VERSIONS.sort(key=lambda v: Version(v["torch"]), reverse=True) def find_latest(ver: str) -> Dict[str, str]: From 15f31e26cc3604901a1e8df9555ed373926e988f Mon Sep 17 00:00:00 2001 From: Jirka Date: Mon, 25 Oct 2021 14:10:28 +0200 Subject: [PATCH 16/31] prepare 0.6 RC --- CHANGELOG.md | 64 ++++--------------- .../classification/test_average_precision.py | 2 +- torchmetrics/__about__.py | 2 +- torchmetrics/classification/__init__.py | 2 +- ...{average_precision.py => avg_precision.py} | 0 5 files changed, 17 insertions(+), 53 deletions(-) rename torchmetrics/classification/{average_precision.py => avg_precision.py} (100%) diff --git a/CHANGELOG.md b/CHANGELOG.md index 49ebaf56417..14bf326b303 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,67 +6,38 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 **Note: we move fast, but still we preserve 0.1 version (one feature release) back compatibility.** -## [unReleased] - 2021-MM-DD +## [0.6.0] - 2021-10-DD ### Added -- Added Learned Perceptual Image Patch Similarity (LPIPS) ([#431](https://github.com/PyTorchLightning/metrics/issues/431)) - - -- Added Tweedie Deviance Score ([#499](https://github.com/PyTorchLightning/metrics/pull/499)) - - +- Added audio metrics: + - Perceptual Evaluation of Speech Quality (PESQ) ([#353](https://github.com/PyTorchLightning/metrics/issues/353)) + - Short Term Objective Intelligibility (STOI) ([#353](https://github.com/PyTorchLightning/metrics/issues/353)) +- Added Information retrieval metrics: + - `RetrievalRPrecision` ([#577](https://github.com/PyTorchLightning/metrics/pull/577/)) + - `RetrievalHitRate` ([#576](https://github.com/PyTorchLightning/metrics/pull/576)) +- Added NLP metrics: + - `SacreBLEUScore` ([#546](https://github.com/PyTorchLightning/metrics/pull/546)) + - `CharErrorRate` ([#575](https://github.com/PyTorchLightning/metrics/pull/575)) +- Added other metrics: + - Tweedie Deviance Score ([#499](https://github.com/PyTorchLightning/metrics/pull/499)) + - Learned Perceptual Image Patch Similarity (LPIPS) ([#431](https://github.com/PyTorchLightning/metrics/issues/431)) - Added support for float targets in `nDCG` metric ([#437](https://github.com/PyTorchLightning/metrics/pull/437)) - - -- Added `average` argument to `AveragePrecision` metric for reducing multilabel and multiclass problems ([#477](https://github.com/PyTorchLightning/metrics/pull/477)) - - -- Added Perceptual Evaluation of Speech Quality (PESQ) ([#353](https://github.com/PyTorchLightning/metrics/issues/353)) - - +- Added `average` argument to `AveragePrecision` metric for reducing multi-label and multi-class problems ([#477](https://github.com/PyTorchLightning/metrics/pull/477)) - Added `MultioutputWrapper` ([#510](https://github.com/PyTorchLightning/metrics/pull/510)) - - - Added metric sweeping `higher_is_better` as constant attribute ([#544](https://github.com/PyTorchLightning/metrics/pull/544)) - - -- Added `SacreBLEUScore` metric to text package ([#546](https://github.com/PyTorchLightning/metrics/pull/546)) - - - Added simple aggregation metrics: `SumMetric`, `MeanMetric`, `CatMetric`, `MinMetric`, `MaxMetric` ([#506](https://github.com/PyTorchLightning/metrics/pull/506)) - - - Added pairwise submodule with metrics ([#553](https://github.com/PyTorchLightning/metrics/pull/553)) - `pairwise_cosine_similarity` - `pairwise_euclidean_distance` - `pairwise_linear_similarity` - `pairwise_manhatten_distance` - -- Added Short Term Objective Intelligibility (`STOI`) ([#353](https://github.com/PyTorchLightning/metrics/issues/353)) - - -- Added `RetrievalRPrecision` metric to retrieval package ([#577](https://github.com/PyTorchLightning/metrics/pull/577/)) - - -- Added `RetrievalHitRate` metric to retrieval package ([#576](https://github.com/PyTorchLightning/metrics/pull/576)) - - -- Added `CharErrorRate` metric to text package ([#575](https://github.com/PyTorchLightning/metrics/pull/575)) - - ### Changed - `AveragePrecision` will now as default output the `macro` average for multilabel and multiclass problems ([#477](https://github.com/PyTorchLightning/metrics/pull/477)) - - - `half`, `double`, `float` will no longer change the dtype of the metric states. Use `metric.set_dtype` instead ([#493](https://github.com/PyTorchLightning/metrics/pull/493)) - - - Renamed `AverageMeter` to `MeanMetric` ([#506](https://github.com/PyTorchLightning/metrics/pull/506)) - - - Changed `is_differentiable` from property to a constant attribute ([#551](https://github.com/PyTorchLightning/metrics/pull/551)) ### Deprecated @@ -77,18 +48,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Removed `dtype` property ([#493](https://github.com/PyTorchLightning/metrics/pull/493)) - ### Fixed - Fixed bug in `F1` with `average='macro'` and `ignore_index!=None` ([#495](https://github.com/PyTorchLightning/metrics/pull/495)) - - - Fixed bug in `pit` by using the returned first result to initialize device and type ([#533](https://github.com/PyTorchLightning/metrics/pull/533)) - - - Fixed `SSIM` metric using too much memory ([#539](https://github.com/PyTorchLightning/metrics/pull/539)) - - - Fixed bug where `device` property was not properly update when metric was a child of a module ([#542](https://github.com/PyTorchLightning/metrics/pull/542)) ## [0.5.1] - 2021-08-30 diff --git a/tests/classification/test_average_precision.py b/tests/classification/test_average_precision.py index 5c65a2256cf..557fb92b154 100644 --- a/tests/classification/test_average_precision.py +++ b/tests/classification/test_average_precision.py @@ -24,7 +24,7 @@ from tests.classification.inputs import _input_multilabel from tests.helpers import seed_all from tests.helpers.testers import NUM_CLASSES, MetricTester -from torchmetrics.classification.average_precision import AveragePrecision +from torchmetrics.classification.avg_precision import AveragePrecision from torchmetrics.functional import average_precision seed_all(42) diff --git a/torchmetrics/__about__.py b/torchmetrics/__about__.py index ea1d5dcd3c0..334914962d1 100644 --- a/torchmetrics/__about__.py +++ b/torchmetrics/__about__.py @@ -1,4 +1,4 @@ -__version__ = "0.6.0dev" +__version__ = "0.6.0rc0" __author__ = "PyTorchLightning et al." __author_email__ = "name@pytorchlightning.ai" __license__ = "Apache-2.0" diff --git a/torchmetrics/classification/__init__.py b/torchmetrics/classification/__init__.py index 35476172b06..0ed2d3d8d8b 100644 --- a/torchmetrics/classification/__init__.py +++ b/torchmetrics/classification/__init__.py @@ -14,7 +14,7 @@ from torchmetrics.classification.accuracy import Accuracy # noqa: F401 from torchmetrics.classification.auc import AUC # noqa: F401 from torchmetrics.classification.auroc import AUROC # noqa: F401 -from torchmetrics.classification.average_precision import AveragePrecision # noqa: F401 +from torchmetrics.classification.avg_precision import AveragePrecision # noqa: F401 from torchmetrics.classification.binned_precision_recall import BinnedAveragePrecision # noqa: F401 from torchmetrics.classification.binned_precision_recall import BinnedPrecisionRecallCurve # noqa: F401 from torchmetrics.classification.binned_precision_recall import BinnedRecallAtFixedPrecision # noqa: F401 diff --git a/torchmetrics/classification/average_precision.py b/torchmetrics/classification/avg_precision.py similarity index 100% rename from torchmetrics/classification/average_precision.py rename to torchmetrics/classification/avg_precision.py From b2e60742f68750eb8be59a254d303f498ffbcb42 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 24 Nov 2021 13:19:18 +0100 Subject: [PATCH 17/31] audio --- torchmetrics/__init__.py | 4 +--- torchmetrics/audio/__init__.py | 4 +--- torchmetrics/audio/pesq.py | 2 +- torchmetrics/audio/stoi.py | 2 +- torchmetrics/functional/__init__.py | 4 ---- torchmetrics/functional/audio/__init__.py | 2 -- torchmetrics/functional/audio/pesq.py | 2 +- torchmetrics/functional/audio/stoi.py | 2 +- 8 files changed, 6 insertions(+), 16 deletions(-) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 144033b808f..bb3385e9b0d 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -13,7 +13,7 @@ from torchmetrics import functional # noqa: E402 from torchmetrics.aggregation import CatMetric, MaxMetric, MeanMetric, MinMetric, SumMetric # noqa: E402 -from torchmetrics.audio import PESQ, PIT, SI_SDR, SI_SNR, SNR, STOI # noqa: E402 +from torchmetrics.audio import PIT, SI_SDR, SI_SNR, SNR # noqa: E402 from torchmetrics.classification import ( # noqa: E402 AUC, AUROC, @@ -122,7 +122,6 @@ "MinMetric", "MultioutputWrapper", "PearsonCorrcoef", - "PESQ", "PIT", "Precision", "PrecisionRecallCurve", @@ -148,7 +147,6 @@ "SQuAD", "SSIM", "StatScores", - "STOI", "SumMetric", "SymmetricMeanAbsolutePercentageError", "WER", diff --git a/torchmetrics/audio/__init__.py b/torchmetrics/audio/__init__.py index fe1dd7e4901..3f70622cc22 100644 --- a/torchmetrics/audio/__init__.py +++ b/torchmetrics/audio/__init__.py @@ -11,9 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from torchmetrics.audio.pesq import PESQ # noqa: F401 from torchmetrics.audio.pit import PIT # noqa: F401 from torchmetrics.audio.si_sdr import SI_SDR # noqa: F401 from torchmetrics.audio.si_snr import SI_SNR # noqa: F401 -from torchmetrics.audio.snr import SNR # noqa: F401 -from torchmetrics.audio.stoi import STOI # noqa: F401 +from torchmetrics.audio.snr import SNR # noqa: F401 \ No newline at end of file diff --git a/torchmetrics/audio/pesq.py b/torchmetrics/audio/pesq.py index eaff3dd8e93..374dab1ca31 100644 --- a/torchmetrics/audio/pesq.py +++ b/torchmetrics/audio/pesq.py @@ -62,7 +62,7 @@ class PESQ(Metric): If ``mode`` is not either ``"wb"`` or ``"nb"`` Example: - >>> from torchmetrics.audio import PESQ + >>> from torchmetrics.audio.pesq import PESQ >>> import torch >>> g = torch.manual_seed(1) >>> preds = torch.randn(8000) diff --git a/torchmetrics/audio/stoi.py b/torchmetrics/audio/stoi.py index 1c2148b9b4c..126bc7eddf6 100644 --- a/torchmetrics/audio/stoi.py +++ b/torchmetrics/audio/stoi.py @@ -63,7 +63,7 @@ class STOI(Metric): If ``pystoi`` package is not installed Example: - >>> from torchmetrics.audio import STOI + >>> from torchmetrics.audio.stoi import STOI >>> import torch >>> g = torch.manual_seed(1) >>> preds = torch.randn(8000) diff --git a/torchmetrics/functional/__init__.py b/torchmetrics/functional/__init__.py index d359d48b368..05f9d4fe4e8 100644 --- a/torchmetrics/functional/__init__.py +++ b/torchmetrics/functional/__init__.py @@ -11,12 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from torchmetrics.functional.audio.pesq import pesq from torchmetrics.functional.audio.pit import pit, pit_permutate from torchmetrics.functional.audio.si_sdr import si_sdr from torchmetrics.functional.audio.si_snr import si_snr from torchmetrics.functional.audio.snr import snr -from torchmetrics.functional.audio.stoi import stoi from torchmetrics.functional.classification.accuracy import accuracy from torchmetrics.functional.classification.auc import auc from torchmetrics.functional.classification.auroc import auroc @@ -106,7 +104,6 @@ "pairwise_linear_similarity", "pairwise_manhatten_distance", "pearson_corrcoef", - "pesq", "pit", "pit_permutate", "precision", @@ -134,7 +131,6 @@ "squad", "ssim", "stat_scores", - "stoi", "symmetric_mean_absolute_percentage_error", "wer", "char_error_rate", diff --git a/torchmetrics/functional/audio/__init__.py b/torchmetrics/functional/audio/__init__.py index 678f45419db..f701dad2f11 100644 --- a/torchmetrics/functional/audio/__init__.py +++ b/torchmetrics/functional/audio/__init__.py @@ -11,9 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from torchmetrics.functional.audio.pesq import pesq # noqa: F401 from torchmetrics.functional.audio.pit import pit, pit_permutate # noqa: F401 from torchmetrics.functional.audio.si_sdr import si_sdr # noqa: F401 from torchmetrics.functional.audio.si_snr import si_snr # noqa: F401 from torchmetrics.functional.audio.snr import snr # noqa: F401 -from torchmetrics.functional.audio.stoi import stoi # noqa: F401 diff --git a/torchmetrics/functional/audio/pesq.py b/torchmetrics/functional/audio/pesq.py index 268002712dc..c45eafbe9df 100644 --- a/torchmetrics/functional/audio/pesq.py +++ b/torchmetrics/functional/audio/pesq.py @@ -58,7 +58,7 @@ def pesq(preds: Tensor, target: Tensor, fs: int, mode: str, keep_same_device: bo If ``mode`` is not either ``"wb"`` or ``"nb"`` Example: - >>> from torchmetrics.functional.audio import pesq + >>> from torchmetrics.functional.audio.pesq import pesq >>> import torch >>> g = torch.manual_seed(1) >>> preds = torch.randn(8000) diff --git a/torchmetrics/functional/audio/stoi.py b/torchmetrics/functional/audio/stoi.py index 71e36bf9c54..8cfb2435991 100644 --- a/torchmetrics/functional/audio/stoi.py +++ b/torchmetrics/functional/audio/stoi.py @@ -59,7 +59,7 @@ def stoi(preds: Tensor, target: Tensor, fs: int, extended: bool = False, keep_sa If ``pystoi`` package is not installed Example: - >>> from torchmetrics.functional.audio import stoi + >>> from torchmetrics.functional.audio.stoi import stoi >>> import torch >>> g = torch.manual_seed(1) >>> preds = torch.randn(8000) From 8002fdd8997a4d93d966d725e664ca0f3c2304e6 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 24 Nov 2021 18:42:55 +0100 Subject: [PATCH 18/31] detection --- torchmetrics/__init__.py | 2 -- torchmetrics/detection/__init__.py | 1 - torchmetrics/detection/map.py | 20 ++++++++++++++++++++ 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index bb3385e9b0d..43f73c4e4fc 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -40,7 +40,6 @@ StatScores, ) from torchmetrics.collections import MetricCollection # noqa: E402 -from torchmetrics.detection import MAP # noqa: E402 from torchmetrics.image import FID, IS, KID, LPIPS, PSNR, SSIM # noqa: E402 from torchmetrics.metric import Metric # noqa: E402 from torchmetrics.regression import ( # noqa: E402 @@ -107,7 +106,6 @@ "KID", "KLDivergence", "LPIPS", - "MAP", "MatthewsCorrcoef", "MaxMetric", "MeanAbsoluteError", diff --git a/torchmetrics/detection/__init__.py b/torchmetrics/detection/__init__.py index f8d01bdb293..d7aa17d7f84 100644 --- a/torchmetrics/detection/__init__.py +++ b/torchmetrics/detection/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from torchmetrics.detection.map import MAP # noqa: F401 diff --git a/torchmetrics/detection/map.py b/torchmetrics/detection/map.py index eb4786d5eb7..8b492802874 100644 --- a/torchmetrics/detection/map.py +++ b/torchmetrics/detection/map.py @@ -184,6 +184,26 @@ class MAP(Metric): Callback that performs the allgather operation on the metric state. When ``None``, DDP will be used to perform the allgather + Example: + >>> from torchmetrics.detection.map import MAP + >>> import torch + >>> preds = [ + ... dict( + ... boxes=torch.Tensor([[258.0, 41.0, 606.0, 285.0]]), + ... scores=torch.Tensor([0.536]), + ... labels=torch.IntTensor([0]), + ... ) + ... ] + >>> target = [ + ... dict( + ... boxes=torch.Tensor([[214.0, 41.0, 562.0, 285.0]]), + ... labels=torch.IntTensor([0]), + ... ) + ... ] + >>> metric = MAP() + >>> metric.update(preds, target) + >>> metric.compute() + Raises: ImportError: If ``pycocotools`` is not installed From ae7ed7ad9390e885a19793a3f08a12f8c29fe86f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 24 Nov 2021 17:46:18 +0000 Subject: [PATCH 19/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/audio/__init__.py | 2 +- torchmetrics/detection/map.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmetrics/audio/__init__.py b/torchmetrics/audio/__init__.py index 3f70622cc22..94d20384e0a 100644 --- a/torchmetrics/audio/__init__.py +++ b/torchmetrics/audio/__init__.py @@ -14,4 +14,4 @@ from torchmetrics.audio.pit import PIT # noqa: F401 from torchmetrics.audio.si_sdr import SI_SDR # noqa: F401 from torchmetrics.audio.si_snr import SI_SNR # noqa: F401 -from torchmetrics.audio.snr import SNR # noqa: F401 \ No newline at end of file +from torchmetrics.audio.snr import SNR # noqa: F401 diff --git a/torchmetrics/detection/map.py b/torchmetrics/detection/map.py index 8b492802874..fe34f77a26f 100644 --- a/torchmetrics/detection/map.py +++ b/torchmetrics/detection/map.py @@ -190,7 +190,7 @@ class MAP(Metric): >>> preds = [ ... dict( ... boxes=torch.Tensor([[258.0, 41.0, 606.0, 285.0]]), - ... scores=torch.Tensor([0.536]), + ... scores=torch.Tensor([0.536]), ... labels=torch.IntTensor([0]), ... ) ... ] From 7d816216e01048cd4e0f402b65b4711983917f01 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 24 Nov 2021 18:50:35 +0100 Subject: [PATCH 20/31] image --- torchmetrics/__init__.py | 6 +----- torchmetrics/image/__init__.py | 4 ---- torchmetrics/image/fid.py | 2 +- torchmetrics/image/inception.py | 2 +- torchmetrics/image/kid.py | 2 +- torchmetrics/image/lpip_similarity.py | 2 +- 6 files changed, 5 insertions(+), 13 deletions(-) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 78659d3ffe7..3d198bb5c40 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -40,7 +40,7 @@ StatScores, ) from torchmetrics.collections import MetricCollection # noqa: E402 -from torchmetrics.image import FID, IS, KID, LPIPS, PSNR, SSIM # noqa: E402 +from torchmetrics.image import PSNR, SSIM # noqa: E402 from torchmetrics.metric import Metric # noqa: E402 from torchmetrics.regression import ( # noqa: E402 CosineSimilarity, @@ -100,14 +100,10 @@ "ExplainedVariance", "F1", "FBeta", - "FID", "HammingDistance", "Hinge", "IoU", - "IS", - "KID", "KLDivergence", - "LPIPS", "MatthewsCorrcoef", "MaxMetric", "MeanAbsoluteError", diff --git a/torchmetrics/image/__init__.py b/torchmetrics/image/__init__.py index 8ee5d0c5107..b3595139bc6 100644 --- a/torchmetrics/image/__init__.py +++ b/torchmetrics/image/__init__.py @@ -11,9 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from torchmetrics.image.fid import FID # noqa: F401 -from torchmetrics.image.inception import IS # noqa: F401 -from torchmetrics.image.kid import KID # noqa: F401 -from torchmetrics.image.lpip_similarity import LPIPS # noqa: F401 from torchmetrics.image.psnr import PSNR # noqa: F401 from torchmetrics.image.ssim import SSIM # noqa: F401 diff --git a/torchmetrics/image/fid.py b/torchmetrics/image/fid.py index 3b891ec3145..a88033c92a1 100644 --- a/torchmetrics/image/fid.py +++ b/torchmetrics/image/fid.py @@ -191,7 +191,7 @@ class FID(Metric): Example: >>> import torch >>> _ = torch.manual_seed(123) - >>> from torchmetrics import FID + >>> from torchmetrics.image.fid import FID >>> fid = FID(feature=64) # doctest: +SKIP >>> # generate two slightly overlapping image intensity distributions >>> imgs_dist1 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8) # doctest: +SKIP diff --git a/torchmetrics/image/inception.py b/torchmetrics/image/inception.py index d2c4504c1d7..9738934a615 100644 --- a/torchmetrics/image/inception.py +++ b/torchmetrics/image/inception.py @@ -91,7 +91,7 @@ class IS(Metric): Example: >>> import torch >>> _ = torch.manual_seed(123) - >>> from torchmetrics import IS + >>> from torchmetrics.image.inception import IS >>> inception = IS() # doctest: +SKIP >>> # generate some images >>> imgs = torch.randint(0, 255, (100, 3, 299, 299), dtype=torch.uint8) # doctest: +SKIP diff --git a/torchmetrics/image/kid.py b/torchmetrics/image/kid.py index d860a2706b8..6e691e06cb3 100644 --- a/torchmetrics/image/kid.py +++ b/torchmetrics/image/kid.py @@ -150,7 +150,7 @@ class KID(Metric): Example: >>> import torch >>> _ = torch.manual_seed(123) - >>> from torchmetrics import KID + >>> from torchmetrics.image.kid import KID >>> kid = KID(subset_size=50) # doctest: +SKIP >>> # generate two slightly overlapping image intensity distributions >>> imgs_dist1 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8) # doctest: +SKIP diff --git a/torchmetrics/image/lpip_similarity.py b/torchmetrics/image/lpip_similarity.py index 990c2bea1a3..48e4a758ed3 100644 --- a/torchmetrics/image/lpip_similarity.py +++ b/torchmetrics/image/lpip_similarity.py @@ -79,7 +79,7 @@ class LPIPS(Metric): Example: >>> import torch >>> _ = torch.manual_seed(123) - >>> from torchmetrics import LPIPS + >>> from torchmetrics.image.lpip_similarity import LPIPS >>> lpips = LPIPS(net_type='vgg') >>> img1 = torch.rand(10, 3, 100, 100) >>> img2 = torch.rand(10, 3, 100, 100) From a0a4aede5e657721995ff3e3b4fc2fa613978364 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 24 Nov 2021 18:58:08 +0100 Subject: [PATCH 21/31] text --- torchmetrics/__init__.py | 2 -- torchmetrics/functional/text/bert.py | 1 + torchmetrics/functional/text/rouge.py | 1 + torchmetrics/text/__init__.py | 2 -- torchmetrics/text/bert.py | 1 + torchmetrics/text/rouge.py | 2 +- 6 files changed, 4 insertions(+), 5 deletions(-) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 3d198bb5c40..27fd8113224 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -88,7 +88,6 @@ "BinnedAveragePrecision", "BinnedPrecisionRecallCurve", "BinnedRecallAtFixedPrecision", - "BERTScore", "BLEUScore", "BootStrapper", "CalibrationError", @@ -133,7 +132,6 @@ "RetrievalRecall", "RetrievalRPrecision", "ROC", - "ROUGEScore", "SacreBLEUScore", "SI_SDR", "SI_SNR", diff --git a/torchmetrics/functional/text/bert.py b/torchmetrics/functional/text/bert.py index 98017eaa72a..9a80cf84362 100644 --- a/torchmetrics/functional/text/bert.py +++ b/torchmetrics/functional/text/bert.py @@ -547,6 +547,7 @@ def bert_score( If invalid input is provided. Example: + >>> from torchmetrics.functional.text.bert import bert_score >>> predictions = ["hello there", "general kenobi"] >>> references = ["hello there", "master kenobi"] >>> bert_score(predictions=predictions, references=references, lang="en") # doctest: +SKIP diff --git a/torchmetrics/functional/text/rouge.py b/torchmetrics/functional/text/rouge.py index b1602a5097d..17e673c3d43 100644 --- a/torchmetrics/functional/text/rouge.py +++ b/torchmetrics/functional/text/rouge.py @@ -261,6 +261,7 @@ def rouge_score( Python dictionary of rouge scores for each input rouge key. Example: + >>> from torchmetrics.functional.text.rouge import rouge_score >>> targets = "Is your name John" >>> preds = "My name is John" >>> from pprint import pprint diff --git a/torchmetrics/text/__init__.py b/torchmetrics/text/__init__.py index bef443fe644..ff78c668360 100644 --- a/torchmetrics/text/__init__.py +++ b/torchmetrics/text/__init__.py @@ -11,11 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from torchmetrics.text.bert import BERTScore # noqa: F401 from torchmetrics.text.bleu import BLEUScore # noqa: F401 from torchmetrics.text.cer import CharErrorRate # noqa: F401 from torchmetrics.text.mer import MatchErrorRate # noqa: F401 -from torchmetrics.text.rouge import ROUGEScore # noqa: F401 from torchmetrics.text.sacre_bleu import SacreBLEUScore # noqa: F401 from torchmetrics.text.squad import SQuAD # noqa: F401 from torchmetrics.text.wer import WER # noqa: F401 diff --git a/torchmetrics/text/bert.py b/torchmetrics/text/bert.py index ff0059e929a..2324d4d47e9 100644 --- a/torchmetrics/text/bert.py +++ b/torchmetrics/text/bert.py @@ -111,6 +111,7 @@ class BERTScore(Metric): Python dictionary containing the keys `precision`, `recall` and `f1` with corresponding values. Example: + >>> from torchmetrics.text.bert import BertScore >>> predictions = ["hello there", "general kenobi"] >>> references = ["hello there", "master kenobi"] >>> bertscore = BERTScore() diff --git a/torchmetrics/text/rouge.py b/torchmetrics/text/rouge.py index fa5b987641b..14d423ae5e7 100644 --- a/torchmetrics/text/rouge.py +++ b/torchmetrics/text/rouge.py @@ -49,7 +49,7 @@ class ROUGEScore(Metric): will be used to perform the allgather. Example: - + >>> from torchmetrics.text.rouge import ROUGEScore >>> targets = "Is your name John" >>> preds = "My name is John" >>> rouge = ROUGEScore() # doctest: +SKIP From db90fd35810441e470eb27df98dd4cb024c8e8cc Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 24 Nov 2021 18:59:52 +0100 Subject: [PATCH 22/31] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2a35b5b78a0..e009be5ffbf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Use `torch.topk` instead of `torch.argsort` in retrieval precision for speedup ([#627](https://github.com/PyTorchLightning/metrics/pull/627)) +- Metrics having third party dependencies removed from global import ([#463](https://github.com/PyTorchLightning/metrics/pull/463)) + + ### Deprecated From 3d0d8c3614f2360106f08fedb0dcc8554049dde5 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 24 Nov 2021 19:03:24 +0100 Subject: [PATCH 23/31] fix import --- torchmetrics/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 27fd8113224..279df1c887e 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -67,11 +67,9 @@ ) from torchmetrics.text import ( # noqa: E402 WER, - BERTScore, BLEUScore, CharErrorRate, MatchErrorRate, - ROUGEScore, SacreBLEUScore, SQuAD, WordInfoLost, From 4f99eb8916d571a68cde9dba560d8b95d3857b98 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 24 Nov 2021 19:08:09 +0100 Subject: [PATCH 24/31] fix import --- torchmetrics/text/bert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/text/bert.py b/torchmetrics/text/bert.py index 2324d4d47e9..64d64c9eedd 100644 --- a/torchmetrics/text/bert.py +++ b/torchmetrics/text/bert.py @@ -111,7 +111,7 @@ class BERTScore(Metric): Python dictionary containing the keys `precision`, `recall` and `f1` with corresponding values. Example: - >>> from torchmetrics.text.bert import BertScore + >>> from torchmetrics.text.bert import BERTScore >>> predictions = ["hello there", "general kenobi"] >>> references = ["hello there", "master kenobi"] >>> bertscore = BERTScore() From 0f2a09e78dd5251d69009b9d7254a1c239ff5314 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 24 Nov 2021 19:54:20 +0100 Subject: [PATCH 25/31] fix docs --- docs/source/references/functional.rst | 8 ++++---- docs/source/references/modules.rst | 18 +++++++++--------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/docs/source/references/functional.rst b/docs/source/references/functional.rst index fcdcca823b0..ae09ac2daee 100644 --- a/docs/source/references/functional.rst +++ b/docs/source/references/functional.rst @@ -14,7 +14,7 @@ Audio Metrics pesq [func] ~~~~~~~~~~~ -.. autofunction:: torchmetrics.functional.pesq +.. autofunction:: torchmetrics.functional.audio.pesq.pesq pit [func] @@ -48,7 +48,7 @@ snr [func] stoi [func] ~~~~~~~~~~~ -.. autofunction:: torchmetrics.functional.stoi +.. autofunction:: torchmetrics.functional.audio.stoi.stoi :noindex: @@ -426,7 +426,7 @@ Text bert_score [func] ~~~~~~~~~~~~~~~~~~ -.. autofunction:: torchmetrics.functional.bert_score +.. autofunction:: torchmetrics.functional.text.bert.bert_score bleu_score [func] ~~~~~~~~~~~~~~~~~ @@ -449,7 +449,7 @@ match_error_rate [func] rouge_score [func] ~~~~~~~~~~~~~~~~~~ -.. autofunction:: torchmetrics.functional.rouge_score +.. autofunction:: torchmetrics.functional.text.rouge.rouge_score :noindex: sacre_bleu_score [func] diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index 4acc7e2544f..7a86580429f 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -77,7 +77,7 @@ the metric will be computed over the ``time`` dimension. PESQ ~~~~ -.. autoclass:: torchmetrics.PESQ +.. autoclass:: torchmetrics.audio.pesq.PESQ PIT ~~~ @@ -106,7 +106,7 @@ SNR STOI ~~~~ -.. autoclass:: torchmetrics.STOI +.. autoclass:: torchmetrics.audio.stoi.STOI :noindex: @@ -363,25 +363,25 @@ learning algorithms such as `Generative Adverserial Networks (GANs) Date: Wed, 24 Nov 2021 20:10:05 +0100 Subject: [PATCH 26/31] fix doctest --- torchmetrics/detection/map.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/torchmetrics/detection/map.py b/torchmetrics/detection/map.py index fe34f77a26f..c49f1ba68b2 100644 --- a/torchmetrics/detection/map.py +++ b/torchmetrics/detection/map.py @@ -186,6 +186,7 @@ class MAP(Metric): Example: >>> from torchmetrics.detection.map import MAP + >>> from pprint import pprint >>> import torch >>> preds = [ ... dict( @@ -202,7 +203,22 @@ class MAP(Metric): ... ] >>> metric = MAP() >>> metric.update(preds, target) - >>> metric.compute() + >>> pprint(metric.compute()) # doctest: +NORMALIZE_WHITESPACE +SKIP + {'map': tensor(0.6000), + 'map_50': tensor(1.), + 'map_75': tensor(1.), + 'map_small': tensor(-1.), + 'map_medium': tensor(-1.), + 'map_large': tensor(0.6000), + 'mar_1': tensor(0.6000), + 'mar_10': tensor(0.6000), + 'mar_100': tensor(0.6000), + 'mar_small': tensor(-1.), + 'mar_medium': tensor(-1.), + 'mar_large': tensor(0.6000), + 'map_per_class': tensor(-1.), + 'mar_100_per_class': tensor(-1.) + } Raises: ImportError: From ee17b482102a7d7643e8d436bf9df31e1f80e79f Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 24 Nov 2021 20:44:32 +0100 Subject: [PATCH 27/31] fix test imports --- tests/audio/test_pesq.py | 4 ++-- tests/audio/test_stoi.py | 4 ++-- tests/text/test_bertscore.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/audio/test_pesq.py b/tests/audio/test_pesq.py index ec65ea38ae2..93fce3fe365 100644 --- a/tests/audio/test_pesq.py +++ b/tests/audio/test_pesq.py @@ -21,8 +21,8 @@ from tests.helpers import seed_all from tests.helpers.testers import MetricTester -from torchmetrics.audio import PESQ -from torchmetrics.functional import pesq +from torchmetrics.audio.pesq import PESQ +from torchmetrics.functional.audio.pesq import pesq from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 seed_all(42) diff --git a/tests/audio/test_stoi.py b/tests/audio/test_stoi.py index 9f98bc9b5ed..cd4192e83d7 100644 --- a/tests/audio/test_stoi.py +++ b/tests/audio/test_stoi.py @@ -21,8 +21,8 @@ from tests.helpers import seed_all from tests.helpers.testers import MetricTester -from torchmetrics.audio import STOI -from torchmetrics.functional import stoi +from torchmetrics.audio.stoi import STOI +from torchmetrics.functional.audio.stoi import stoi from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 seed_all(42) diff --git a/tests/text/test_bertscore.py b/tests/text/test_bertscore.py index 8bcdf69a6a0..fe707fa8994 100644 --- a/tests/text/test_bertscore.py +++ b/tests/text/test_bertscore.py @@ -7,8 +7,8 @@ import torch.distributed as dist import torch.multiprocessing as mp -from torchmetrics.functional import bert_score as metrics_bert_score -from torchmetrics.text import BERTScore +from torchmetrics.functional.text.bert import bert_score as metrics_bert_score +from torchmetrics.text.bert import BERTScore from torchmetrics.utilities.imports import _BERTSCORE_AVAILABLE if _BERTSCORE_AVAILABLE: From ea070b0e67bc7d588edfbc9ca35e14ed15b144eb Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 25 Nov 2021 14:55:04 +0100 Subject: [PATCH 28/31] changelog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c636414acba..92032761926 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Removed +- Removed `embedding_similarity` metric ([#638](https://github.com/PyTorchLightning/metrics/pull/638)) +- Removed argument `concatenate_texts` from `wer` metric ([#638](https://github.com/PyTorchLightning/metrics/pull/638)) +- Removed arguments `newline_sep` and `decimal_places` from `rouge` metric ([#638](https://github.com/PyTorchLightning/metrics/pull/638)) + ### Fixed From fe2b16934d7ba9d4a8e1a4f045dcca448eee6690 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 25 Nov 2021 14:57:32 +0100 Subject: [PATCH 29/31] revert --- CHANGELOG.md | 4 ---- 1 file changed, 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 92032761926..c636414acba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,10 +36,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Removed -- Removed `embedding_similarity` metric ([#638](https://github.com/PyTorchLightning/metrics/pull/638)) -- Removed argument `concatenate_texts` from `wer` metric ([#638](https://github.com/PyTorchLightning/metrics/pull/638)) -- Removed arguments `newline_sep` and `decimal_places` from `rouge` metric ([#638](https://github.com/PyTorchLightning/metrics/pull/638)) - ### Fixed From 8f1e40674a2f9f969676ed3882454239bbe84d38 Mon Sep 17 00:00:00 2001 From: Jirka Date: Mon, 6 Dec 2021 11:48:00 +0100 Subject: [PATCH 30/31] skip --- torchmetrics/detection/map.py | 9 ++++----- torchmetrics/functional/text/rouge.py | 4 ++-- torchmetrics/text/rouge.py | 2 +- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/torchmetrics/detection/map.py b/torchmetrics/detection/map.py index ccc376d3c67..c1cf6c96572 100644 --- a/torchmetrics/detection/map.py +++ b/torchmetrics/detection/map.py @@ -180,9 +180,8 @@ class MAP(Metric): will be used to perform the allgather Example: - >>> from torchmetrics.detection.map import MAP - >>> from pprint import pprint >>> import torch + >>> from torchmetrics.detection.map import MAP >>> preds = [ ... dict( ... boxes=torch.Tensor([[258.0, 41.0, 606.0, 285.0]]), @@ -196,9 +195,9 @@ class MAP(Metric): ... labels=torch.IntTensor([0]), ... ) ... ] - >>> metric = MAP() - >>> metric.update(preds, target) - >>> pprint(metric.compute()) # doctest: +NORMALIZE_WHITESPACE +SKIP + >>> metric = MAP() # doctest: +SKIP + >>> metric.update(preds, target) # doctest: +SKIP + >>> pprint(metric.compute()) # doctest: +SKIP {'map': tensor(0.6000), 'map_50': tensor(1.), 'map_75': tensor(1.), diff --git a/torchmetrics/functional/text/rouge.py b/torchmetrics/functional/text/rouge.py index 17e673c3d43..efaf66a8b04 100644 --- a/torchmetrics/functional/text/rouge.py +++ b/torchmetrics/functional/text/rouge.py @@ -180,7 +180,7 @@ def _rouge_score_update( >>> preds = "My name is John".split() >>> from pprint import pprint >>> score = _rouge_score_update(preds, targets, rouge_keys_values=[1, 2, 3, 'L']) - >>> pprint(score) # doctest: +NORMALIZE_WHITESPACE +SKIP + >>> pprint(score) # doctest: +SKIP {1: [{'fmeasure': tensor(0.), 'precision': tensor(0.), 'recall': tensor(0.)}, {'fmeasure': tensor(0.), 'precision': tensor(0.), 'recall': tensor(0.)}, {'fmeasure': tensor(0.), 'precision': tensor(0.), 'recall': tensor(0.)}, @@ -265,7 +265,7 @@ def rouge_score( >>> targets = "Is your name John" >>> preds = "My name is John" >>> from pprint import pprint - >>> pprint(rouge_score(preds, targets)) # doctest: +NORMALIZE_WHITESPACE +SKIP + >>> pprint(rouge_score(preds, targets)) # doctest: +SKIP {'rouge1_fmeasure': 0.25, 'rouge1_precision': 0.25, 'rouge1_recall': 0.25, diff --git a/torchmetrics/text/rouge.py b/torchmetrics/text/rouge.py index f24d66a95ad..0dc9875c9c7 100644 --- a/torchmetrics/text/rouge.py +++ b/torchmetrics/text/rouge.py @@ -47,7 +47,7 @@ class ROUGEScore(Metric): >>> preds = "My name is John" >>> rouge = ROUGEScore() # doctest: +SKIP >>> from pprint import pprint - >>> pprint(rouge(preds, targets)) # doctest: +NORMALIZE_WHITESPACE +SKIP + >>> pprint(rouge(preds, targets)) # doctest: +SKIP {'rouge1_fmeasure': 0.25, 'rouge1_precision': 0.25, 'rouge1_recall': 0.25, From 64d87022033094ec247af455fef96299b8d12059 Mon Sep 17 00:00:00 2001 From: Jirka Date: Mon, 6 Dec 2021 11:51:29 +0100 Subject: [PATCH 31/31] skip --- torchmetrics/detection/map.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchmetrics/detection/map.py b/torchmetrics/detection/map.py index c1cf6c96572..4d18f5be2cd 100644 --- a/torchmetrics/detection/map.py +++ b/torchmetrics/detection/map.py @@ -197,6 +197,7 @@ class MAP(Metric): ... ] >>> metric = MAP() # doctest: +SKIP >>> metric.update(preds, target) # doctest: +SKIP + >>> from pprint import pprint >>> pprint(metric.compute()) # doctest: +SKIP {'map': tensor(0.6000), 'map_50': tensor(1.),