diff --git a/torchmetrics/classification/__init__.py b/torchmetrics/classification/__init__.py index 36ebd20781b..023b1ed59e5 100644 --- a/torchmetrics/classification/__init__.py +++ b/torchmetrics/classification/__init__.py @@ -24,7 +24,7 @@ from torchmetrics.classification.hamming_distance import HammingDistance # noqa: F401 from torchmetrics.classification.hinge import Hinge # noqa: F401 from torchmetrics.classification.iou import IoU # noqa: F401 -from torchmetrics.classification.kldivergence import KLDivergence # noqa: F401 +from torchmetrics.classification.kl_divergence import KLDivergence # noqa: F401 from torchmetrics.classification.matthews_corrcoef import MatthewsCorrcoef # noqa: F401 from torchmetrics.classification.precision_recall import Precision, Recall # noqa: F401 from torchmetrics.classification.precision_recall_curve import PrecisionRecallCurve # noqa: F401 diff --git a/torchmetrics/classification/kldivergence.py b/torchmetrics/classification/kl_divergence.py similarity index 95% rename from torchmetrics/classification/kldivergence.py rename to torchmetrics/classification/kl_divergence.py index 08bec92962a..c8935d60ae2 100644 --- a/torchmetrics/classification/kldivergence.py +++ b/torchmetrics/classification/kl_divergence.py @@ -11,12 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, List, Optional, Union +from typing import Any, Callable, Optional import torch from torch import Tensor -from torchmetrics.functional.classification.kldivergence import _kld_compute, _kld_update +from torchmetrics.functional.classification.kl_divergence import _kld_compute, _kld_update from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat @@ -60,6 +60,8 @@ class KLDivergence(Metric): >>> kldivergence(p, q) tensor(0.0853) """ + # canot be used because if scripting + # measures: Union[List[Tensor], Tensor] total: Tensor def __init__( diff --git a/torchmetrics/classification/stat_scores.py b/torchmetrics/classification/stat_scores.py index 5a1ec84bc6f..87dab7927ca 100644 --- a/torchmetrics/classification/stat_scores.py +++ b/torchmetrics/classification/stat_scores.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Tuple, Union, List +from typing import Any, Callable, Optional, Tuple import torch from torch import Tensor @@ -130,6 +130,12 @@ class StatScores(Metric): """ + # canot be used because if scripting + # tp: Union[Tensor, List[Tensor]] + # fp: Union[Tensor, List[Tensor]] + # tn: Union[Tensor, List[Tensor]] + # fn: Union[Tensor, List[Tensor]] + def __init__( self, threshold: float = 0.5, diff --git a/torchmetrics/functional/__init__.py b/torchmetrics/functional/__init__.py index 8074ebab775..631ef7e4246 100644 --- a/torchmetrics/functional/__init__.py +++ b/torchmetrics/functional/__init__.py @@ -25,7 +25,7 @@ from torchmetrics.functional.classification.hamming_distance import hamming_distance # noqa: F401 from torchmetrics.functional.classification.hinge import hinge # noqa: F401 from torchmetrics.functional.classification.iou import iou # noqa: F401 -from torchmetrics.functional.classification.kldivergence import kldivergence # noqa: F401 +from torchmetrics.functional.classification.kl_divergence import kldivergence # noqa: F401 from torchmetrics.functional.classification.matthews_corrcoef import matthews_corrcoef # noqa: F401 from torchmetrics.functional.classification.precision_recall import precision, precision_recall, recall # noqa: F401 from torchmetrics.functional.classification.precision_recall_curve import precision_recall_curve # noqa: F401 diff --git a/torchmetrics/functional/classification/__init__.py b/torchmetrics/functional/classification/__init__.py index 205f67b8b67..d8319601f23 100644 --- a/torchmetrics/functional/classification/__init__.py +++ b/torchmetrics/functional/classification/__init__.py @@ -22,7 +22,7 @@ from torchmetrics.functional.classification.hamming_distance import hamming_distance # noqa: F401 from torchmetrics.functional.classification.hinge import hinge # noqa: F401 from torchmetrics.functional.classification.iou import iou # noqa: F401 -from torchmetrics.functional.classification.kldivergence import kldivergence # noqa: F401 +from torchmetrics.functional.classification.kl_divergence import kldivergence # noqa: F401 from torchmetrics.functional.classification.matthews_corrcoef import matthews_corrcoef # noqa: F401 from torchmetrics.functional.classification.precision_recall import precision, precision_recall, recall # noqa: F401 from torchmetrics.functional.classification.precision_recall_curve import precision_recall_curve # noqa: F401 diff --git a/torchmetrics/functional/classification/kldivergence.py b/torchmetrics/functional/classification/kl_divergence.py similarity index 100% rename from torchmetrics/functional/classification/kldivergence.py rename to torchmetrics/functional/classification/kl_divergence.py