Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add typing for attributes #314

Merged
merged 11 commits into from
Jul 5, 2021
2 changes: 2 additions & 0 deletions torchmetrics/audio/si_sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ class SI_SDR(Metric):
[1] Le Roux, Jonathan, et al. "SDR half-baked or well done." IEEE International Conference on Acoustics, Speech
and Signal Processing (ICASSP) 2019.
"""
sum_si_sdr: Tensor
total: Tensor

def __init__(
self,
Expand Down
2 changes: 2 additions & 0 deletions torchmetrics/audio/si_snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ class SI_SNR(Metric):
Separation," 2018 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), 2018, pp.
696-700, doi: 10.1109/ICASSP.2018.8462116.
"""
sum_si_snr: Tensor
total: Tensor

def __init__(
self,
Expand Down
2 changes: 2 additions & 0 deletions torchmetrics/audio/snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ class SNR(Metric):
[1] Le Roux, Jonathan, et al. "SDR half-baked or well done." IEEE International Conference on Acoustics, Speech
and Signal Processing (ICASSP) 2019.
"""
sum_snr: Tensor
total: Tensor

def __init__(
self,
Expand Down
2 changes: 2 additions & 0 deletions torchmetrics/average.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ class AverageMeter(Metric):
>>> avg(values, weights)
tensor(1.2500)
"""
value: Tensor
weight: Tensor

def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from torchmetrics.classification.hamming_distance import HammingDistance # noqa: F401
from torchmetrics.classification.hinge import Hinge # noqa: F401
from torchmetrics.classification.iou import IoU # noqa: F401
from torchmetrics.classification.kldivergence import KLDivergence # noqa: F401
from torchmetrics.classification.kl_divergence import KLDivergence # noqa: F401
from torchmetrics.classification.matthews_corrcoef import MatthewsCorrcoef # noqa: F401
from torchmetrics.classification.precision_recall import Precision, Recall # noqa: F401
from torchmetrics.classification.precision_recall_curve import PrecisionRecallCurve # noqa: F401
Expand Down
2 changes: 2 additions & 0 deletions torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ class Accuracy(StatScores):
tensor(0.6667)

"""
correct: Tensor
total: Tensor

def __init__(
self,
Expand Down
4 changes: 3 additions & 1 deletion torchmetrics/classification/auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional
from typing import Any, Callable, List, Optional

from torch import Tensor

Expand Down Expand Up @@ -43,6 +43,8 @@ class AUC(Metric):
Callback that performs the ``allgather`` operation on the metric state. When ``None``, DDP
will be used to perform the ``allgather``.
"""
x: List[Tensor]
y: List[Tensor]

def __init__(
self,
Expand Down
4 changes: 3 additions & 1 deletion torchmetrics/classification/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional
from typing import Any, Callable, List, Optional

import torch
from torch import Tensor
Expand Down Expand Up @@ -100,6 +100,8 @@ class AUROC(Metric):
tensor(0.7778)

"""
preds: List[Tensor]
target: List[Tensor]

def __init__(
self,
Expand Down
2 changes: 2 additions & 0 deletions torchmetrics/classification/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ class AveragePrecision(Metric):
[tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)]

"""
preds: List[Tensor]
target: List[Tensor]

def __init__(
self,
Expand Down
3 changes: 3 additions & 0 deletions torchmetrics/classification/binned_precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ class BinnedPrecisionRecallCurve(Metric):
tensor([0.0000, 0.5000, 1.0000]),
tensor([0.0000, 0.5000, 1.0000])]
"""
TPs: Tensor
FPs: Tensor
FNs: Tensor

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/classification/cohen_kappa.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class labels.
>>> cohenkappa(preds, target)
tensor(0.5000)
"""
confmat: Tensor

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/classification/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class ConfusionMatrix(Metric):
[[1., 0.], [1., 0.]],
[[0., 1.], [0., 1.]]])
"""
confmat: Tensor

def __init__(
self,
Expand Down
2 changes: 2 additions & 0 deletions torchmetrics/classification/hamming_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ class HammingDistance(Metric):
tensor(0.2500)

"""
correct: Tensor
total: Tensor

def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch
from torch import Tensor

from torchmetrics.functional.classification.kldivergence import _kld_compute, _kld_update
from torchmetrics.functional.classification.kl_divergence import _kld_compute, _kld_update
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat

Expand Down Expand Up @@ -60,6 +60,9 @@ class KLDivergence(Metric):
>>> kldivergence(p, q)
tensor(0.0853)
"""
# canot be used because if scripting
Borda marked this conversation as resolved.
Show resolved Hide resolved
# measures: Union[List[Tensor], Tensor]
total: Tensor

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/classification/matthews_corrcoef.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class MatthewsCorrcoef(Metric):
tensor(0.5774)

"""
confmat: Tensor

def __init__(
self,
Expand Down
2 changes: 2 additions & 0 deletions torchmetrics/classification/precision_recall_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ class PrecisionRecallCurve(Metric):
[tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])]

"""
preds: List[Tensor]
target: List[Tensor]

def __init__(
self,
Expand Down
2 changes: 2 additions & 0 deletions torchmetrics/classification/roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ class ROC(Metric):
tensor([1.1837, 0.1837, 0.1338, 0.1183, 0.1138])]

"""
preds: List[Tensor]
target: List[Tensor]

def __init__(
self,
Expand Down
6 changes: 6 additions & 0 deletions torchmetrics/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,12 @@ class StatScores(Metric):

"""

# canot be used because if scripting
Borda marked this conversation as resolved.
Show resolved Hide resolved
# tp: Union[Tensor, List[Tensor]]
# fp: Union[Tensor, List[Tensor]]
# tn: Union[Tensor, List[Tensor]]
# fn: Union[Tensor, List[Tensor]]

def __init__(
self,
threshold: float = 0.5,
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from torchmetrics.functional.classification.hamming_distance import hamming_distance # noqa: F401
from torchmetrics.functional.classification.hinge import hinge # noqa: F401
from torchmetrics.functional.classification.iou import iou # noqa: F401
from torchmetrics.functional.classification.kldivergence import kldivergence # noqa: F401
from torchmetrics.functional.classification.kl_divergence import kldivergence # noqa: F401
from torchmetrics.functional.classification.matthews_corrcoef import matthews_corrcoef # noqa: F401
from torchmetrics.functional.classification.precision_recall import precision, precision_recall, recall # noqa: F401
from torchmetrics.functional.classification.precision_recall_curve import precision_recall_curve # noqa: F401
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/functional/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from torchmetrics.functional.classification.hamming_distance import hamming_distance # noqa: F401
from torchmetrics.functional.classification.hinge import hinge # noqa: F401
from torchmetrics.functional.classification.iou import iou # noqa: F401
from torchmetrics.functional.classification.kldivergence import kldivergence # noqa: F401
from torchmetrics.functional.classification.kl_divergence import kldivergence # noqa: F401
from torchmetrics.functional.classification.matthews_corrcoef import matthews_corrcoef # noqa: F401
from torchmetrics.functional.classification.precision_recall import precision, precision_recall, recall # noqa: F401
from torchmetrics.functional.classification.precision_recall_curve import precision_recall_curve # noqa: F401
Expand Down
2 changes: 2 additions & 0 deletions torchmetrics/image/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@ class FID(Metric):
tensor(12.7202)

"""
real_features: List[Tensor]
fake_features: List[Tensor]

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/image/inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class IS(Metric):
(tensor(1.0569), tensor(0.0113))

"""
features: List

def __init__(
self,
Expand Down
4 changes: 3 additions & 1 deletion torchmetrics/image/kid.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional, Tuple, Union
from typing import Any, Callable, List, Optional, Tuple, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -168,6 +168,8 @@ class KID(Metric):
(tensor(0.0338), tensor(0.0025))

"""
real_features: List[Tensor]
fake_features: List[Tensor]

def __init__(
self,
Expand Down
16 changes: 9 additions & 7 deletions torchmetrics/image/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Sequence
from typing import Any, List, Optional, Sequence

import torch
from torch import Tensor
Expand Down Expand Up @@ -51,6 +51,8 @@ class SSIM(Metric):
>>> ssim(preds, target)
tensor(0.9219)
"""
preds: List[Tensor]
target: List[Tensor]

def __init__(
self,
Expand All @@ -75,8 +77,8 @@ def __init__(
' to large memory footprint.'
)

self.add_state("y", default=[], dist_reduce_fx="cat")
self.add_state("y_pred", default=[], dist_reduce_fx="cat")
self.add_state("preds", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")
self.kernel_size = kernel_size
self.sigma = sigma
self.data_range = data_range
Expand All @@ -93,15 +95,15 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
target: Ground truth values
"""
preds, target = _ssim_update(preds, target)
self.y_pred.append(preds)
self.y.append(target)
self.preds.append(preds)
self.target.append(target)

def compute(self) -> Tensor:
"""
Computes explained variance over state.
"""
preds = dim_zero_cat(self.y_pred)
target = dim_zero_cat(self.y)
preds = dim_zero_cat(self.preds)
target = dim_zero_cat(self.target)
return _ssim_compute(
preds, target, self.kernel_size, self.sigma, self.reduction, self.data_range, self.k1, self.k2
)
4 changes: 3 additions & 1 deletion torchmetrics/regression/cosine_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional
from typing import Any, Callable, List, Optional

import torch
from torch import Tensor
Expand Down Expand Up @@ -60,6 +60,8 @@ class CosineSimilarity(Metric):
>>> cosine_similarity(preds, target)
tensor(0.8536)
"""
preds: List[Tensor]
target: List[Tensor]

def __init__(
self,
Expand Down
2 changes: 2 additions & 0 deletions torchmetrics/regression/mean_absolute_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class MeanAbsoluteError(Metric):
>>> mean_absolute_error(preds, target)
tensor(0.5000)
"""
sum_abs_error: Tensor
total: Tensor

def __init__(
self,
Expand Down
2 changes: 2 additions & 0 deletions torchmetrics/regression/mean_absolute_percentage_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ class MeanAbsolutePercentageError(Metric):
>>> mean_abs_percentage_error(preds, target)
tensor(0.2667)
"""
sum_abs_per_error: Tensor
total: Tensor

def __init__(
self,
Expand Down
2 changes: 2 additions & 0 deletions torchmetrics/regression/mean_squared_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ class MeanSquaredError(Metric):
tensor(0.8750)

"""
sum_squared_error: Tensor
total: Tensor

def __init__(
self,
Expand Down
2 changes: 2 additions & 0 deletions torchmetrics/regression/mean_squared_log_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ class MeanSquaredLogError(Metric):
Half precision is only support on GPU for this metric

"""
sum_squared_log_error: Tensor
total: Tensor

def __init__(
self,
Expand Down
4 changes: 3 additions & 1 deletion torchmetrics/regression/pearson.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional
from typing import Any, List, Optional

import torch
from torch import Tensor
Expand Down Expand Up @@ -56,6 +56,8 @@ class PearsonCorrcoef(Metric):
tensor(0.9849)

"""
preds: List[Tensor]
target: List[Tensor]

def __init__(
self,
Expand Down
4 changes: 4 additions & 0 deletions torchmetrics/regression/r2score.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ class R2Score(Metric):
>>> r2score(preds, target)
tensor([0.9654, 0.9082])
"""
sum_squared_error: Tensor
sum_error: Tensor
residual: Tensor
total: Tensor

def __init__(
self,
Expand Down
4 changes: 3 additions & 1 deletion torchmetrics/regression/spearman.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional
from typing import Any, Callable, List, Optional

import torch
from torch import Tensor
Expand Down Expand Up @@ -53,6 +53,8 @@ class SpearmanCorrcoef(Metric):
>>> spearman(preds, target)
tensor(1.0000)
"""
preds: List[Tensor]
target: List[Tensor]

def __init__(
self,
Expand Down
Loading