Skip to content

Commit

Permalink
typing
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Jul 5, 2021
1 parent f6d40fd commit 7be8d75
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 6 deletions.
2 changes: 1 addition & 1 deletion torchmetrics/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from torchmetrics.classification.hamming_distance import HammingDistance # noqa: F401
from torchmetrics.classification.hinge import Hinge # noqa: F401
from torchmetrics.classification.iou import IoU # noqa: F401
from torchmetrics.classification.kldivergence import KLDivergence # noqa: F401
from torchmetrics.classification.kl_divergence import KLDivergence # noqa: F401
from torchmetrics.classification.matthews_corrcoef import MatthewsCorrcoef # noqa: F401
from torchmetrics.classification.precision_recall import Precision, Recall # noqa: F401
from torchmetrics.classification.precision_recall_curve import PrecisionRecallCurve # noqa: F401
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, List, Optional, Union
from typing import Any, Callable, Optional

import torch
from torch import Tensor

from torchmetrics.functional.classification.kldivergence import _kld_compute, _kld_update
from torchmetrics.functional.classification.kl_divergence import _kld_compute, _kld_update
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat

Expand Down Expand Up @@ -60,6 +60,8 @@ class KLDivergence(Metric):
>>> kldivergence(p, q)
tensor(0.0853)
"""
# canot be used because if scripting
# measures: Union[List[Tensor], Tensor]
total: Tensor

def __init__(
Expand Down
8 changes: 7 additions & 1 deletion torchmetrics/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional, Tuple, Union, List
from typing import Any, Callable, Optional, Tuple

import torch
from torch import Tensor
Expand Down Expand Up @@ -130,6 +130,12 @@ class StatScores(Metric):
"""

# canot be used because if scripting
# tp: Union[Tensor, List[Tensor]]
# fp: Union[Tensor, List[Tensor]]
# tn: Union[Tensor, List[Tensor]]
# fn: Union[Tensor, List[Tensor]]

def __init__(
self,
threshold: float = 0.5,
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from torchmetrics.functional.classification.hamming_distance import hamming_distance # noqa: F401
from torchmetrics.functional.classification.hinge import hinge # noqa: F401
from torchmetrics.functional.classification.iou import iou # noqa: F401
from torchmetrics.functional.classification.kldivergence import kldivergence # noqa: F401
from torchmetrics.functional.classification.kl_divergence import kldivergence # noqa: F401
from torchmetrics.functional.classification.matthews_corrcoef import matthews_corrcoef # noqa: F401
from torchmetrics.functional.classification.precision_recall import precision, precision_recall, recall # noqa: F401
from torchmetrics.functional.classification.precision_recall_curve import precision_recall_curve # noqa: F401
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/functional/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from torchmetrics.functional.classification.hamming_distance import hamming_distance # noqa: F401
from torchmetrics.functional.classification.hinge import hinge # noqa: F401
from torchmetrics.functional.classification.iou import iou # noqa: F401
from torchmetrics.functional.classification.kldivergence import kldivergence # noqa: F401
from torchmetrics.functional.classification.kl_divergence import kldivergence # noqa: F401
from torchmetrics.functional.classification.matthews_corrcoef import matthews_corrcoef # noqa: F401
from torchmetrics.functional.classification.precision_recall import precision, precision_recall, recall # noqa: F401
from torchmetrics.functional.classification.precision_recall_curve import precision_recall_curve # noqa: F401
Expand Down

0 comments on commit 7be8d75

Please sign in to comment.