Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

extend typing: base #327

Merged
merged 25 commits into from
Jun 29, 2021
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .deepsource.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ enabled = true

[[transformers]]
name = "autopep8"
enabled = true
enabled = false
12 changes: 0 additions & 12 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -73,18 +73,6 @@ files = torchmetrics
disallow_untyped_defs = True
ignore_missing_imports = True

# todo: add proper typing to this module...
[mypy-torchmetrics.metric]
ignore_errors = True

# todo: add proper typing to this module...
[mypy-torchmetrics.collections]
ignore_errors = True

# todo: add proper typing to this module...
[mypy-torchmetrics.compositional]
ignore_errors = True

# todo: add proper typing to this module...
[mypy-torchmetrics.image.ssim]
ignore_errors = True
Expand Down
12 changes: 6 additions & 6 deletions torchmetrics/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from collections import OrderedDict
from copy import deepcopy
from typing import Any, Dict, Iterable, Optional, Sequence, Tuple, Union
from typing import Any, Dict, Hashable, Iterable, Optional, Sequence, Tuple, Union

import torch
from torch import nn
Expand Down Expand Up @@ -103,15 +103,15 @@ def __init__(
self.postfix = self._check_arg(postfix, 'postfix')

@torch.jit.unused
def forward(self, *args, **kwargs) -> Dict[str, Any]: # pylint: disable=E0202
def forward(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
"""
Iteratively call forward for each metric. Positional arguments (args) will
be passed to every metric in the collection, while keyword arguments (kwargs)
will be filtered based on the signature of the individual metric.
"""
return {k: m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items()}

def update(self, *args, **kwargs): # pylint: disable=E0202
def update(self, *args: Any, **kwargs: Any) -> None:
"""
Iteratively call update for each metric. Positional arguments (args) will
be passed to every metric in the collection, while keyword arguments (kwargs)
Expand Down Expand Up @@ -161,7 +161,7 @@ def add_metrics(
if isinstance(metrics, Sequence):
# prepare for optional additions
metrics = list(metrics)
remain = []
remain: list = []
for m in additional_metrics:
(metrics if isinstance(m, Metric) else remain).append(m)

Expand Down Expand Up @@ -207,7 +207,7 @@ def _to_renamed_ordered_dict(self) -> OrderedDict:
od[self._set_name(k)] = v
return od

def keys(self, keep_base: bool = False):
def keys(self, keep_base: bool = False) -> Iterable[Hashable]:
r"""Return an iterable of the ModuleDict key.
Args:
keep_base: Whether to add prefix/postfix on the items collection.
Expand All @@ -231,7 +231,7 @@ def _check_arg(arg: Optional[str], name: str) -> Optional[str]:
return arg
raise ValueError(f'Expected input `{name}` to be a string, but got {type(arg)}')

def __repr__(self) -> Optional[str]:
def __repr__(self) -> str:
repr_str = super().__repr__()[:-2]
if self.prefix:
repr_str += f",\n prefix={self.prefix}{',' if self.postfix else ''}"
Expand Down
Loading