Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Mar 15, 2024
1 parent 649dfdb commit 71ca0b2
Show file tree
Hide file tree
Showing 30 changed files with 455 additions and 345 deletions.
16 changes: 9 additions & 7 deletions examples/bert_score-own_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,15 @@ def __call__(self, sentences: Union[str, List[str]], max_len: int = _MAX_LEN) ->
sentence.lower().split()[:max_len] + [self.PAD_TOKEN] * (max_len - len(sentence.lower().split()))
for sentence in sentences
]
output_dict["input_ids"] = torch.cat([
torch.cat([self.word2vec[word] for word in sentence]).unsqueeze(0) for sentence in tokenized_sentences
])
output_dict["attention_mask"] = torch.cat([
torch.tensor([1 if word != self.PAD_TOKEN else 0 for word in sentence]).unsqueeze(0)
for sentence in tokenized_sentences
]).long()
output_dict["input_ids"] = torch.cat(
[torch.cat([self.word2vec[word] for word in sentence]).unsqueeze(0) for sentence in tokenized_sentences]
)
output_dict["attention_mask"] = torch.cat(
[
torch.tensor([1 if word != self.PAD_TOKEN else 0 for word in sentence]).unsqueeze(0)
for sentence in tokenized_sentences
]
).long()

return output_dict

Expand Down
12 changes: 7 additions & 5 deletions src/torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,11 +504,13 @@ def __new__( # type: ignore[misc]
"""Initialize task metric."""
task = ClassificationTask.from_str(task)

kwargs.update({
"multidim_average": multidim_average,
"ignore_index": ignore_index,
"validate_args": validate_args,
})
kwargs.update(
{
"multidim_average": multidim_average,
"ignore_index": ignore_index,
"validate_args": validate_args,
}
)

if task == ClassificationTask.BINARY:
return BinaryAccuracy(threshold, **kwargs)
Expand Down
12 changes: 7 additions & 5 deletions src/torchmetrics/classification/exact_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,11 +406,13 @@ def __new__(
) -> Metric:
"""Initialize task metric."""
task = ClassificationTaskNoBinary.from_str(task)
kwargs.update({
"multidim_average": multidim_average,
"ignore_index": ignore_index,
"validate_args": validate_args,
})
kwargs.update(
{
"multidim_average": multidim_average,
"ignore_index": ignore_index,
"validate_args": validate_args,
}
)
if task == ClassificationTaskNoBinary.MULTICLASS:
if not isinstance(num_classes, int):
raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`")
Expand Down
24 changes: 14 additions & 10 deletions src/torchmetrics/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,11 +1075,13 @@ def __new__(
"""Initialize task metric."""
task = ClassificationTask.from_str(task)
assert multidim_average is not None # noqa: S101 # needed for mypy
kwargs.update({
"multidim_average": multidim_average,
"ignore_index": ignore_index,
"validate_args": validate_args,
})
kwargs.update(
{
"multidim_average": multidim_average,
"ignore_index": ignore_index,
"validate_args": validate_args,
}
)
if task == ClassificationTask.BINARY:
return BinaryFBetaScore(beta, threshold, **kwargs)
if task == ClassificationTask.MULTICLASS:
Expand Down Expand Up @@ -1138,11 +1140,13 @@ def __new__(
"""Initialize task metric."""
task = ClassificationTask.from_str(task)
assert multidim_average is not None # noqa: S101 # needed for mypy
kwargs.update({
"multidim_average": multidim_average,
"ignore_index": ignore_index,
"validate_args": validate_args,
})
kwargs.update(
{
"multidim_average": multidim_average,
"ignore_index": ignore_index,
"validate_args": validate_args,
}
)
if task == ClassificationTask.BINARY:
return BinaryF1Score(threshold, **kwargs)
if task == ClassificationTask.MULTICLASS:
Expand Down
12 changes: 7 additions & 5 deletions src/torchmetrics/classification/hamming.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,11 +508,13 @@ def __new__( # type: ignore[misc]
"""Initialize task metric."""
task = ClassificationTask.from_str(task)
assert multidim_average is not None # noqa: S101 # needed for mypy
kwargs.update({
"multidim_average": multidim_average,
"ignore_index": ignore_index,
"validate_args": validate_args,
})
kwargs.update(
{
"multidim_average": multidim_average,
"ignore_index": ignore_index,
"validate_args": validate_args,
}
)
if task == ClassificationTask.BINARY:
return BinaryHammingDistance(threshold, **kwargs)
if task == ClassificationTask.MULTICLASS:
Expand Down
24 changes: 14 additions & 10 deletions src/torchmetrics/classification/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,11 +945,13 @@ def __new__(
) -> Metric:
"""Initialize task metric."""
assert multidim_average is not None # noqa: S101 # needed for mypy
kwargs.update({
"multidim_average": multidim_average,
"ignore_index": ignore_index,
"validate_args": validate_args,
})
kwargs.update(
{
"multidim_average": multidim_average,
"ignore_index": ignore_index,
"validate_args": validate_args,
}
)
task = ClassificationTask.from_str(task)
if task == ClassificationTask.BINARY:
return BinaryPrecision(threshold, **kwargs)
Expand Down Expand Up @@ -1011,11 +1013,13 @@ def __new__(
"""Initialize task metric."""
task = ClassificationTask.from_str(task)
assert multidim_average is not None # noqa: S101 # needed for mypy
kwargs.update({
"multidim_average": multidim_average,
"ignore_index": ignore_index,
"validate_args": validate_args,
})
kwargs.update(
{
"multidim_average": multidim_average,
"ignore_index": ignore_index,
"validate_args": validate_args,
}
)
if task == ClassificationTask.BINARY:
return BinaryRecall(threshold, **kwargs)
if task == ClassificationTask.MULTICLASS:
Expand Down
12 changes: 7 additions & 5 deletions src/torchmetrics/classification/specificity.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,11 +492,13 @@ def __new__( # type: ignore[misc]
"""Initialize task metric."""
task = ClassificationTask.from_str(task)
assert multidim_average is not None # noqa: S101 # needed for mypy
kwargs.update({
"multidim_average": multidim_average,
"ignore_index": ignore_index,
"validate_args": validate_args,
})
kwargs.update(
{
"multidim_average": multidim_average,
"ignore_index": ignore_index,
"validate_args": validate_args,
}
)
if task == ClassificationTask.BINARY:
return BinarySpecificity(threshold, **kwargs)
if task == ClassificationTask.MULTICLASS:
Expand Down
12 changes: 7 additions & 5 deletions src/torchmetrics/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,11 +531,13 @@ def __new__(
"""Initialize task metric."""
task = ClassificationTask.from_str(task)
assert multidim_average is not None # noqa: S101 # needed for mypy
kwargs.update({
"multidim_average": multidim_average,
"ignore_index": ignore_index,
"validate_args": validate_args,
})
kwargs.update(
{
"multidim_average": multidim_average,
"ignore_index": ignore_index,
"validate_args": validate_args,
}
)
if task == ClassificationTask.BINARY:
return BinaryStatScores(threshold, **kwargs)
if task == ClassificationTask.MULTICLASS:
Expand Down
6 changes: 3 additions & 3 deletions src/torchmetrics/functional/clustering/dunn_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ def _dunn_index_update(data: Tensor, labels: Tensor, p: float) -> Tuple[Tensor,
torch.stack([a - b for a, b in combinations(centroids, 2)], dim=0), ord=p, dim=1
)

max_intracluster_distance = torch.stack([
torch.linalg.norm(ci - mu, ord=p, dim=1).max() for ci, mu in zip(clusters, centroids)
])
max_intracluster_distance = torch.stack(
[torch.linalg.norm(ci - mu, ord=p, dim=1).max() for ci, mu in zip(clusters, centroids)]
)

return intercluster_distance, max_intracluster_distance

Expand Down
10 changes: 6 additions & 4 deletions src/torchmetrics/functional/clustering/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,12 @@ def calculate_contingency_matrix(
num_classes_target = target_classes.size(0)

contingency = torch.sparse_coo_tensor(
torch.stack((
target_idx,
preds_idx,
)),
torch.stack(
(
target_idx,
preds_idx,
)
),
torch.ones(target_idx.shape[0], dtype=preds_idx.dtype, device=preds_idx.device),
(
num_classes_target,
Expand Down
6 changes: 3 additions & 3 deletions src/torchmetrics/functional/text/helper_embedding_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,9 @@ def _get_tokens_idf(self) -> Dict[int, float]:
token_counter.update(tokens)

tokens_idf: Dict[int, float] = defaultdict(self._get_tokens_idf_default_value)
tokens_idf.update({
idx: math.log((self.num_sentences + 1) / (occurrence + 1)) for idx, occurrence in token_counter.items()
})
tokens_idf.update(
{idx: math.log((self.num_sentences + 1) / (occurrence + 1)) for idx, occurrence in token_counter.items()}
)
return tokens_idf

def _get_tokens_idf_default_value(self) -> float:
Expand Down
12 changes: 6 additions & 6 deletions src/torchmetrics/wrappers/multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,14 +217,14 @@ def clone(self, prefix: Optional[str] = None, postfix: Optional[str] = None) ->
multitask_copy = deepcopy(self)
if prefix is not None:
prefix = self._check_arg(prefix, "prefix")
multitask_copy.task_metrics = nn.ModuleDict({
prefix + key: value for key, value in multitask_copy.task_metrics.items()
})
multitask_copy.task_metrics = nn.ModuleDict(
{prefix + key: value for key, value in multitask_copy.task_metrics.items()}
)
if postfix is not None:
postfix = self._check_arg(postfix, "postfix")
multitask_copy.task_metrics = nn.ModuleDict({
key + postfix: value for key, value in multitask_copy.task_metrics.items()
})
multitask_copy.task_metrics = nn.ModuleDict(
{key + postfix: value for key, value in multitask_copy.task_metrics.items()}
)
return multitask_copy

def plot(
Expand Down
10 changes: 6 additions & 4 deletions tests/integrations/test_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,10 +358,12 @@ class TestModel(BoringModel):
def __init__(self) -> None:
super().__init__()
self.multitask = MultitaskWrapper({"classification": BinaryAccuracy(), "regression": MeanSquaredError()})
self.multitask_collection = MultitaskWrapper({
"classification": MetricCollection([BinaryAccuracy(), BinaryAveragePrecision()]),
"regression": MetricCollection([MeanSquaredError(), MeanAbsoluteError()]),
})
self.multitask_collection = MultitaskWrapper(
{
"classification": MetricCollection([BinaryAccuracy(), BinaryAveragePrecision()]),
"regression": MetricCollection([MeanSquaredError(), MeanAbsoluteError()]),
}
)

self.accuracy = BinaryAccuracy()
self.mse = MeanSquaredError()
Expand Down
30 changes: 18 additions & 12 deletions tests/unittests/bases/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,10 +580,12 @@ def test_classwise_wrapper_compute_group():

def test_compute_on_different_dtype():
"""Check that extraction of compute groups are robust towards difference in dtype."""
m = MetricCollection([
MulticlassConfusionMatrix(num_classes=3),
MulticlassMatthewsCorrCoef(num_classes=3),
])
m = MetricCollection(
[
MulticlassConfusionMatrix(num_classes=3),
MulticlassMatthewsCorrCoef(num_classes=3),
]
)
assert not m._groups_checked
assert m.compute_groups == {0: ["MulticlassConfusionMatrix"], 1: ["MulticlassMatthewsCorrCoef"]}
preds = torch.randn(10, 3).softmax(dim=-1)
Expand Down Expand Up @@ -625,14 +627,18 @@ def test_error_on_wrong_specified_compute_groups():
),
],
{
"macro": MetricCollection([
MulticlassAccuracy(num_classes=3, average="macro"),
MulticlassPrecision(num_classes=3, average="macro"),
]),
"micro": MetricCollection([
MulticlassAccuracy(num_classes=3, average="micro"),
MulticlassPrecision(num_classes=3, average="micro"),
]),
"macro": MetricCollection(
[
MulticlassAccuracy(num_classes=3, average="macro"),
MulticlassPrecision(num_classes=3, average="macro"),
]
),
"micro": MetricCollection(
[
MulticlassAccuracy(num_classes=3, average="micro"),
MulticlassPrecision(num_classes=3, average="micro"),
]
),
},
],
)
Expand Down
6 changes: 3 additions & 3 deletions tests/unittests/clustering/test_dunn_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def _reference_np_dunn_index(data, labels, p):
np.stack([a - b for a, b in combinations(centroids, 2)], axis=0), ord=p, axis=1
)

max_intracluster_distance = np.stack([
np.linalg.norm(ci - mu, ord=p, axis=1).max() for ci, mu in zip(clusters, centroids)
])
max_intracluster_distance = np.stack(
[np.linalg.norm(ci - mu, ord=p, axis=1).max() for ci, mu in zip(clusters, centroids)]
)

return intercluster_distance.min() / max_intracluster_distance.max()

Expand Down
24 changes: 14 additions & 10 deletions tests/unittests/detection/test_intersection.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,20 +76,24 @@ def _tv_wrapper_class(preds, target, base_fn, respect_labels, iou_threshold, cla


_preds_fn = (
torch.tensor([
[296.55, 93.96, 314.97, 152.79],
[328.94, 97.05, 342.49, 122.98],
[356.62, 95.47, 372.33, 147.55],
])
torch.tensor(
[
[296.55, 93.96, 314.97, 152.79],
[328.94, 97.05, 342.49, 122.98],
[356.62, 95.47, 372.33, 147.55],
]
)
.unsqueeze(0)
.repeat(4, 1, 1)
)
_target_fn = (
torch.tensor([
[300.00, 100.00, 315.00, 150.00],
[330.00, 100.00, 350.00, 125.00],
[350.00, 100.00, 375.00, 150.00],
])
torch.tensor(
[
[300.00, 100.00, 315.00, 150.00],
[330.00, 100.00, 350.00, 125.00],
[350.00, 100.00, 375.00, 150.00],
]
)
.unsqueeze(0)
.repeat(4, 1, 1)
)
Expand Down
Loading

0 comments on commit 71ca0b2

Please sign in to comment.