Skip to content

Commit

Permalink
cleaning long torch imports (#996)
Browse files Browse the repository at this point in the history
* Tensor
* Module
  • Loading branch information
Borda authored Apr 29, 2022
1 parent 1bc6c47 commit 3a141ae
Show file tree
Hide file tree
Showing 26 changed files with 115 additions and 107 deletions.
7 changes: 4 additions & 3 deletions tests/bases/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
import psutil
import pytest
import torch
from torch import Tensor, nn, tensor
from torch import Tensor, tensor
from torch.nn import Module

from tests.helpers import seed_all
from tests.helpers.testers import DummyListMetric, DummyMetric, DummyMetricMultiOutput, DummyMetricSum
Expand Down Expand Up @@ -245,7 +246,7 @@ def test_load_state_dict(tmpdir):
def test_child_metric_state_dict():
"""test that child metric states will be added to parent state dict."""

class TestModule(nn.Module):
class TestModule(Module):
def __init__(self):
super().__init__()
self.metric = DummyMetric()
Expand Down Expand Up @@ -346,7 +347,7 @@ def test_forward_and_compute_to_device(metric_class):
def test_device_if_child_module(metric_class):
"""Test that if a metric is a child module all values gets moved to the correct device."""

class TestModule(nn.Module):
class TestModule(Module):
def __init__(self):
super().__init__()
self.metric = metric_class()
Expand Down
67 changes: 34 additions & 33 deletions tests/detection/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import pytest
import torch
from torch import Tensor

from tests.helpers.testers import MetricTester
from torchmetrics.detection.mean_ap import MeanAveragePrecision
Expand All @@ -27,19 +28,19 @@
preds=[
[
dict(
boxes=torch.Tensor([[258.15, 41.29, 606.41, 285.07]]),
scores=torch.Tensor([0.236]),
boxes=Tensor([[258.15, 41.29, 606.41, 285.07]]),
scores=Tensor([0.236]),
labels=torch.IntTensor([4]),
), # coco image id 42
dict(
boxes=torch.Tensor([[61.00, 22.75, 565.00, 632.42], [12.66, 3.32, 281.26, 275.23]]),
scores=torch.Tensor([0.318, 0.726]),
boxes=Tensor([[61.00, 22.75, 565.00, 632.42], [12.66, 3.32, 281.26, 275.23]]),
scores=Tensor([0.318, 0.726]),
labels=torch.IntTensor([3, 2]),
), # coco image id 73
],
[
dict(
boxes=torch.Tensor(
boxes=Tensor(
[
[87.87, 276.25, 384.29, 379.43],
[0.00, 3.66, 142.15, 316.06],
Expand All @@ -50,24 +51,24 @@
[276.11, 103.84, 291.44, 150.72],
]
),
scores=torch.Tensor([0.546, 0.3, 0.407, 0.611, 0.335, 0.805, 0.953]),
scores=Tensor([0.546, 0.3, 0.407, 0.611, 0.335, 0.805, 0.953]),
labels=torch.IntTensor([4, 1, 0, 0, 0, 0, 0]),
), # coco image id 74
dict(
boxes=torch.Tensor([[0.00, 2.87, 601.00, 421.52]]),
scores=torch.Tensor([0.699]),
boxes=Tensor([[0.00, 2.87, 601.00, 421.52]]),
scores=Tensor([0.699]),
labels=torch.IntTensor([5]),
), # coco image id 133
],
],
target=[
[
dict(
boxes=torch.Tensor([[214.1500, 41.2900, 562.4100, 285.0700]]),
boxes=Tensor([[214.1500, 41.2900, 562.4100, 285.0700]]),
labels=torch.IntTensor([4]),
), # coco image id 42
dict(
boxes=torch.Tensor(
boxes=Tensor(
[
[13.00, 22.75, 548.98, 632.42],
[1.66, 3.32, 270.26, 275.23],
Expand All @@ -78,7 +79,7 @@
],
[
dict(
boxes=torch.Tensor(
boxes=Tensor(
[
[61.87, 276.25, 358.29, 379.43],
[2.75, 3.66, 162.15, 316.06],
Expand All @@ -92,7 +93,7 @@
labels=torch.IntTensor([4, 1, 0, 0, 0, 0, 0]),
), # coco image id 74
dict(
boxes=torch.Tensor([[13.99, 2.87, 640.00, 421.52]]),
boxes=Tensor([[13.99, 2.87, 640.00, 421.52]]),
labels=torch.IntTensor([5]),
), # coco image id 133
],
Expand All @@ -104,29 +105,29 @@
preds=[
[
dict(
boxes=torch.Tensor([[258.0, 41.0, 606.0, 285.0]]),
scores=torch.Tensor([0.536]),
boxes=Tensor([[258.0, 41.0, 606.0, 285.0]]),
scores=Tensor([0.536]),
labels=torch.IntTensor([0]),
),
],
[
dict(
boxes=torch.Tensor([[258.0, 41.0, 606.0, 285.0]]),
scores=torch.Tensor([0.536]),
boxes=Tensor([[258.0, 41.0, 606.0, 285.0]]),
scores=Tensor([0.536]),
labels=torch.IntTensor([0]),
)
],
],
target=[
[
dict(
boxes=torch.Tensor([[214.0, 41.0, 562.0, 285.0]]),
boxes=Tensor([[214.0, 41.0, 562.0, 285.0]]),
labels=torch.IntTensor([0]),
)
],
[
dict(
boxes=torch.Tensor([]),
boxes=Tensor([]),
labels=torch.IntTensor([]),
)
],
Expand Down Expand Up @@ -196,20 +197,20 @@ def _compare_fn(preds, target) -> dict:
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.900
"""
return {
"map": torch.Tensor([0.706]),
"map_50": torch.Tensor([0.901]),
"map_75": torch.Tensor([0.846]),
"map_small": torch.Tensor([0.689]),
"map_medium": torch.Tensor([0.800]),
"map_large": torch.Tensor([0.701]),
"mar_1": torch.Tensor([0.592]),
"mar_10": torch.Tensor([0.716]),
"mar_100": torch.Tensor([0.716]),
"mar_small": torch.Tensor([0.767]),
"mar_medium": torch.Tensor([0.800]),
"mar_large": torch.Tensor([0.700]),
"map_per_class": torch.Tensor([0.725, 0.800, 0.454, -1.000, 0.650, 0.900]),
"mar_100_per_class": torch.Tensor([0.780, 0.800, 0.450, -1.000, 0.650, 0.900]),
"map": Tensor([0.706]),
"map_50": Tensor([0.901]),
"map_75": Tensor([0.846]),
"map_small": Tensor([0.689]),
"map_medium": Tensor([0.800]),
"map_large": Tensor([0.701]),
"mar_1": Tensor([0.592]),
"mar_10": Tensor([0.716]),
"mar_100": Tensor([0.716]),
"mar_small": Tensor([0.767]),
"mar_medium": Tensor([0.800]),
"mar_large": Tensor([0.700]),
"map_per_class": Tensor([0.725, 0.800, 0.454, -1.000, 0.650, 0.900]),
"mar_100_per_class": Tensor([0.780, 0.800, 0.450, -1.000, 0.650, 0.900]),
}


Expand Down Expand Up @@ -260,7 +261,7 @@ def test_empty_preds():

metric.update(
[
dict(boxes=torch.Tensor([]), scores=torch.Tensor([]), labels=torch.IntTensor([])),
dict(boxes=Tensor([]), scores=torch.Tensor([]), labels=torch.IntTensor([])),
],
[
dict(boxes=torch.Tensor([[214.1500, 41.2900, 562.4100, 285.0700]]), labels=torch.IntTensor([4])),
Expand Down
8 changes: 4 additions & 4 deletions tests/helpers/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def _class_test(
batch_result = metric(preds[i], target[i], **batch_kwargs_update)

if metric.dist_sync_on_step and check_dist_sync_on_step and rank == 0:
if isinstance(preds, torch.Tensor):
if isinstance(preds, Tensor):
ddp_preds = torch.cat([preds[i + r] for r in range(worldsize)]).cpu()
ddp_target = torch.cat([target[i + r] for r in range(worldsize)]).cpu()
else:
Expand All @@ -201,8 +201,8 @@ def _class_test(
k: v.cpu() if isinstance(v, Tensor) else v
for k, v in (batch_kwargs_update if fragment_kwargs else kwargs_update).items()
}
preds_ = preds[i].cpu() if isinstance(preds, torch.Tensor) else preds[i]
target_ = target[i].cpu() if isinstance(target, torch.Tensor) else target[i]
preds_ = preds[i].cpu() if isinstance(preds, Tensor) else preds[i]
target_ = target[i].cpu() if isinstance(target, Tensor) else target[i]
sk_batch_result = sk_metric(preds_, target_, **batch_kwargs_update)
if isinstance(batch_result, dict):
for key in batch_result.keys():
Expand All @@ -221,7 +221,7 @@ def _class_test(
else:
_assert_tensor(result)

if isinstance(preds, torch.Tensor):
if isinstance(preds, Tensor):
total_preds = torch.cat([preds[i] for i in range(num_batches)]).cpu()
total_target = torch.cat([target[i] for i in range(num_batches)]).cpu()
else:
Expand Down
3 changes: 2 additions & 1 deletion tests/image/test_fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pytest
import torch
from scipy.linalg import sqrtm as scipy_sqrtm
from torch.nn import Module
from torch.utils.data import Dataset

from torchmetrics.image.fid import FrechetInceptionDistance, sqrtm
Expand Down Expand Up @@ -44,7 +45,7 @@ def generate_cov(n):
def test_no_train():
"""Assert that metric never leaves evaluation mode."""

class MyModel(torch.nn.Module):
class MyModel(Module):
def __init__(self):
super().__init__()
self.metric = FrechetInceptionDistance()
Expand Down
3 changes: 2 additions & 1 deletion tests/image/test_inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import pytest
import torch
from torch.nn import Module
from torch.utils.data import Dataset

from torchmetrics.image.inception import InceptionScore
Expand All @@ -27,7 +28,7 @@
def test_no_train():
"""Assert that metric never leaves evaluation mode."""

class MyModel(torch.nn.Module):
class MyModel(Module):
def __init__(self):
super().__init__()
self.metric = InceptionScore()
Expand Down
3 changes: 2 additions & 1 deletion tests/image/test_kid.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import pytest
import torch
from torch.nn import Module
from torch.utils.data import Dataset

from torchmetrics.image.kid import KernelInceptionDistance
Expand All @@ -27,7 +28,7 @@
def test_no_train():
"""Assert that metric never leaves evaluation mode."""

class MyModel(torch.nn.Module):
class MyModel(Module):
def __init__(self):
super().__init__()
self.metric = KernelInceptionDistance()
Expand Down
3 changes: 2 additions & 1 deletion tests/text/test_bertscore.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch import Tensor

from tests.text.helpers import skip_on_connection_issues
from torchmetrics.functional.text.bert import bert_score as metrics_bert_score
Expand Down Expand Up @@ -45,7 +46,7 @@ def _assert_list(preds: Any, targets: Any, threshold: float = 1e-8):
assert np.allclose(preds, targets, atol=threshold, equal_nan=True)


def _parse_original_bert_score(score: torch.Tensor) -> Dict[str, List[float]]:
def _parse_original_bert_score(score: Tensor) -> Dict[str, List[float]]:
"""Parse the BERT score returned by the original `bert-score` package."""
score_dict = {metric: value.tolist() for metric, value in zip(_METRICS, score)}
return score_dict
Expand Down
5 changes: 3 additions & 2 deletions tests/wrappers/test_minmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest
import torch
from torch import Tensor

from tests.helpers import seed_all
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, MetricTester
Expand Down Expand Up @@ -97,10 +98,10 @@ def test_basic_example(preds, labels, raws, maxs, mins) -> None:
"""tests that both min and max versions of MinMaxMetric operate correctly after calling compute."""
acc = Accuracy()
min_max_acc = MinMaxMetric(acc)
labels = torch.Tensor(labels).long()
labels = Tensor(labels).long()

for i in range(3):
preds_ = torch.Tensor(preds[i])
preds_ = Tensor(preds[i])
min_max_acc(preds_, labels)
acc = min_max_acc.compute()
assert acc["raw"] == raws[i]
Expand Down
5 changes: 3 additions & 2 deletions tests/wrappers/test_multioutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
from sklearn.metrics import accuracy_score
from sklearn.metrics import r2_score as sk_r2score
from torch import Tensor

from tests.helpers import seed_all
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, MetricTester
Expand All @@ -31,11 +32,11 @@ def __init__(
num_outputs=num_outputs,
)

def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
def update(self, preds: Tensor, target: Tensor) -> None:
"""Update the each pair of outputs and predictions."""
return self.metric.update(preds, target)

def compute(self) -> torch.Tensor:
def compute(self) -> Tensor:
"""Compute the R2 score between each pair of outputs and predictions."""
return self.metric.compute()

Expand Down
19 changes: 8 additions & 11 deletions tm_examples/bert_score-own_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import Module

from torchmetrics.text.bert import BERTScore

Expand Down Expand Up @@ -52,7 +54,7 @@ def __init__(self) -> None:
self.PAD_TOKEN: torch.zeros(1, _MODEL_DIM),
}

def __call__(self, sentences: Union[str, List[str]], max_len: int = _MAX_LEN) -> Dict[str, torch.Tensor]:
def __call__(self, sentences: Union[str, List[str]], max_len: int = _MAX_LEN) -> Dict[str, Tensor]:
"""The `__call__` method must be defined for this class. To ensure the functionality, the `__call__` method
should obey the input/output arguments structure described below.
Expand All @@ -63,10 +65,9 @@ def __call__(self, sentences: Union[str, List[str]], max_len: int = _MAX_LEN) ->
Maximum length of pre-processed text. `int`
Return:
Python dictionary containing the keys `input_ids` and `attention_mask` with corresponding `torch.Tensor`
values.
Python dictionary containing the keys `input_ids` and `attention_mask` with corresponding values.
"""
output_dict: Dict[str, torch.Tensor] = {}
output_dict: Dict[str, Tensor] = {}
if isinstance(sentences, str):
sentences = [sentences]
# Add special tokens
Expand All @@ -89,29 +90,25 @@ def __call__(self, sentences: Union[str, List[str]], max_len: int = _MAX_LEN) ->
return output_dict


def get_user_model_encoder(
num_layers: int = _NUM_LAYERS, d_model: int = _MODEL_DIM, nhead: int = _NHEAD
) -> torch.nn.Module:
def get_user_model_encoder(num_layers: int = _NUM_LAYERS, d_model: int = _MODEL_DIM, nhead: int = _NHEAD) -> Module:
"""Initialize the Transformer encoder."""
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead)
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
return transformer_encoder


def user_forward_fn(model: torch.nn.Module, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
def user_forward_fn(model: Module, batch: Dict[str, Tensor]) -> Tensor:
"""User forward function used for the computation of model embeddings.
This function might be arbitrarily complicated inside. However, to ensure functionality, it should obey the
input/output argument structure described below.
Args:
model:
`torch.nn.Module`
batch:
`Dict[str, torch.Tensor]`
Return:
The model output. `torch.Tensor`
The model output.
"""
return model(batch["input_ids"])

Expand Down
Loading

0 comments on commit 3a141ae

Please sign in to comment.