Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ruff: src public method #1502

Merged
merged 2 commits into from
Feb 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ unfixable = ["F401"]
"__init__.py" = ["D100"]
"src/**" = [
"D101", # todo # Missing docstring in public class
"D102", # todo # Missing docstring in public method
"D103", # todo # Missing docstring in public function
"D105", # todo # Missing docstring in magic method
"D205", # todo # 1 blank line required between summary line and description
Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,7 @@ def __new__(
validate_args: bool = True,
**kwargs: Any,
) -> Metric:
"""Initialize task metric."""
task = ClassificationTask.from_str(task)
kwargs.update(
{"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args}
Expand Down
4 changes: 4 additions & 0 deletions src/torchmetrics/classification/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(
self.max_fpr = max_fpr

def compute(self) -> Tensor:
"""Compute metric."""
state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat
return _binary_auroc_compute(state, self.thresholds, self.max_fpr)

Expand Down Expand Up @@ -207,6 +208,7 @@ def __init__(
self.validate_args = validate_args

def compute(self) -> Tensor:
"""Compute metric."""
state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat
return _multiclass_auroc_compute(state, self.num_classes, self.average, self.thresholds)

Expand Down Expand Up @@ -308,6 +310,7 @@ def __init__(
self.validate_args = validate_args

def compute(self) -> Tensor:
"""Compute metric."""
state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat
return _multilabel_auroc_compute(state, self.num_labels, self.average, self.thresholds, self.ignore_index)

Expand Down Expand Up @@ -353,6 +356,7 @@ def __new__(
validate_args: bool = True,
**kwargs: Any,
) -> Metric:
"""Initialize task metric."""
task = ClassificationTask.from_str(task)
kwargs.update({"thresholds": thresholds, "ignore_index": ignore_index, "validate_args": validate_args})
if task == ClassificationTask.BINARY:
Expand Down
4 changes: 4 additions & 0 deletions src/torchmetrics/classification/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class BinaryAveragePrecision(BinaryPrecisionRecallCurve):
full_state_update: bool = False

def compute(self) -> Tensor:
"""Compute metric."""
state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat
return _binary_average_precision_compute(state, self.thresholds)

Expand Down Expand Up @@ -201,6 +202,7 @@ def __init__(
self.validate_args = validate_args

def compute(self) -> Tensor:
"""Compute metric."""
state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat
return _multiclass_average_precision_compute(state, self.num_classes, self.average, self.thresholds)

Expand Down Expand Up @@ -307,6 +309,7 @@ def __init__(
self.validate_args = validate_args

def compute(self) -> Tensor:
"""Compute metric."""
state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat
return _multilabel_average_precision_compute(
state, self.num_labels, self.average, self.thresholds, self.ignore_index
Expand Down Expand Up @@ -357,6 +360,7 @@ def __new__(
validate_args: bool = True,
**kwargs: Any,
) -> Metric:
"""Initialize task metric."""
task = ClassificationTask.from_str(task)
kwargs.update({"thresholds": thresholds, "ignore_index": ignore_index, "validate_args": validate_args})
if task == ClassificationTask.BINARY:
Expand Down
5 changes: 5 additions & 0 deletions src/torchmetrics/classification/calibration_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __init__(
self.add_state("accuracies", [], dist_reduce_fx="cat")

def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
"""Update metric states with predictions and targets."""
if self.validate_args:
_binary_calibration_error_tensor_validation(preds, target, self.ignore_index)
preds, target = _binary_confusion_matrix_format(
Expand All @@ -124,6 +125,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
self.accuracies.append(accuracies)

def compute(self) -> Tensor:
"""Compute metric."""
confidences = dim_zero_cat(self.confidences)
accuracies = dim_zero_cat(self.accuracies)
return _ce_compute(confidences, accuracies, self.n_bins, norm=self.norm)
Expand Down Expand Up @@ -217,6 +219,7 @@ def __init__(
self.add_state("accuracies", [], dist_reduce_fx="cat")

def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
"""Update metric states with predictions and targets."""
if self.validate_args:
_multiclass_calibration_error_tensor_validation(preds, target, self.num_classes, self.ignore_index)
preds, target = _multiclass_confusion_matrix_format(
Expand All @@ -227,6 +230,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
self.accuracies.append(accuracies)

def compute(self) -> Tensor:
"""Compute metric."""
confidences = dim_zero_cat(self.confidences)
accuracies = dim_zero_cat(self.accuracies)
return _ce_compute(confidences, accuracies, self.n_bins, norm=self.norm)
Expand Down Expand Up @@ -268,6 +272,7 @@ def __new__(
validate_args: bool = True,
**kwargs: Any,
) -> Metric:
"""Initialize task metric."""
task = ClassificationTaskNoMultilabel.from_str(task)
kwargs.update({"n_bins": n_bins, "norm": norm, "ignore_index": ignore_index, "validate_args": validate_args})
if task == ClassificationTaskNoMultilabel.BINARY:
Expand Down
3 changes: 3 additions & 0 deletions src/torchmetrics/classification/cohen_kappa.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def __init__(
self.validate_args = validate_args

def compute(self) -> Tensor:
"""Compute metric."""
return _cohen_kappa_reduce(self.confmat, self.weights)


Expand Down Expand Up @@ -184,6 +185,7 @@ def __init__(
self.validate_args = validate_args

def compute(self) -> Tensor:
"""Compute metric."""
return _cohen_kappa_reduce(self.confmat, self.weights)


Expand Down Expand Up @@ -222,6 +224,7 @@ def __new__(
validate_args: bool = True,
**kwargs: Any,
) -> Metric:
"""Initialize task metric."""
task = ClassificationTaskNoMultilabel.from_str(task)
kwargs.update({"weights": weights, "ignore_index": ignore_index, "validate_args": validate_args})
if task == ClassificationTaskNoMultilabel.BINARY:
Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/classification/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@ def __new__(
validate_args: bool = True,
**kwargs: Any,
) -> Metric:
"""Initialize task metric."""
task = ClassificationTask.from_str(task)
kwargs.update({"normalize": normalize, "ignore_index": ignore_index, "validate_args": validate_args})
if task == ClassificationTask.BINARY:
Expand Down
4 changes: 4 additions & 0 deletions src/torchmetrics/classification/exact_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def __init__(
)

def update(self, preds, target) -> None:
"""Update metric states with predictions and targets."""
if self.validate_args:
_multiclass_stat_scores_tensor_validation(
preds, target, self.num_classes, self.multidim_average, self.ignore_index
Expand All @@ -132,6 +133,7 @@ def update(self, preds, target) -> None:
self.total += total

def compute(self) -> Tensor:
"""Compute metric."""
correct = dim_zero_cat(self.correct) if isinstance(self.correct, list) else self.correct
return _exact_match_reduce(correct, self.total)

Expand Down Expand Up @@ -250,6 +252,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
self.total += total

def compute(self) -> Tensor:
"""Compute metric."""
correct = dim_zero_cat(self.correct) if isinstance(self.correct, list) else self.correct
return _exact_match_reduce(correct, self.total)

Expand Down Expand Up @@ -289,6 +292,7 @@ def __new__(
validate_args: bool = True,
**kwargs: Any,
) -> Metric:
"""Initialize task metric."""
task = ClassificationTaskNoBinary.from_str(task)
kwargs.update(
{"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args}
Expand Down
5 changes: 5 additions & 0 deletions src/torchmetrics/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def __init__(
self.beta = beta

def compute(self) -> Tensor:
"""Compute metric."""
tp, fp, tn, fn = self._final_state()
return _fbeta_reduce(tp, fp, tn, fn, self.beta, average="binary", multidim_average=self.multidim_average)

Expand Down Expand Up @@ -245,6 +246,7 @@ def __init__(
self.beta = beta

def compute(self) -> Tensor:
"""Compute metric."""
tp, fp, tn, fn = self._final_state()
return _fbeta_reduce(tp, fp, tn, fn, self.beta, average=self.average, multidim_average=self.multidim_average)

Expand Down Expand Up @@ -369,6 +371,7 @@ def __init__(
self.beta = beta

def compute(self) -> Tensor:
"""Compute metric."""
tp, fp, tn, fn = self._final_state()
return _fbeta_reduce(tp, fp, tn, fn, self.beta, average=self.average, multidim_average=self.multidim_average)

Expand Down Expand Up @@ -726,6 +729,7 @@ def __new__(
validate_args: bool = True,
**kwargs: Any,
) -> Metric:
"""Initialize task metric."""
assert multidim_average is not None
kwargs.update(
{"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args}
Expand Down Expand Up @@ -777,6 +781,7 @@ def __new__(
validate_args: bool = True,
**kwargs: Any,
) -> Metric:
"""Initialize task metric."""
task = ClassificationTask.from_str(task)
assert multidim_average is not None
kwargs.update(
Expand Down
4 changes: 4 additions & 0 deletions src/torchmetrics/classification/hamming.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class BinaryHammingDistance(BinaryStatScores):
full_state_update: bool = False

def compute(self) -> Tensor:
"""Compute metric."""
tp, fp, tn, fn = self._final_state()
return _hamming_distance_reduce(tp, fp, tn, fn, average="binary", multidim_average=self.multidim_average)

Expand Down Expand Up @@ -199,6 +200,7 @@ class MulticlassHammingDistance(MulticlassStatScores):
full_state_update: bool = False

def compute(self) -> Tensor:
"""Compute metric."""
tp, fp, tn, fn = self._final_state()
return _hamming_distance_reduce(tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average)

Expand Down Expand Up @@ -302,6 +304,7 @@ class MultilabelHammingDistance(MultilabelStatScores):
full_state_update: bool = False

def compute(self) -> Tensor:
"""Compute metric."""
tp, fp, tn, fn = self._final_state()
return _hamming_distance_reduce(
tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average, multilabel=True
Expand Down Expand Up @@ -345,6 +348,7 @@ def __new__(
validate_args: bool = True,
**kwargs: Any,
) -> Metric:
"""Initialize task metric."""
task = ClassificationTask.from_str(task)
assert multidim_average is not None
kwargs.update(
Expand Down
5 changes: 5 additions & 0 deletions src/torchmetrics/classification/hinge.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def __init__(
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
"""Update metric state."""
if self.validate_args:
_binary_hinge_loss_tensor_validation(preds, target, self.ignore_index)
preds, target = _binary_confusion_matrix_format(
Expand All @@ -109,6 +110,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
self.total += total

def compute(self) -> Tensor:
"""Compute metric."""
return _hinge_loss_compute(self.measures, self.total)


Expand Down Expand Up @@ -203,6 +205,7 @@ def __init__(
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
"""Update metric state."""
if self.validate_args:
_multiclass_hinge_loss_tensor_validation(preds, target, self.num_classes, self.ignore_index)
preds, target = _multiclass_confusion_matrix_format(preds, target, self.ignore_index, convert_to_labels=False)
Expand All @@ -211,6 +214,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
self.total += total

def compute(self) -> Tensor:
"""Compute metric."""
return _hinge_loss_compute(self.measures, self.total)


Expand Down Expand Up @@ -253,6 +257,7 @@ def __new__(
validate_args: bool = True,
**kwargs: Any,
) -> Metric:
"""Initialize task metric."""
task = ClassificationTaskNoMultilabel.from_str(task)
kwargs.update({"ignore_index": ignore_index, "validate_args": validate_args})
if task == ClassificationTaskNoMultilabel.BINARY:
Expand Down
4 changes: 4 additions & 0 deletions src/torchmetrics/classification/jaccard.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __init__(
)

def compute(self) -> Tensor:
"""Compute metric."""
return _jaccard_index_reduce(self.confmat, average="binary")


Expand Down Expand Up @@ -172,6 +173,7 @@ def __init__(
self.average = average

def compute(self) -> Tensor:
"""Compute metric."""
return _jaccard_index_reduce(self.confmat, average=self.average)


Expand Down Expand Up @@ -259,6 +261,7 @@ def __init__(
self.average = average

def compute(self) -> Tensor:
"""Compute metric."""
return _jaccard_index_reduce(self.confmat, average=self.average)


Expand Down Expand Up @@ -296,6 +299,7 @@ def __new__(
validate_args: bool = True,
**kwargs: Any,
) -> Metric:
"""Initialize task metric."""
task = ClassificationTask.from_str(task)
kwargs.update({"ignore_index": ignore_index, "validate_args": validate_args})
if task == ClassificationTask.BINARY:
Expand Down
4 changes: 4 additions & 0 deletions src/torchmetrics/classification/matthews_corrcoef.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(
super().__init__(threshold, ignore_index, normalize=None, validate_args=validate_args, **kwargs)

def compute(self) -> Tensor:
"""Compute metric."""
return _matthews_corrcoef_reduce(self.confmat)


Expand Down Expand Up @@ -144,6 +145,7 @@ def __init__(
super().__init__(num_classes, ignore_index, normalize=None, validate_args=validate_args, **kwargs)

def compute(self) -> Tensor:
"""Compute metric."""
return _matthews_corrcoef_reduce(self.confmat)


Expand Down Expand Up @@ -207,6 +209,7 @@ def __init__(
super().__init__(num_labels, threshold, ignore_index, normalize=None, validate_args=validate_args, **kwargs)

def compute(self) -> Tensor:
"""Compute metric."""
return _matthews_corrcoef_reduce(self.confmat)


Expand Down Expand Up @@ -238,6 +241,7 @@ def __new__(
validate_args: bool = True,
**kwargs: Any,
) -> Metric:
"""Initialize task metric."""
task = ClassificationTask.from_str(task)
kwargs.update({"ignore_index": ignore_index, "validate_args": validate_args})
if task == ClassificationTask.BINARY:
Expand Down
Loading