diff --git a/tests/bases/test_metric.py b/tests/bases/test_metric.py index 30a92f96d20..55342eeb2d6 100644 --- a/tests/bases/test_metric.py +++ b/tests/bases/test_metric.py @@ -20,7 +20,8 @@ import psutil import pytest import torch -from torch import Tensor, nn, tensor +from torch import Tensor, tensor +from torch.nn import Module from tests.helpers import seed_all from tests.helpers.testers import DummyListMetric, DummyMetric, DummyMetricMultiOutput, DummyMetricSum @@ -245,7 +246,7 @@ def test_load_state_dict(tmpdir): def test_child_metric_state_dict(): """test that child metric states will be added to parent state dict.""" - class TestModule(nn.Module): + class TestModule(Module): def __init__(self): super().__init__() self.metric = DummyMetric() @@ -346,7 +347,7 @@ def test_forward_and_compute_to_device(metric_class): def test_device_if_child_module(metric_class): """Test that if a metric is a child module all values gets moved to the correct device.""" - class TestModule(nn.Module): + class TestModule(Module): def __init__(self): super().__init__() self.metric = metric_class() diff --git a/tests/detection/test_map.py b/tests/detection/test_map.py index f63e3d80814..5c2a35da4d6 100644 --- a/tests/detection/test_map.py +++ b/tests/detection/test_map.py @@ -16,6 +16,7 @@ import pytest import torch +from torch import Tensor from tests.helpers.testers import MetricTester from torchmetrics.detection.mean_ap import MeanAveragePrecision @@ -27,19 +28,19 @@ preds=[ [ dict( - boxes=torch.Tensor([[258.15, 41.29, 606.41, 285.07]]), - scores=torch.Tensor([0.236]), + boxes=Tensor([[258.15, 41.29, 606.41, 285.07]]), + scores=Tensor([0.236]), labels=torch.IntTensor([4]), ), # coco image id 42 dict( - boxes=torch.Tensor([[61.00, 22.75, 565.00, 632.42], [12.66, 3.32, 281.26, 275.23]]), - scores=torch.Tensor([0.318, 0.726]), + boxes=Tensor([[61.00, 22.75, 565.00, 632.42], [12.66, 3.32, 281.26, 275.23]]), + scores=Tensor([0.318, 0.726]), labels=torch.IntTensor([3, 2]), ), # coco image id 73 ], [ dict( - boxes=torch.Tensor( + boxes=Tensor( [ [87.87, 276.25, 384.29, 379.43], [0.00, 3.66, 142.15, 316.06], @@ -50,12 +51,12 @@ [276.11, 103.84, 291.44, 150.72], ] ), - scores=torch.Tensor([0.546, 0.3, 0.407, 0.611, 0.335, 0.805, 0.953]), + scores=Tensor([0.546, 0.3, 0.407, 0.611, 0.335, 0.805, 0.953]), labels=torch.IntTensor([4, 1, 0, 0, 0, 0, 0]), ), # coco image id 74 dict( - boxes=torch.Tensor([[0.00, 2.87, 601.00, 421.52]]), - scores=torch.Tensor([0.699]), + boxes=Tensor([[0.00, 2.87, 601.00, 421.52]]), + scores=Tensor([0.699]), labels=torch.IntTensor([5]), ), # coco image id 133 ], @@ -63,11 +64,11 @@ target=[ [ dict( - boxes=torch.Tensor([[214.1500, 41.2900, 562.4100, 285.0700]]), + boxes=Tensor([[214.1500, 41.2900, 562.4100, 285.0700]]), labels=torch.IntTensor([4]), ), # coco image id 42 dict( - boxes=torch.Tensor( + boxes=Tensor( [ [13.00, 22.75, 548.98, 632.42], [1.66, 3.32, 270.26, 275.23], @@ -78,7 +79,7 @@ ], [ dict( - boxes=torch.Tensor( + boxes=Tensor( [ [61.87, 276.25, 358.29, 379.43], [2.75, 3.66, 162.15, 316.06], @@ -92,7 +93,7 @@ labels=torch.IntTensor([4, 1, 0, 0, 0, 0, 0]), ), # coco image id 74 dict( - boxes=torch.Tensor([[13.99, 2.87, 640.00, 421.52]]), + boxes=Tensor([[13.99, 2.87, 640.00, 421.52]]), labels=torch.IntTensor([5]), ), # coco image id 133 ], @@ -104,15 +105,15 @@ preds=[ [ dict( - boxes=torch.Tensor([[258.0, 41.0, 606.0, 285.0]]), - scores=torch.Tensor([0.536]), + boxes=Tensor([[258.0, 41.0, 606.0, 285.0]]), + scores=Tensor([0.536]), labels=torch.IntTensor([0]), ), ], [ dict( - boxes=torch.Tensor([[258.0, 41.0, 606.0, 285.0]]), - scores=torch.Tensor([0.536]), + boxes=Tensor([[258.0, 41.0, 606.0, 285.0]]), + scores=Tensor([0.536]), labels=torch.IntTensor([0]), ) ], @@ -120,13 +121,13 @@ target=[ [ dict( - boxes=torch.Tensor([[214.0, 41.0, 562.0, 285.0]]), + boxes=Tensor([[214.0, 41.0, 562.0, 285.0]]), labels=torch.IntTensor([0]), ) ], [ dict( - boxes=torch.Tensor([]), + boxes=Tensor([]), labels=torch.IntTensor([]), ) ], @@ -196,20 +197,20 @@ def _compare_fn(preds, target) -> dict: Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.900 """ return { - "map": torch.Tensor([0.706]), - "map_50": torch.Tensor([0.901]), - "map_75": torch.Tensor([0.846]), - "map_small": torch.Tensor([0.689]), - "map_medium": torch.Tensor([0.800]), - "map_large": torch.Tensor([0.701]), - "mar_1": torch.Tensor([0.592]), - "mar_10": torch.Tensor([0.716]), - "mar_100": torch.Tensor([0.716]), - "mar_small": torch.Tensor([0.767]), - "mar_medium": torch.Tensor([0.800]), - "mar_large": torch.Tensor([0.700]), - "map_per_class": torch.Tensor([0.725, 0.800, 0.454, -1.000, 0.650, 0.900]), - "mar_100_per_class": torch.Tensor([0.780, 0.800, 0.450, -1.000, 0.650, 0.900]), + "map": Tensor([0.706]), + "map_50": Tensor([0.901]), + "map_75": Tensor([0.846]), + "map_small": Tensor([0.689]), + "map_medium": Tensor([0.800]), + "map_large": Tensor([0.701]), + "mar_1": Tensor([0.592]), + "mar_10": Tensor([0.716]), + "mar_100": Tensor([0.716]), + "mar_small": Tensor([0.767]), + "mar_medium": Tensor([0.800]), + "mar_large": Tensor([0.700]), + "map_per_class": Tensor([0.725, 0.800, 0.454, -1.000, 0.650, 0.900]), + "mar_100_per_class": Tensor([0.780, 0.800, 0.450, -1.000, 0.650, 0.900]), } @@ -260,7 +261,7 @@ def test_empty_preds(): metric.update( [ - dict(boxes=torch.Tensor([]), scores=torch.Tensor([]), labels=torch.IntTensor([])), + dict(boxes=Tensor([]), scores=torch.Tensor([]), labels=torch.IntTensor([])), ], [ dict(boxes=torch.Tensor([[214.1500, 41.2900, 562.4100, 285.0700]]), labels=torch.IntTensor([4])), diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index a33b59a8663..aeda9a95c98 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -178,7 +178,7 @@ def _class_test( batch_result = metric(preds[i], target[i], **batch_kwargs_update) if metric.dist_sync_on_step and check_dist_sync_on_step and rank == 0: - if isinstance(preds, torch.Tensor): + if isinstance(preds, Tensor): ddp_preds = torch.cat([preds[i + r] for r in range(worldsize)]).cpu() ddp_target = torch.cat([target[i + r] for r in range(worldsize)]).cpu() else: @@ -201,8 +201,8 @@ def _class_test( k: v.cpu() if isinstance(v, Tensor) else v for k, v in (batch_kwargs_update if fragment_kwargs else kwargs_update).items() } - preds_ = preds[i].cpu() if isinstance(preds, torch.Tensor) else preds[i] - target_ = target[i].cpu() if isinstance(target, torch.Tensor) else target[i] + preds_ = preds[i].cpu() if isinstance(preds, Tensor) else preds[i] + target_ = target[i].cpu() if isinstance(target, Tensor) else target[i] sk_batch_result = sk_metric(preds_, target_, **batch_kwargs_update) if isinstance(batch_result, dict): for key in batch_result.keys(): @@ -221,7 +221,7 @@ def _class_test( else: _assert_tensor(result) - if isinstance(preds, torch.Tensor): + if isinstance(preds, Tensor): total_preds = torch.cat([preds[i] for i in range(num_batches)]).cpu() total_target = torch.cat([target[i] for i in range(num_batches)]).cpu() else: diff --git a/tests/image/test_fid.py b/tests/image/test_fid.py index dd44af83c50..dfabd378da4 100644 --- a/tests/image/test_fid.py +++ b/tests/image/test_fid.py @@ -16,6 +16,7 @@ import pytest import torch from scipy.linalg import sqrtm as scipy_sqrtm +from torch.nn import Module from torch.utils.data import Dataset from torchmetrics.image.fid import FrechetInceptionDistance, sqrtm @@ -44,7 +45,7 @@ def generate_cov(n): def test_no_train(): """Assert that metric never leaves evaluation mode.""" - class MyModel(torch.nn.Module): + class MyModel(Module): def __init__(self): super().__init__() self.metric = FrechetInceptionDistance() diff --git a/tests/image/test_inception.py b/tests/image/test_inception.py index e4985b3eac2..147dedd0d4e 100644 --- a/tests/image/test_inception.py +++ b/tests/image/test_inception.py @@ -15,6 +15,7 @@ import pytest import torch +from torch.nn import Module from torch.utils.data import Dataset from torchmetrics.image.inception import InceptionScore @@ -27,7 +28,7 @@ def test_no_train(): """Assert that metric never leaves evaluation mode.""" - class MyModel(torch.nn.Module): + class MyModel(Module): def __init__(self): super().__init__() self.metric = InceptionScore() diff --git a/tests/image/test_kid.py b/tests/image/test_kid.py index dca29cd1c97..263c1b55c86 100644 --- a/tests/image/test_kid.py +++ b/tests/image/test_kid.py @@ -15,6 +15,7 @@ import pytest import torch +from torch.nn import Module from torch.utils.data import Dataset from torchmetrics.image.kid import KernelInceptionDistance @@ -27,7 +28,7 @@ def test_no_train(): """Assert that metric never leaves evaluation mode.""" - class MyModel(torch.nn.Module): + class MyModel(Module): def __init__(self): super().__init__() self.metric = KernelInceptionDistance() diff --git a/tests/text/test_bertscore.py b/tests/text/test_bertscore.py index 1d899c1a3d4..874a232f5cd 100644 --- a/tests/text/test_bertscore.py +++ b/tests/text/test_bertscore.py @@ -6,6 +6,7 @@ import torch import torch.distributed as dist import torch.multiprocessing as mp +from torch import Tensor from tests.text.helpers import skip_on_connection_issues from torchmetrics.functional.text.bert import bert_score as metrics_bert_score @@ -45,7 +46,7 @@ def _assert_list(preds: Any, targets: Any, threshold: float = 1e-8): assert np.allclose(preds, targets, atol=threshold, equal_nan=True) -def _parse_original_bert_score(score: torch.Tensor) -> Dict[str, List[float]]: +def _parse_original_bert_score(score: Tensor) -> Dict[str, List[float]]: """Parse the BERT score returned by the original `bert-score` package.""" score_dict = {metric: value.tolist() for metric, value in zip(_METRICS, score)} return score_dict diff --git a/tests/wrappers/test_minmax.py b/tests/wrappers/test_minmax.py index c1b113f8843..5c0c11d562c 100644 --- a/tests/wrappers/test_minmax.py +++ b/tests/wrappers/test_minmax.py @@ -3,6 +3,7 @@ import pytest import torch +from torch import Tensor from tests.helpers import seed_all from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, MetricTester @@ -97,10 +98,10 @@ def test_basic_example(preds, labels, raws, maxs, mins) -> None: """tests that both min and max versions of MinMaxMetric operate correctly after calling compute.""" acc = Accuracy() min_max_acc = MinMaxMetric(acc) - labels = torch.Tensor(labels).long() + labels = Tensor(labels).long() for i in range(3): - preds_ = torch.Tensor(preds[i]) + preds_ = Tensor(preds[i]) min_max_acc(preds_, labels) acc = min_max_acc.compute() assert acc["raw"] == raws[i] diff --git a/tests/wrappers/test_multioutput.py b/tests/wrappers/test_multioutput.py index 75e30c29e4a..08b6618684f 100644 --- a/tests/wrappers/test_multioutput.py +++ b/tests/wrappers/test_multioutput.py @@ -5,6 +5,7 @@ import torch from sklearn.metrics import accuracy_score from sklearn.metrics import r2_score as sk_r2score +from torch import Tensor from tests.helpers import seed_all from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, MetricTester @@ -31,11 +32,11 @@ def __init__( num_outputs=num_outputs, ) - def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: + def update(self, preds: Tensor, target: Tensor) -> None: """Update the each pair of outputs and predictions.""" return self.metric.update(preds, target) - def compute(self) -> torch.Tensor: + def compute(self) -> Tensor: """Compute the R2 score between each pair of outputs and predictions.""" return self.metric.compute() diff --git a/tm_examples/bert_score-own_model.py b/tm_examples/bert_score-own_model.py index 77d94e34d8f..ec9d590f7cb 100644 --- a/tm_examples/bert_score-own_model.py +++ b/tm_examples/bert_score-own_model.py @@ -21,6 +21,8 @@ import torch import torch.nn as nn +from torch import Tensor +from torch.nn import Module from torchmetrics.text.bert import BERTScore @@ -52,7 +54,7 @@ def __init__(self) -> None: self.PAD_TOKEN: torch.zeros(1, _MODEL_DIM), } - def __call__(self, sentences: Union[str, List[str]], max_len: int = _MAX_LEN) -> Dict[str, torch.Tensor]: + def __call__(self, sentences: Union[str, List[str]], max_len: int = _MAX_LEN) -> Dict[str, Tensor]: """The `__call__` method must be defined for this class. To ensure the functionality, the `__call__` method should obey the input/output arguments structure described below. @@ -63,10 +65,9 @@ def __call__(self, sentences: Union[str, List[str]], max_len: int = _MAX_LEN) -> Maximum length of pre-processed text. `int` Return: - Python dictionary containing the keys `input_ids` and `attention_mask` with corresponding `torch.Tensor` - values. + Python dictionary containing the keys `input_ids` and `attention_mask` with corresponding values. """ - output_dict: Dict[str, torch.Tensor] = {} + output_dict: Dict[str, Tensor] = {} if isinstance(sentences, str): sentences = [sentences] # Add special tokens @@ -89,16 +90,14 @@ def __call__(self, sentences: Union[str, List[str]], max_len: int = _MAX_LEN) -> return output_dict -def get_user_model_encoder( - num_layers: int = _NUM_LAYERS, d_model: int = _MODEL_DIM, nhead: int = _NHEAD -) -> torch.nn.Module: +def get_user_model_encoder(num_layers: int = _NUM_LAYERS, d_model: int = _MODEL_DIM, nhead: int = _NHEAD) -> Module: """Initialize the Transformer encoder.""" encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead) transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) return transformer_encoder -def user_forward_fn(model: torch.nn.Module, batch: Dict[str, torch.Tensor]) -> torch.Tensor: +def user_forward_fn(model: Module, batch: Dict[str, Tensor]) -> Tensor: """User forward function used for the computation of model embeddings. This function might be arbitrarily complicated inside. However, to ensure functionality, it should obey the @@ -106,12 +105,10 @@ def user_forward_fn(model: torch.nn.Module, batch: Dict[str, torch.Tensor]) -> t Args: model: - `torch.nn.Module` batch: - `Dict[str, torch.Tensor]` Return: - The model output. `torch.Tensor` + The model output. """ return model(batch["input_ids"]) diff --git a/tm_examples/detection_map.py b/tm_examples/detection_map.py index 66a3f1fe036..957b09f5fc5 100644 --- a/tm_examples/detection_map.py +++ b/tm_examples/detection_map.py @@ -17,6 +17,7 @@ """ import torch +from torch import Tensor from torchmetrics.detection.mean_ap import MeanAveragePrecision @@ -27,10 +28,10 @@ # The boxes keyword should contain an [N,4] tensor, # where N is the number of detected boxes with boxes of the format # [xmin, ymin, xmax, ymax] in absolute image coordinates - boxes=torch.Tensor([[258.0, 41.0, 606.0, 285.0]]), + boxes=Tensor([[258.0, 41.0, 606.0, 285.0]]), # The scores keyword should contain an [N,] tensor where # each element is confidence score between 0 and 1 - scores=torch.Tensor([0.536]), + scores=Tensor([0.536]), # The labels keyword should contain an [N,] tensor # with integers of the predicted classes labels=torch.IntTensor([0]), @@ -43,7 +44,7 @@ # target need to match target = [ dict( - boxes=torch.Tensor([[214.0, 41.0, 562.0, 285.0]]), + boxes=Tensor([[214.0, 41.0, 562.0, 285.0]]), labels=torch.IntTensor([0]), ) ] diff --git a/torchmetrics/collections.py b/torchmetrics/collections.py index 3c8b58df1c7..55492bdd3bd 100644 --- a/torchmetrics/collections.py +++ b/torchmetrics/collections.py @@ -15,7 +15,8 @@ from typing import Any, Dict, Hashable, Iterable, List, Optional, Sequence, Tuple, Union import torch -from torch import Tensor, nn +from torch import Tensor +from torch.nn import Module, ModuleDict from torchmetrics.metric import Metric from torchmetrics.utilities import rank_zero_warn @@ -25,7 +26,7 @@ from torchmetrics.utilities.imports import OrderedDict -class MetricCollection(nn.ModuleDict): +class MetricCollection(ModuleDict): """MetricCollection class can be used to chain metrics that have the same call pattern into one single class. Args: @@ -347,7 +348,7 @@ def keys(self, keep_base: bool = False) -> Iterable[Hashable]: return self._modules.keys() return self._to_renamed_ordered_dict().keys() - def items(self, keep_base: bool = False) -> Iterable[Tuple[str, nn.Module]]: + def items(self, keep_base: bool = False) -> Iterable[Tuple[str, Module]]: r"""Return an iterable of the ModuleDict key/value pairs. Args: keep_base: Whether to add prefix/postfix on the items collection. diff --git a/torchmetrics/functional/audio/pit.py b/torchmetrics/functional/audio/pit.py index 3d3e2b542bb..b2a2b7d87bd 100644 --- a/torchmetrics/functional/audio/pit.py +++ b/torchmetrics/functional/audio/pit.py @@ -26,7 +26,7 @@ def _find_best_perm_by_linear_sum_assignment( - metric_mtx: torch.Tensor, + metric_mtx: Tensor, eval_func: Union[torch.min, torch.max], ) -> Tuple[Tensor, Tensor]: """Solves the linear sum assignment problem using scipy, and returns the best metric values and the @@ -50,7 +50,7 @@ def _find_best_perm_by_linear_sum_assignment( def _find_best_perm_by_exhaustive_method( - metric_mtx: torch.Tensor, + metric_mtx: Tensor, eval_func: Union[torch.min, torch.max], ) -> Tuple[Tensor, Tensor]: """Solves the linear sum assignment problem using exhaustive method, i.e. exhaustively calculates the metric @@ -93,7 +93,7 @@ def _find_best_perm_by_exhaustive_method( def permutation_invariant_training( - preds: torch.Tensor, target: torch.Tensor, metric_func: Callable, eval_func: str = "max", **kwargs: Dict[str, Any] + preds: Tensor, target: torch.Tensor, metric_func: Callable, eval_func: str = "max", **kwargs: Dict[str, Any] ) -> Tuple[Tensor, Tensor]: """Permutation invariant training (PIT). The ``permutation_invariant_training`` implements the famous Permutation Invariant Training method. diff --git a/torchmetrics/functional/audio/sdr.py b/torchmetrics/functional/audio/sdr.py index 8a3b9490a9e..f8d6c0d85c8 100644 --- a/torchmetrics/functional/audio/sdr.py +++ b/torchmetrics/functional/audio/sdr.py @@ -69,11 +69,7 @@ def _symmetric_toeplitz(vector: Tensor) -> Tensor: ).flip(dims=(-1,)) -def _compute_autocorr_crosscorr( - target: torch.Tensor, - preds: torch.Tensor, - corr_len: int, -) -> Tuple[torch.Tensor, torch.Tensor]: +def _compute_autocorr_crosscorr(target: Tensor, preds: Tensor, corr_len: int) -> Tuple[Tensor, Tensor]: r"""Compute the auto correlation of `target` and the cross correlation of `target` and `preds` using the fast Fourier transform (FFT). Let's denotes the symmetric Toeplitz matric of the auto correlation of `target` as `R`, the cross correlation as 'b', then solving the equation `Rh=b` could have `h` as the coordinate of diff --git a/torchmetrics/functional/image/gradients.py b/torchmetrics/functional/image/gradients.py index f123f1a37b5..f22ad3a1945 100644 --- a/torchmetrics/functional/image/gradients.py +++ b/torchmetrics/functional/image/gradients.py @@ -56,7 +56,7 @@ def image_gradients(img: Tensor) -> Tuple[Tensor, Tensor]: Raises: TypeError: - If ``img`` is not of the type . + If ``img`` is not of the type ``torch.Tensor``. RuntimeError: If ``img`` is not a 4D tensor. diff --git a/torchmetrics/functional/image/ssim.py b/torchmetrics/functional/image/ssim.py index d4eca37d910..2c775994018 100644 --- a/torchmetrics/functional/image/ssim.py +++ b/torchmetrics/functional/image/ssim.py @@ -150,7 +150,7 @@ def _ssim_compute( kernel = _gaussian_kernel_2d(channel, gauss_kernel_size, sigma, dtype, device) if not gaussian_kernel: - kernel = torch.ones((1, 1, *kernel_size)) / torch.prod(torch.Tensor(kernel_size)) + kernel = torch.ones((1, 1, *kernel_size)) / torch.prod(Tensor(kernel_size)) input_list = torch.cat((preds, target, preds * preds, target * target, preds * target)) # (5 * B, C, H, W) diff --git a/torchmetrics/functional/text/bert.py b/torchmetrics/functional/text/bert.py index 7d1c5ff53fc..30e778cf302 100644 --- a/torchmetrics/functional/text/bert.py +++ b/torchmetrics/functional/text/bert.py @@ -20,6 +20,7 @@ import torch from torch import Tensor +from torch.nn import Module from torch.utils.data import DataLoader, Dataset from torchmetrics.utilities.imports import _TQDM_AVAILABLE, _TRANSFORMERS_AUTO_AVAILABLE @@ -197,7 +198,7 @@ def _get_tokens_idf_default_value(self) -> float: @staticmethod def _set_of_tokens(input_ids: Tensor) -> Set: - """Return set of tokens from the ``input_ids`` ``torch.Tensor``.""" + """Return set of tokens from the ``input_ids``.""" return set(input_ids.tolist()) @@ -248,13 +249,13 @@ def _check_shape_of_model_output(output: Tensor, input_ids: Tensor) -> None: def _get_embeddings_and_idf_scale( dataloader: DataLoader, target_len: int, - model: torch.nn.Module, + model: Module, device: Optional[Union[str, torch.device]] = None, num_layers: Optional[int] = None, all_layers: bool = False, idf: bool = False, verbose: bool = False, - user_forward_fn: Callable[[torch.nn.Module, Dict[str, Tensor]], Tensor] = None, + user_forward_fn: Callable[[Module, Dict[str, Tensor]], Tensor] = None, ) -> Tuple[Tensor, Tensor]: """Calculate sentence embeddings and the inverse-document-frequency scaling factor. Args: @@ -273,7 +274,7 @@ def _get_embeddings_and_idf_scale( ``torch.Tensor``. Return: - A tuple of torch.Tensors containing the model's embeddings and the normalized tokens IDF. + A tuple of ``torch.Tensor``s containing the model's embeddings and the normalized tokens IDF. When ``idf = False``, tokens IDF is not calculated, and a matrix of mean weights is returned instead. For a single sentence, ``mean_weight = 1/seq_len``, where ``seq_len`` is a sum over the corresponding ``attention_mask``. @@ -440,9 +441,9 @@ def bert_score( model_name_or_path: Optional[str] = None, num_layers: Optional[int] = None, all_layers: bool = False, - model: Optional[torch.nn.Module] = None, + model: Optional[Module] = None, user_tokenizer: Any = None, - user_forward_fn: Callable[[torch.nn.Module, Dict[str, Tensor]], Tensor] = None, + user_forward_fn: Callable[[Module, Dict[str, Tensor]], Tensor] = None, verbose: bool = False, idf: bool = False, device: Optional[Union[str, torch.device]] = None, diff --git a/torchmetrics/image/fid.py b/torchmetrics/image/fid.py index 94255e77985..350cf73dd77 100644 --- a/torchmetrics/image/fid.py +++ b/torchmetrics/image/fid.py @@ -17,6 +17,7 @@ import torch from torch import Tensor from torch.autograd import Function +from torch.nn import Module from torchmetrics.metric import Metric from torchmetrics.utilities import rank_zero_info, rank_zero_warn @@ -27,7 +28,7 @@ from torch_fidelity.feature_extractor_inceptionv3 import FeatureExtractorInceptionV3 else: - class FeatureExtractorInceptionV3(torch.nn.Module): # type: ignore + class FeatureExtractorInceptionV3(Module): # type: ignore pass __doctest_skip__ = ["FrechetInceptionDistance", "FID"] @@ -202,7 +203,7 @@ class FrechetInceptionDistance(Metric): def __init__( self, - feature: Union[int, torch.nn.Module] = 2048, + feature: Union[int, Module] = 2048, reset_real_features: bool = True, **kwargs: Dict[str, Any], ) -> None: @@ -227,7 +228,7 @@ def __init__( ) self.inception = NoTrainInceptionV3(name="inception-v3-compat", features_list=[str(feature)]) - elif isinstance(feature, torch.nn.Module): + elif isinstance(feature, Module): self.inception = feature else: raise TypeError("Got unknown input to argument `feature`") diff --git a/torchmetrics/image/inception.py b/torchmetrics/image/inception.py index 4df77dc0814..9040fc70df4 100644 --- a/torchmetrics/image/inception.py +++ b/torchmetrics/image/inception.py @@ -15,6 +15,7 @@ import torch from torch import Tensor +from torch.nn import Module from torchmetrics.image.fid import NoTrainInceptionV3 from torchmetrics.metric import Metric @@ -92,7 +93,7 @@ class InceptionScore(Metric): def __init__( self, - feature: Union[str, int, torch.nn.Module] = "logits_unbiased", + feature: Union[str, int, Module] = "logits_unbiased", splits: int = 10, **kwargs: Dict[str, Any], ) -> None: @@ -117,7 +118,7 @@ def __init__( ) self.inception = NoTrainInceptionV3(name="inception-v3-compat", features_list=[str(feature)]) - elif isinstance(feature, torch.nn.Module): + elif isinstance(feature, Module): self.inception = feature else: raise TypeError("Got unknown input to argument `feature`") diff --git a/torchmetrics/image/kid.py b/torchmetrics/image/kid.py index ece20ea7a6f..d8f02df3a5d 100644 --- a/torchmetrics/image/kid.py +++ b/torchmetrics/image/kid.py @@ -156,7 +156,7 @@ class KernelInceptionDistance(Metric): def __init__( self, - feature: Union[str, int, torch.nn.Module] = 2048, + feature: Union[str, int, Module] = 2048, subsets: int = 100, subset_size: int = 1000, degree: int = 3, diff --git a/torchmetrics/image/lpip.py b/torchmetrics/image/lpip.py index 54c920b89d5..d0f5e59ae27 100644 --- a/torchmetrics/image/lpip.py +++ b/torchmetrics/image/lpip.py @@ -15,6 +15,7 @@ import torch from torch import Tensor +from torch.nn import Module from typing_extensions import Literal from torchmetrics.metric import Metric @@ -24,7 +25,7 @@ from lpips import LPIPS as _LPIPS else: - class _LPIPS(torch.nn.Module): # type: ignore + class _LPIPS(Module): # type: ignore pass __doctest_skip__ = ["LearnedPerceptualImagePatchSimilarity", "LPIPS"] diff --git a/torchmetrics/text/bert.py b/torchmetrics/text/bert.py index c477a7d44e7..99fcad1306f 100644 --- a/torchmetrics/text/bert.py +++ b/torchmetrics/text/bert.py @@ -16,6 +16,7 @@ import torch from torch import Tensor +from torch.nn import Module from torchmetrics.functional.text.bert import _preprocess_text, bert_score from torchmetrics.metric import Metric @@ -65,7 +66,7 @@ class BERTScore(Metric): A user's own forward function used in a combination with `user_model`. This function must take `user_model` and a python dictionary of containing `"input_ids"` and `"attention_mask"` represented by `torch.Tensor` as an input and return the model's output represented by the single `torch.Tensor`. - verbose: An indication of whether a progress bar to be displayed during the embeddings calculation. + verbose: An indication of whether a progress bar to be displayed during the embeddings' calculation. idf: An indication whether normalization using inverse document frequencies should be used. device: A device to be used for calculation. max_length: A maximum length of input sequences. Sequences longer than `max_length` are to be trimmed. @@ -116,9 +117,9 @@ def __init__( model_name_or_path: Optional[str] = None, num_layers: Optional[int] = None, all_layers: bool = False, - model: Optional[torch.nn.Module] = None, + model: Optional[Module] = None, user_tokenizer: Optional[Any] = None, - user_forward_fn: Callable[[torch.nn.Module, Dict[str, torch.Tensor]], torch.Tensor] = None, + user_forward_fn: Callable[[Module, Dict[str, Tensor]], Tensor] = None, verbose: bool = False, idf: bool = False, device: Optional[Union[str, torch.device]] = None, diff --git a/torchmetrics/utilities/data.py b/torchmetrics/utilities/data.py index 7ed5242c2b1..5a991138105 100644 --- a/torchmetrics/utilities/data.py +++ b/torchmetrics/utilities/data.py @@ -194,11 +194,10 @@ def apply_to_collection( def get_group_indexes(indexes: Tensor) -> List[Tensor]: - """Given an integer ``torch.Tensor`` ``indexes``, return a ``torch.Tensor`` of indexes for each different value - in ``indexes``. + """Given an integer ``indexes``, return indexes for each different value in ``indexes``. Args: - indexes: a ``torch.Tensor`` + indexes: Return: A list of integer ``torch.Tensor``s diff --git a/torchmetrics/wrappers/bootstrapping.py b/torchmetrics/wrappers/bootstrapping.py index c1fcafa88dd..fd43cfb7c6e 100644 --- a/torchmetrics/wrappers/bootstrapping.py +++ b/torchmetrics/wrappers/bootstrapping.py @@ -15,7 +15,8 @@ from typing import Any, Dict, Optional, Union import torch -from torch import Tensor, nn +from torch import Tensor +from torch.nn import ModuleList from torchmetrics.metric import Metric from torchmetrics.utilities import apply_to_collection @@ -98,7 +99,7 @@ def __init__( "Expected base metric to be an instance of torchmetrics.Metric" f" but received {base_metric}" ) - self.metrics = nn.ModuleList([deepcopy(base_metric) for _ in range(num_bootstraps)]) + self.metrics = ModuleList([deepcopy(base_metric) for _ in range(num_bootstraps)]) self.num_bootstraps = num_bootstraps self.mean = mean diff --git a/torchmetrics/wrappers/multioutput.py b/torchmetrics/wrappers/multioutput.py index fafca22ea50..c51cc0cb02f 100644 --- a/torchmetrics/wrappers/multioutput.py +++ b/torchmetrics/wrappers/multioutput.py @@ -2,13 +2,14 @@ from typing import Any, List, Tuple import torch -from torch import nn +from torch import Tensor +from torch.nn import ModuleList from torchmetrics import Metric from torchmetrics.utilities import apply_to_collection -def _get_nan_indices(*tensors: torch.Tensor) -> torch.Tensor: +def _get_nan_indices(*tensors: Tensor) -> Tensor: """Get indices of rows along dim 0 which have NaN values.""" if len(tensors) == 0: raise ValueError("Must pass at least one tensor as argument") @@ -88,22 +89,20 @@ def __init__( squeeze_outputs: bool = True, ): super().__init__() - self.metrics = nn.ModuleList([deepcopy(base_metric) for _ in range(num_outputs)]) + self.metrics = ModuleList([deepcopy(base_metric) for _ in range(num_outputs)]) self.output_dim = output_dim self.remove_nans = remove_nans self.squeeze_outputs = squeeze_outputs - def _get_args_kwargs_by_output( - self, *args: torch.Tensor, **kwargs: torch.Tensor - ) -> List[Tuple[torch.Tensor, torch.Tensor]]: + def _get_args_kwargs_by_output(self, *args: Tensor, **kwargs: Tensor) -> List[Tuple[Tensor, Tensor]]: """Get args and kwargs reshaped to be output-specific and (maybe) having NaNs stripped out.""" args_kwargs_by_output = [] for i in range(len(self.metrics)): selected_args = apply_to_collection( - args, torch.Tensor, torch.index_select, dim=self.output_dim, index=torch.tensor(i, device=self.device) + args, Tensor, torch.index_select, dim=self.output_dim, index=torch.tensor(i, device=self.device) ) selected_kwargs = apply_to_collection( - kwargs, torch.Tensor, torch.index_select, dim=self.output_dim, index=torch.tensor(i, device=self.device) + kwargs, Tensor, torch.index_select, dim=self.output_dim, index=torch.tensor(i, device=self.device) ) if self.remove_nans: args_kwargs = selected_args + tuple(selected_kwargs.values()) @@ -122,7 +121,7 @@ def update(self, *args: Any, **kwargs: Any) -> None: for metric, (selected_args, selected_kwargs) in zip(self.metrics, reshaped_args_kwargs): metric.update(*selected_args, **selected_kwargs) - def compute(self) -> List[torch.Tensor]: + def compute(self) -> List[Tensor]: """Compute metrics.""" return [m.compute() for m in self.metrics] diff --git a/torchmetrics/wrappers/tracker.py b/torchmetrics/wrappers/tracker.py index eb9f393f190..0b222dd91b1 100644 --- a/torchmetrics/wrappers/tracker.py +++ b/torchmetrics/wrappers/tracker.py @@ -16,13 +16,14 @@ from typing import Any, Dict, List, Tuple, Union import torch -from torch import Tensor, nn +from torch import Tensor +from torch.nn import ModuleList from torchmetrics.collections import MetricCollection from torchmetrics.metric import Metric -class MetricTracker(nn.ModuleList): +class MetricTracker(ModuleList): """A wrapper class that can help keeping track of a metric or metric collection over time and implement useful methods. The wrapper implements the standard ``.update()``, ``.compute()``, ``.reset()`` methods that just calls corresponding method of the currently tracked metric. However, the following additional methods are