Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing typing #1215

Merged
merged 10 commits into from
Sep 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/code-format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ name: Code formatting

# see: https://help.github.com/en/actions/reference/events-that-trigger-workflows
on: # Trigger the workflow on push or pull request, but only for the master branch
push:
branches: [master, "release/*"]
push: {}
pull_request:
branches: [master, "release/*"]

Expand Down
13 changes: 9 additions & 4 deletions src/torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
_subset_accuracy_compute,
_subset_accuracy_update,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.enums import AverageMethod, DataType

from torchmetrics.classification.stat_scores import ( # isort:skip
Expand Down Expand Up @@ -461,7 +462,7 @@ def __new__(
cls,
threshold: float = 0.5,
num_classes: Optional[int] = None,
average: Optional[str] = "micro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
mdmc_average: Optional[str] = None,
ignore_index: Optional[int] = None,
top_k: Optional[int] = None,
Expand All @@ -472,16 +473,20 @@ def __new__(
multidim_average: Optional[Literal["global", "samplewise"]] = "global",
validate_args: bool = True,
**kwargs: Any,
) -> None:
) -> Metric:
if task is not None:
assert multidim_average is not None
kwargs.update(
dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args)
)
if task == "binary":
return BinaryAccuracy(threshold, **kwargs)
if task == "multiclass":
return MulticlassAccuracy(num_classes, average, top_k, **kwargs)
assert isinstance(num_classes, int)
assert isinstance(top_k, int)
return MulticlassAccuracy(num_classes, top_k, average, **kwargs)
if task == "multilabel":
assert isinstance(num_labels, int)
return MultilabelAccuracy(num_labels, threshold, average, **kwargs)
raise ValueError(
f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}"
Expand All @@ -492,7 +497,7 @@ def __init__(
self,
threshold: float = 0.5,
num_classes: Optional[int] = None,
average: Optional[str] = "micro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
mdmc_average: Optional[str] = None,
ignore_index: Optional[int] = None,
top_k: Optional[int] = None,
Expand Down
9 changes: 6 additions & 3 deletions src/torchmetrics/classification/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ class MulticlassAUROC(MulticlassPrecisionRecallCurve):
- ``macro``: Calculate score for each class and average them
- ``weighted``: Calculates score for each class and computes weighted average using their support
- ``"none"`` or ``None``: Calculates score for each class and applies no reduction

thresholds:
Can be one of:

Expand Down Expand Up @@ -408,22 +409,24 @@ def __new__(
cls,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
average: Optional[str] = "macro",
average: Optional[Literal["macro", "weighted", "none"]] = "macro",
max_fpr: Optional[float] = None,
task: Optional[Literal["binary", "multiclass", "multilabel"]] = None,
thresholds: Optional[Union[int, List[float], Tensor]] = None,
num_labels: Optional[int] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs: Any,
) -> None:
) -> Metric:
if task is not None:
kwargs.update(dict(thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args))
if task == "binary":
return BinaryAUROC(max_fpr, **kwargs)
if task == "multiclass":
assert isinstance(num_classes, int)
return MulticlassAUROC(num_classes, average, **kwargs)
if task == "multilabel":
assert isinstance(num_labels, int)
return MultilabelAUROC(num_labels, average, **kwargs)
raise ValueError(
f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}"
Expand All @@ -434,7 +437,7 @@ def __init__(
self,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
average: Optional[str] = "macro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
max_fpr: Optional[float] = None,
task: Optional[Literal["binary", "multiclass", "multilabel"]] = None,
thresholds: Optional[Union[int, List[float], Tensor]] = None,
Expand Down
8 changes: 5 additions & 3 deletions src/torchmetrics/classification/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,21 +388,23 @@ def __new__(
cls,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
average: Optional[str] = "macro",
average: Optional[Literal["macro", "weighted", "none"]] = "macro",
task: Optional[Literal["binary", "multiclass", "multilabel"]] = None,
thresholds: Optional[Union[int, List[float], Tensor]] = None,
num_labels: Optional[int] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs: Any,
) -> None:
) -> Metric:
if task is not None:
kwargs.update(dict(thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args))
if task == "binary":
return BinaryAveragePrecision(**kwargs)
if task == "multiclass":
assert isinstance(num_classes, int)
return MulticlassAveragePrecision(num_classes, average, **kwargs)
if task == "multilabel":
assert isinstance(num_labels, int)
return MultilabelAveragePrecision(num_labels, average, **kwargs)
raise ValueError(
f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}"
Expand All @@ -413,7 +415,7 @@ def __init__(
self,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
average: Optional[str] = "macro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
task: Optional[Literal["binary", "multiclass", "multilabel"]] = None,
thresholds: Optional[Union[int, List[float], Tensor]] = None,
num_labels: Optional[int] = None,
Expand Down
3 changes: 2 additions & 1 deletion src/torchmetrics/classification/calibration_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,12 +274,13 @@ def __new__(
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs: Any,
) -> None:
) -> Metric:
if task is not None:
kwargs.update(dict(n_bins=n_bins, norm=norm, ignore_index=ignore_index, validate_args=validate_args))
if task == "binary":
return BinaryCalibrationError(**kwargs)
if task == "multiclass":
assert isinstance(num_classes, int)
return MulticlassCalibrationError(num_classes, **kwargs)
raise ValueError(
f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}"
Expand Down
3 changes: 2 additions & 1 deletion src/torchmetrics/classification/cohen_kappa.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,12 +245,13 @@ def __new__(
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs: Any,
) -> None:
) -> Metric:
if task is not None:
kwargs.update(dict(weights=weights, ignore_index=ignore_index, validate_args=validate_args))
if task == "binary":
return BinaryCohenKappa(threshold, **kwargs)
if task == "multiclass":
assert isinstance(num_classes, int)
return MulticlassCohenKappa(num_classes, **kwargs)
raise ValueError(
f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}"
Expand Down
4 changes: 3 additions & 1 deletion src/torchmetrics/classification/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,14 +401,16 @@ def __new__(
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs: Any,
) -> None:
) -> Metric:
if task is not None:
kwargs.update(dict(normalize=normalize, ignore_index=ignore_index, validate_args=validate_args))
if task == "binary":
return BinaryConfusionMatrix(threshold, **kwargs)
if task == "multiclass":
assert isinstance(num_classes, int)
return MulticlassConfusionMatrix(num_classes, **kwargs)
if task == "multilabel":
assert isinstance(num_labels, int)
return MultilabelConfusionMatrix(num_labels, threshold, **kwargs)
raise ValueError(
f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}"
Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/classification/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional
from typing import Any, Literal, Optional

from torch import Tensor

Expand Down Expand Up @@ -124,7 +124,7 @@ def __init__(
zero_division: int = 0,
num_classes: Optional[int] = None,
threshold: float = 0.5,
average: Optional[str] = "micro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
mdmc_average: Optional[str] = "global",
ignore_index: Optional[int] = None,
top_k: Optional[int] = None,
Expand Down
25 changes: 17 additions & 8 deletions src/torchmetrics/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
_multiclass_fbeta_score_arg_validation,
_multilabel_fbeta_score_arg_validation,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.enums import AverageMethod


Expand Down Expand Up @@ -822,7 +823,7 @@ def __new__(
num_classes: Optional[int] = None,
beta: float = 1.0,
threshold: float = 0.5,
average: Optional[str] = "micro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
mdmc_average: Optional[str] = None,
ignore_index: Optional[int] = None,
top_k: Optional[int] = None,
Expand All @@ -832,16 +833,20 @@ def __new__(
multidim_average: Optional[Literal["global", "samplewise"]] = "global",
validate_args: bool = True,
**kwargs: Any,
) -> None:
) -> Metric:
if task is not None:
assert multidim_average is not None
kwargs.update(
dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args)
)
if task == "binary":
return BinaryFBetaScore(beta, threshold, **kwargs)
if task == "multiclass":
return MulticlassFBetaScore(beta, num_classes, average, top_k, **kwargs)
assert isinstance(num_classes, int)
assert isinstance(top_k, int)
return MulticlassFBetaScore(beta, num_classes, top_k, average, **kwargs)
if task == "multilabel":
assert isinstance(num_labels, int)
return MultilabelFBetaScore(beta, num_labels, threshold, average, **kwargs)
raise ValueError(
f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}"
Expand All @@ -853,7 +858,7 @@ def __init__(
num_classes: Optional[int] = None,
beta: float = 1.0,
threshold: float = 0.5,
average: Optional[str] = "micro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
mdmc_average: Optional[str] = None,
ignore_index: Optional[int] = None,
top_k: Optional[int] = None,
Expand Down Expand Up @@ -983,7 +988,7 @@ def __new__(
cls,
num_classes: Optional[int] = None,
threshold: float = 0.5,
average: Optional[str] = "micro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
mdmc_average: Optional[str] = None,
ignore_index: Optional[int] = None,
top_k: Optional[int] = None,
Expand All @@ -993,16 +998,20 @@ def __new__(
multidim_average: Optional[Literal["global", "samplewise"]] = "global",
validate_args: bool = True,
**kwargs: Any,
) -> None:
) -> Metric:
if task is not None:
assert multidim_average is not None
kwargs.update(
dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args)
)
if task == "binary":
return BinaryF1Score(threshold, **kwargs)
if task == "multiclass":
return MulticlassF1Score(num_classes, average, top_k, **kwargs)
assert isinstance(num_classes, int)
assert isinstance(top_k, int)
return MulticlassF1Score(num_classes, top_k, average, **kwargs)
if task == "multilabel":
assert isinstance(num_labels, int)
return MultilabelF1Score(num_labels, threshold, average, **kwargs)
raise ValueError(
f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}"
Expand All @@ -1013,7 +1022,7 @@ def __init__(
self,
num_classes: Optional[int] = None,
threshold: float = 0.5,
average: Optional[str] = "micro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
mdmc_average: Optional[str] = None,
ignore_index: Optional[int] = None,
top_k: Optional[int] = None,
Expand Down
10 changes: 7 additions & 3 deletions src/torchmetrics/classification/hamming.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,22 +368,26 @@ def __new__(
task: Optional[Literal["binary", "multiclass", "multilabel"]] = None,
num_classes: Optional[int] = None,
num_labels: Optional[int] = None,
average: Optional[str] = "micro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
multidim_average: Optional[Literal["global", "samplewise"]] = "global",
top_k: Optional[int] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs: Any,
) -> None:
) -> Metric:
if task is not None:
assert multidim_average is not None
kwargs.update(
dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args)
)
if task == "binary":
return BinaryHammingDistance(threshold, **kwargs)
if task == "multiclass":
return MulticlassHammingDistance(num_classes, average, top_k, **kwargs)
assert isinstance(num_classes, int)
assert isinstance(top_k, int)
return MulticlassHammingDistance(num_classes, top_k, average, **kwargs)
if task == "multilabel":
assert isinstance(num_labels, int)
return MultilabelHammingDistance(num_labels, threshold, average, **kwargs)
raise ValueError(
f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}"
Expand Down
6 changes: 4 additions & 2 deletions src/torchmetrics/classification/hinge.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,18 +284,20 @@ class HingeLoss(Metric):
def __new__(
cls,
squared: bool = False,
multiclass_mode: Optional[Union[str, MulticlassMode]] = None,
multiclass_mode: Literal["crammer-singer", "one-vs-all"] = None,
task: Optional[Literal["binary", "multiclass", "multilabel"]] = None,
num_classes: Optional[int] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs: Any,
) -> None:
) -> Metric:
if task is not None:
kwargs.update(dict(ignore_index=ignore_index, validate_args=validate_args))
if task == "binary":
return BinaryHingeLoss(squared, **kwargs)
if task == "multiclass":
assert isinstance(num_classes, int)
assert multiclass_mode is not None
return MulticlassHingeLoss(num_classes, squared, multiclass_mode, **kwargs)
raise ValueError(
f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}"
Expand Down
Loading