Skip to content

Commit

Permalink
[pre-commit.ci] pre-commit suggestions (#1688)
Browse files Browse the repository at this point in the history
* [pre-commit.ci] pre-commit suggestions

updates:
- [github.com/psf/black: 23.1.0 → 23.3.0](psf/black@23.1.0...23.3.0)
- [github.com/charliermarsh/ruff-pre-commit: v0.0.255 → v0.0.260](astral-sh/ruff-pre-commit@v0.0.255...v0.0.260)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* S301

* S310

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
  • Loading branch information
3 people committed Apr 4, 2023
1 parent 2509448 commit fd8fa6f
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 11 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ repos:
require_serial: false

- repo: https://github.com/psf/black
rev: 23.1.0
rev: 23.3.0
hooks:
- id: black
name: Format code
Expand Down Expand Up @@ -87,7 +87,7 @@ repos:
- flake8-bandit

- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.255
rev: v0.0.260
hooks:
- id: ruff
args: ["--fix"]
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ ignore = [
"D107", # Missing docstring in `__init__`
"ANN101", # Missing type annotation for `self` in method
"ANN102", # Missing type annotation for `cls` in classmethod
"S301", # `pickle` and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue # todo
"S310", # Audit URL open for permitted schemes. Allowing use of `file:` or custom schemes is often unexpected. # todo
]
# Exclude a variety of commonly ignored directories.
exclude = [
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def _init_compute_groups(self) -> None:
simply initialize each metric in the collection as its own group
"""
if isinstance(self._enable_compute_groups, list):
self._groups = {i: k for i, k in enumerate(self._enable_compute_groups)}
self._groups = dict(enumerate(self._enable_compute_groups))
for v in self._groups.values():
for metric in v:
if metric not in self:
Expand Down
11 changes: 4 additions & 7 deletions src/torchmetrics/text/chrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,12 +166,9 @@ def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor]]:

def _convert_states_to_dicts(self) -> _DICT_STATES_TYPES:
"""Convert global metric states to the n-gram dictionaries to be passed in ``_chrf_score_update``."""
n_grams_dicts: Dict[str, Dict[int, Tensor]] = {
name: n_gram_dict
for name, n_gram_dict in zip(
_DICT_STATES_NAMES, _prepare_n_grams_dicts(self.n_char_order, self.n_word_order)
)
}
n_grams_dicts: Dict[str, Dict[int, Tensor]] = dict(
zip(_DICT_STATES_NAMES, _prepare_n_grams_dicts(self.n_char_order, self.n_word_order))
)

for (n_gram_level, n_gram_order), text in self._get_text_n_gram_iterator():
for n in range(1, n_gram_order + 1):
Expand All @@ -184,7 +181,7 @@ def _convert_states_to_dicts(self) -> _DICT_STATES_TYPES:

def _update_states_from_dicts(self, n_grams_dicts_tuple: _DICT_STATES_TYPES) -> None:
"""Update global metric states based on the n-gram dictionaries calculated on the current batch."""
n_grams_dicts = {name: n_gram_dict for name, n_gram_dict, in zip(_DICT_STATES_NAMES, n_grams_dicts_tuple)}
n_grams_dicts = dict(zip(_DICT_STATES_NAMES, n_grams_dicts_tuple))
for (n_gram_level, n_gram_order), text in self._get_text_n_gram_iterator():
for n in range(1, n_gram_order + 1):
dict_name = self._get_dict_name(text, n_gram_level)
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/wrappers/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def plot(
>>> fig_, ax_ = tracker.plot() # plot all epochs
"""
val = val if val is not None else [val for val in self.compute_all()]
val = val if val is not None else list(self.compute_all())
fig, ax = plot_single_or_multi_val(
val,
ax=ax,
Expand Down

0 comments on commit fd8fa6f

Please sign in to comment.