diff --git a/.github/assistant.py b/.github/assistant.py index 3f49ccb0458..0dc82b17eb5 100644 --- a/.github/assistant.py +++ b/.github/assistant.py @@ -59,7 +59,7 @@ def prune_packages(req_file: str, *pkgs: str) -> None: lines = [ln for ln in lines if not ln.startswith(pkg)] logging.info(lines) - with open(req_file, "w") as fp: + with open(req_file, "w", encoding="utf-8") as fp: fp.writelines(lines) @staticmethod @@ -71,17 +71,17 @@ def set_min_torch_by_python(fpath: str = "requirements.txt") -> None: with open(fpath) as fp: req = fp.read() req = re.sub(r"torch>=[\d\.]+", f"torch>={LUT_PYTHON_TORCH[py_ver]}", req) - with open(fpath, "w") as fp: + with open(fpath, "w", encoding="utf-8") as fp: fp.write(req) @staticmethod def replace_min_requirements(fpath: str) -> None: """Replace all `>=` by `==` in given file.""" logging.info(f"processing: {fpath}") - with open(fpath) as fp: + with open(fpath, encoding="utf-8") as fp: req = fp.read() req = req.replace(">=", "==") - with open(fpath, "w") as fp: + with open(fpath, "w", encoding="utf-8") as fp: fp.write(req) @staticmethod diff --git a/CHANGELOG.md b/CHANGELOG.md index b3297261ac0..b5137f2b3cb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -41,6 +41,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Removed deprecated `compute_on_step` argument in Regression ([#967](https://github.com/PyTorchLightning/metrics/pull/967)) +- Removed deprecated `compute_on_step` argument in Image ([#979](https://github.com/PyTorchLightning/metrics/pull/979)) + + ### Fixed - Fixed "Sort currently does not support bool dtype on CUDA" error in MAP for empty preds ([#983](https://github.com/PyTorchLightning/metrics/pull/983)) @@ -49,7 +52,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed `BinnedPrecisionRecallCurve` when `thresholds` argument is not provided ([#968](https://github.com/PyTorchLightning/metrics/pull/968)) -- Fixed `CalibrationError` to work on logit input ([]()) +- Fixed `CalibrationError` to work on logit input ([#985](https://github.com/PyTorchLightning/metrics/pull/985)) ## [0.8.0] - 2022-04-14 diff --git a/docs/source/conf.py b/docs/source/conf.py index 6db900fa9aa..dfd10440221 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -65,7 +65,7 @@ def _transform_changelog(path_in: str, path_out: str) -> None: elif ln.startswith("### "): ln = ln.replace("###", f"### {chlog_ver} -") chlog_lines[i] = ln - with open(path_out, "w") as fp: + with open(path_out, "w", encoding="utf-8") as fp: fp.writelines(chlog_lines) diff --git a/tests/audio/test_sdr.py b/tests/audio/test_sdr.py index 15e31d80664..6f49caddded 100644 --- a/tests/audio/test_sdr.py +++ b/tests/audio/test_sdr.py @@ -87,7 +87,7 @@ def test_sdr(self, preds, target, sk_metric, ddp, dist_sync_on_step): SignalDistortionRatio, sk_metric=partial(average_metric, metric_func=sk_metric), dist_sync_on_step=dist_sync_on_step, - metric_args=dict(), + metric_args={}, ) def test_sdr_functional(self, preds, target, sk_metric): @@ -96,7 +96,7 @@ def test_sdr_functional(self, preds, target, sk_metric): target, signal_distortion_ratio, sk_metric, - metric_args=dict(), + metric_args={}, ) def test_sdr_differentiability(self, preds, target, sk_metric): @@ -105,7 +105,7 @@ def test_sdr_differentiability(self, preds, target, sk_metric): target=target, metric_module=SignalDistortionRatio, metric_functional=signal_distortion_ratio, - metric_args=dict(), + metric_args={}, ) @pytest.mark.skipif( @@ -117,7 +117,7 @@ def test_sdr_half_cpu(self, preds, target, sk_metric): target=target, metric_module=SignalDistortionRatio, metric_functional=signal_distortion_ratio, - metric_args=dict(), + metric_args={}, ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") @@ -127,7 +127,7 @@ def test_sdr_half_gpu(self, preds, target, sk_metric): target=target, metric_module=SignalDistortionRatio, metric_functional=signal_distortion_ratio, - metric_args=dict(), + metric_args={}, ) diff --git a/tests/detection/test_map.py b/tests/detection/test_map.py index 388d77f5a46..f63e3d80814 100644 --- a/tests/detection/test_map.py +++ b/tests/detection/test_map.py @@ -377,7 +377,7 @@ def test_error_on_wrong_input(): metric.update([], torch.Tensor()) # type: ignore with pytest.raises(ValueError, match="Expected argument `preds` and `target` to have the same length"): - metric.update([dict()], [dict(), dict()]) + metric.update([{}], [{}, {}]) with pytest.raises(ValueError, match="Expected all dicts in `preds` to contain the `boxes` key"): metric.update( diff --git a/torchmetrics/functional/classification/auroc.py b/torchmetrics/functional/classification/auroc.py index 548ee9af2f6..ecc1c7eac72 100644 --- a/torchmetrics/functional/classification/auroc.py +++ b/torchmetrics/functional/classification/auroc.py @@ -111,9 +111,8 @@ def _auroc_compute( # max_fpr parameter is only support for binary if mode != DataType.BINARY: raise ValueError( - f"Partial AUC computation not available in" - f" multilabel/multiclass setting, 'max_fpr' must be" - f" set to `None`, received `{max_fpr}`." + "Partial AUC computation not available in multilabel/multiclass setting," + f" 'max_fpr' must be set to `None`, received `{max_fpr}`." ) # calculate fpr, tpr @@ -172,7 +171,7 @@ def _auroc_compute( allowed_average = (AverageMethod.NONE.value, AverageMethod.MACRO.value, AverageMethod.WEIGHTED.value) raise ValueError( - f"Argument `average` expected to be one of the following:" f" {allowed_average} but got {average}" + f"Argument `average` expected to be one of the following: {allowed_average} but got {average}" ) return _auc_compute_without_check(fpr, tpr, 1.0) diff --git a/torchmetrics/functional/text/rouge.py b/torchmetrics/functional/text/rouge.py index d0e812c8312..1f3eb5fb950 100644 --- a/torchmetrics/functional/text/rouge.py +++ b/torchmetrics/functional/text/rouge.py @@ -458,7 +458,7 @@ def rouge_score( stemmer = nltk.stem.porter.PorterStemmer() if use_stemmer else None if not isinstance(rouge_keys, tuple): - rouge_keys = tuple([rouge_keys]) + rouge_keys = (rouge_keys,) for key in rouge_keys: if key not in ALLOWED_ROUGE_KEYS.keys(): raise ValueError(f"Got unknown rouge key {key}. Expected to be one of {list(ALLOWED_ROUGE_KEYS.keys())}") diff --git a/torchmetrics/functional/text/squad.py b/torchmetrics/functional/text/squad.py index f0f2a3e9ee0..9e1065ec701 100644 --- a/torchmetrics/functional/text/squad.py +++ b/torchmetrics/functional/text/squad.py @@ -133,7 +133,7 @@ def _squad_input_check( _fn_answer = lambda tgt: dict( answers=[dict(text=txt) for txt in tgt["answers"]["text"]], id=tgt["id"] # type: ignore ) - targets_dict = [dict(paragraphs=[dict(qas=[_fn_answer(target) for target in targets])])] + targets_dict = [{"paragraphs": [{"qas": [_fn_answer(target) for target in targets]}]}] return preds_dict, targets_dict diff --git a/torchmetrics/image/fid.py b/torchmetrics/image/fid.py index 22a902eafb4..94255e77985 100644 --- a/torchmetrics/image/fid.py +++ b/torchmetrics/image/fid.py @@ -148,10 +148,6 @@ class FrechetInceptionDistance(Metric): is installed. Either install as ``pip install torchmetrics[image]`` or ``pip install torch-fidelity`` - .. note:: the ``forward`` method can be used but ``compute_on_step`` is disabled by default (oppesit of - all other metrics) as this metric does not really make sense to calculate on a single batch. This - means that by default ``forward`` will just call ``update`` underneat. - Args: feature: Either an integer or ``nn.Module``: @@ -164,13 +160,6 @@ class FrechetInceptionDistance(Metric): reset_real_features: Whether to also reset the real features. Since in many cases the real dataset does not change, the features can cached them to avoid recomputing them which is costly. Set this to ``False`` if your dataset does not change. - - compute_on_step: - Forward only calls ``update()`` and returns None if this is set to False. - - .. deprecated:: v0.8 - Argument has no use anymore and will be removed v0.9. - kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. References: @@ -215,10 +204,9 @@ def __init__( self, feature: Union[int, torch.nn.Module] = 2048, reset_real_features: bool = True, - compute_on_step: Optional[bool] = None, **kwargs: Dict[str, Any], ) -> None: - super().__init__(compute_on_step=compute_on_step, **kwargs) + super().__init__(**kwargs) rank_zero_warn( "Metric `FrechetInceptionDistance` will save all extracted features in buffer." diff --git a/torchmetrics/image/inception.py b/torchmetrics/image/inception.py index cb9550d91e1..4df77dc0814 100644 --- a/torchmetrics/image/inception.py +++ b/torchmetrics/image/inception.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Tuple, Union import torch from torch import Tensor @@ -46,10 +46,6 @@ class InceptionScore(Metric): is installed. Either install as ``pip install torchmetrics[image]`` or ``pip install torch-fidelity`` - .. note:: the ``forward`` method can be used but ``compute_on_step`` is disabled by default (oppesit of - all other metrics) as this metric does not really make sense to calculate on a single batch. This - means that by default ``forward`` will just call ``update`` underneat. - Args: feature: Either an str, integer or ``nn.Module``: @@ -60,13 +56,6 @@ class InceptionScore(Metric): an ``[N,d]`` matrix where ``N`` is the batch size and ``d`` is the feature size. splits: integer determining how many splits the inception score calculation should be split among - - compute_on_step: - Forward only calls ``update()`` and returns None if this is set to False. - - .. deprecated:: v0.8 - Argument has no use anymore and will be removed v0.9. - kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. References: @@ -105,10 +94,9 @@ def __init__( self, feature: Union[str, int, torch.nn.Module] = "logits_unbiased", splits: int = 10, - compute_on_step: Optional[bool] = None, **kwargs: Dict[str, Any], ) -> None: - super().__init__(compute_on_step=compute_on_step, **kwargs) + super().__init__(**kwargs) rank_zero_warn( "Metric `InceptionScore` will save all extracted features in buffer." diff --git a/torchmetrics/image/kid.py b/torchmetrics/image/kid.py index b071e791b49..ece20ea7a6f 100644 --- a/torchmetrics/image/kid.py +++ b/torchmetrics/image/kid.py @@ -89,10 +89,6 @@ class KernelInceptionDistance(Metric): is installed. Either install as ``pip install torchmetrics[image]`` or ``pip install torch-fidelity`` - .. note:: the ``forward`` method can be used but ``compute_on_step`` is disabled by default (oppesit of - all other metrics) as this metric does not really make sense to calculate on a single batch. This - means that by default ``forward`` will just call ``update`` underneat. - Args: feature: Either an str, integer or ``nn.Module``: @@ -109,12 +105,6 @@ class KernelInceptionDistance(Metric): reset_real_features: Whether to also reset the real features. Since in many cases the real dataset does not change, the features can cached them to avoid recomputing them which is costly. Set this to ``False`` if your dataset does not change. - compute_on_step: - Forward only calls ``update()`` and returns None if this is set to False. - - .. deprecated:: v0.8 - Argument has no use anymore and will be removed v0.9. - kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. References: @@ -173,10 +163,9 @@ def __init__( gamma: Optional[float] = None, # type: ignore coef: float = 1.0, reset_real_features: bool = True, - compute_on_step: Optional[bool] = None, **kwargs: Dict[str, Any], ) -> None: - super().__init__(compute_on_step=compute_on_step, **kwargs) + super().__init__(**kwargs) rank_zero_warn( "Metric `Kernel Inception Distance` will save all extracted features in buffer." diff --git a/torchmetrics/image/lpip.py b/torchmetrics/image/lpip.py index 09ba82d852d..54c920b89d5 100644 --- a/torchmetrics/image/lpip.py +++ b/torchmetrics/image/lpip.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List import torch from torch import Tensor @@ -59,12 +59,6 @@ class LearnedPerceptualImagePatchSimilarity(Metric): Args: net_type: str indicating backbone network type to use. Choose between `'alex'`, `'vgg'` or `'squeeze'` reduction: str indicating how to reduce over the batch dimension. Choose between `'sum'` or `'mean'`. - compute_on_step: - Forward only calls ``update()`` and returns None if this is set to False. - - .. deprecated:: v0.8 - Argument has no use anymore and will be removed v0.9. - kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Raises: @@ -98,10 +92,9 @@ def __init__( self, net_type: str = "alex", reduction: Literal["sum", "mean"] = "mean", - compute_on_step: Optional[bool] = None, **kwargs: Dict[str, Any], ) -> None: - super().__init__(compute_on_step=compute_on_step, **kwargs) + super().__init__(**kwargs) if not _LPIPS_AVAILABLE: raise ModuleNotFoundError( diff --git a/torchmetrics/image/psnr.py b/torchmetrics/image/psnr.py index 18b89452045..d2f2a7a8ccf 100644 --- a/torchmetrics/image/psnr.py +++ b/torchmetrics/image/psnr.py @@ -44,12 +44,6 @@ class PeakSignalNoiseRatio(Metric): dim: Dimensions to reduce PSNR scores over, provided as either an integer or a list of integers. Default is None meaning scores will be reduced across all dimensions and all batches. - compute_on_step: - Forward only calls ``update()`` and returns None if this is set to False. - - .. deprecated:: v0.8 - Argument has no use anymore and will be removed v0.9. - kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Raises: @@ -78,10 +72,9 @@ def __init__( base: float = 10.0, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", dim: Optional[Union[int, Tuple[int, ...]]] = None, - compute_on_step: Optional[bool] = None, **kwargs: Dict[str, Any], ) -> None: - super().__init__(compute_on_step=compute_on_step, **kwargs) + super().__init__(**kwargs) if dim is None and reduction != "elementwise_mean": rank_zero_warn(f"The `reduction={reduction}` will not have any effect when `dim` is None.") diff --git a/torchmetrics/image/ssim.py b/torchmetrics/image/ssim.py index 1302e39dc72..5042f5a744d 100644 --- a/torchmetrics/image/ssim.py +++ b/torchmetrics/image/ssim.py @@ -47,12 +47,6 @@ class StructuralSimilarityIndexMeasure(Metric): return_contrast_sensitivity: If true, the constant term is returned as a second argument. The luminance term can be obtained with luminance=ssim/contrast Mutually exclusive with ``return_full_image`` - compute_on_step: - Forward only calls ``update()`` and returns None if this is set to False. - - .. deprecated:: v0.8 - Argument has no use anymore and will be removed v0.9. - kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Return: @@ -81,12 +75,11 @@ def __init__( data_range: Optional[float] = None, k1: float = 0.01, k2: float = 0.03, - compute_on_step: Optional[bool] = None, return_full_image: bool = False, return_contrast_sensitivity: bool = False, **kwargs: Dict[str, Any], ) -> None: - super().__init__(compute_on_step=compute_on_step, **kwargs) + super().__init__(**kwargs) rank_zero_warn( "Metric `SSIM` will save all targets and" " predictions in buffer. For large datasets this may lead" @@ -157,12 +150,6 @@ class MultiScaleStructuralSimilarityIndexMeasure(Metric): normalize: When MultiScaleStructuralSimilarityIndexMeasure loss is used for training, it is desirable to use normalizes to improve the training stability. This `normalize` argument is out of scope of the original implementation [1], and it is adapted from https://github.com/jorge-pessoa/pytorch-msssim instead. - compute_on_step: - Forward only calls ``update()`` and returns None if this is set to False. - - .. deprecated:: v0.8 - Argument has no use anymore and will be removed v0.9. - kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Return: @@ -206,10 +193,9 @@ def __init__( k2: float = 0.03, betas: Tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333), normalize: Literal["relu", "simple", None] = None, - compute_on_step: Optional[bool] = None, **kwargs: Dict[str, Any], ) -> None: - super().__init__(compute_on_step=compute_on_step, **kwargs) + super().__init__(**kwargs) rank_zero_warn( "Metric `MS_SSIM` will save all targets and" " predictions in buffer. For large datasets this may lead" @@ -219,7 +205,7 @@ def __init__( self.add_state("preds", default=[], dist_reduce_fx="cat") self.add_state("target", default=[], dist_reduce_fx="cat") - if not (isinstance(kernel_size, Sequence) or isinstance(kernel_size, int)): + if not (isinstance(kernel_size, (Sequence, int))): raise ValueError( f"Argument `kernel_size` expected to be an sequence or an int, or a single int. Got {kernel_size}" ) diff --git a/torchmetrics/image/uqi.py b/torchmetrics/image/uqi.py index bd4d0f93981..904df2b501d 100644 --- a/torchmetrics/image/uqi.py +++ b/torchmetrics/image/uqi.py @@ -35,12 +35,6 @@ class UniversalImageQualityIndex(Metric): - ``'none'`` or ``None``: no reduction will be applied data_range: Range of the image. If ``None``, it is determined from the image (max - min) - compute_on_step: - Forward only calls ``update()`` and returns None if this is set to False. - - .. deprecated:: v0.8 - Argument has no use anymore and will be removed v0.9. - kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. @@ -67,10 +61,9 @@ def __init__( sigma: Sequence[float] = (1.5, 1.5), reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", data_range: Optional[float] = None, - compute_on_step: Optional[bool] = None, **kwargs: Dict[str, Any], ) -> None: - super().__init__(compute_on_step=compute_on_step, **kwargs) + super().__init__(**kwargs) rank_zero_warn( "Metric `UniversalImageQualityIndex` will save all targets and" " predictions in buffer. For large datasets this may lead" diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 37db9015f92..5e4484ef119 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -583,7 +583,7 @@ def _filter_kwargs(self, **kwargs: Any) -> Dict[str, Any]: k: v for k, v in kwargs.items() if (k in _sign_params.keys() and _sign_params[k].kind not in _params) } - exists_var_keyword = any([v.kind == inspect.Parameter.VAR_KEYWORD for v in _sign_params.values()]) + exists_var_keyword = any(v.kind == inspect.Parameter.VAR_KEYWORD for v in _sign_params.values()) # if no kwargs filtered, return all kwargs as default if not filtered_kwargs and not exists_var_keyword: # no kwargs in update signature -> don't return any kwargs diff --git a/torchmetrics/text/rouge.py b/torchmetrics/text/rouge.py index faa14f92855..3fda606844a 100644 --- a/torchmetrics/text/rouge.py +++ b/torchmetrics/text/rouge.py @@ -109,7 +109,7 @@ def __init__( import nltk if not isinstance(rouge_keys, tuple): - rouge_keys = tuple([rouge_keys]) + rouge_keys = (rouge_keys,) for key in rouge_keys: if key not in ALLOWED_ROUGE_KEYS: raise ValueError(f"Got unknown rouge key {key}. Expected to be one of {ALLOWED_ROUGE_KEYS}") diff --git a/torchmetrics/utilities/checks.py b/torchmetrics/utilities/checks.py index cad5d73b0f9..f688a873c80 100644 --- a/torchmetrics/utilities/checks.py +++ b/torchmetrics/utilities/checks.py @@ -243,7 +243,7 @@ def _check_classification_inputs( either from the shape of inputs, or the maximum label in the ``target`` and ``preds`` tensor, where applicable. top_k: - Number of highest probability entries for each sample to convert to 1s - relevant + Number of the highest probability entries for each sample to convert to 1s - relevant only for inputs with probability predictions. The default value (``None``) will be interpreted as 1 for these inputs. If this parameter is set for multi-label inputs, it will take precedence over threshold. @@ -342,7 +342,7 @@ def _input_format_classification( In binary case, targets are normally returned as ``(N,1)`` tensor, while preds are transformed into a binary tensor (elements become 1 if the probability is greater than or equal to - ``threshold`` or 0 otherwise). If ``multiclass=True``, then then both targets are preds + ``threshold`` or 0 otherwise). If ``multiclass=True``, then both targets are preds become ``(N, 2)`` tensors by a one-hot transformation; with the thresholding being applied to preds first. @@ -461,7 +461,7 @@ def _input_format_classification_one_hot( Args: num_classes: number of classes preds: either tensor with labels, tensor with probabilities/logits or multilabel tensor - target: tensor with ground true labels + target: tensor with ground-true labels threshold: float used for thresholding multilabel input multilabel: boolean flag indicating if input is multilabel @@ -503,7 +503,7 @@ def _check_retrieval_functional_inputs( target: Tensor, allow_non_binary_target: bool = False, ) -> Tuple[Tensor, Tensor]: - """Check ``preds`` and ``target`` tensors are of the same shape and of the correct dtype. + """Check ``preds`` and ``target`` tensors are of the same shape and of the correct data type. Args: preds: either tensor with scores/logits @@ -535,7 +535,7 @@ def _check_retrieval_inputs( allow_non_binary_target: bool = False, ignore_index: Optional[int] = None, ) -> Tuple[Tensor, Tensor, Tensor]: - """Check ``indexes``, ``preds`` and ``target`` tensors are of the same shape and of the correct dtype. + """Check ``indexes``, ``preds`` and ``target`` tensors are of the same shape and of the correct data type. Args: indexes: tensor with queries indexes @@ -580,7 +580,7 @@ def _check_retrieval_target_and_prediction_types( target: Tensor, allow_non_binary_target: bool = False, ) -> Tuple[Tensor, Tensor]: - """Check ``preds`` and ``target`` tensors are of the same shape and of the correct dtype. + """Check ``preds`` and ``target`` tensors are of the same shape and of the correct data type. Args: preds: either tensor with scores/logits diff --git a/torchmetrics/utilities/data.py b/torchmetrics/utilities/data.py index 003f5a66d9c..7ed5242c2b1 100644 --- a/torchmetrics/utilities/data.py +++ b/torchmetrics/utilities/data.py @@ -157,10 +157,10 @@ def apply_to_collection( data: the collection to apply the function to dtype: the given function will be applied to all elements of this dtype function: the function to apply - *args: positional arguments (will be forwarded to calls of ``function``) + *args: positional arguments (will be forwarded to call of ``function``) wrong_dtype: the given function won't be applied if this type is specified and the given collections is of the :attr:`wrong_type` even if it is of type :attr`dtype` - **kwargs: keyword arguments (will be forwarded to calls of ``function``) + **kwargs: keyword arguments (will be forwarded to call of ``function``) Returns: the resulting collection diff --git a/torchmetrics/utilities/distributed.py b/torchmetrics/utilities/distributed.py index f33f864bb04..46517ac8ac3 100644 --- a/torchmetrics/utilities/distributed.py +++ b/torchmetrics/utilities/distributed.py @@ -19,11 +19,11 @@ from typing_extensions import Literal -def reduce(to_reduce: Tensor, reduction: Literal["elementwise_mean", "sum", "none", None]) -> Tensor: +def reduce(x: Tensor, reduction: Literal["elementwise_mean", "sum", "none", None]) -> Tensor: """Reduces a given tensor by a given reduction method. Args: - to_reduce: the tensor, which shall be reduced + x: the tensor, which shall be reduced reduction: a string specifying the reduction method ('elementwise_mean', 'none', 'sum') Return: @@ -33,11 +33,11 @@ def reduce(to_reduce: Tensor, reduction: Literal["elementwise_mean", "sum", "non ValueError if an invalid reduction parameter was given """ if reduction == "elementwise_mean": - return torch.mean(to_reduce) + return torch.mean(x) if reduction == "none" or reduction is None: - return to_reduce + return x if reduction == "sum": - return torch.sum(to_reduce) + return torch.sum(x) raise ValueError("Reduction parameter unknown.") diff --git a/torchmetrics/utilities/imports.py b/torchmetrics/utilities/imports.py index 29e596717d6..6598eb5ea86 100644 --- a/torchmetrics/utilities/imports.py +++ b/torchmetrics/utilities/imports.py @@ -14,6 +14,7 @@ """Import utilities.""" import operator from collections import OrderedDict # noqa: F401 +from functools import lru_cache from importlib import import_module from importlib.util import find_spec from typing import Callable, Optional @@ -22,6 +23,7 @@ from pkg_resources import DistributionNotFound, get_distribution +@lru_cache() def _package_available(package_name: str) -> bool: """Check if a package is available in your environment. @@ -40,6 +42,7 @@ def _package_available(package_name: str) -> bool: return False +@lru_cache() def _module_available(module_path: str) -> bool: """Check if a module path is available in your environment. @@ -64,6 +67,7 @@ def _module_available(module_path: str) -> bool: return True +@lru_cache() def _compare_version(package: str, op: Callable, version: str) -> Optional[bool]: """Compare package version with some requirements.