Skip to content

Commit

Permalink
Merge branch 'master' into nisqa
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Oct 29, 2024
2 parents 64195d7 + 0a19c47 commit 9f9f28e
Show file tree
Hide file tree
Showing 40 changed files with 693 additions and 160 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/ci-integrate.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ jobs:

- name: source cashing
uses: ./.github/actions/pull-caches
with:
requires: ${{ matrix.requires }}
- name: set oldest if/only for integrations
if: matrix.requires == 'oldest'
run: python .github/assistant.py set-oldest-versions --req_files='["requirements/_integrate.txt"]'
Expand Down
9 changes: 4 additions & 5 deletions .github/workflows/ci-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ jobs:
strategy:
fail-fast: false
matrix:
os: ["ubuntu-20.04"]
python-version: ["3.9"]
os: ["ubuntu-22.04"]
python-version: ["3.10"]
pytorch-version:
- "2.0.1"
- "2.1.2"
Expand All @@ -42,9 +42,8 @@ jobs:
- "2.5.0"
include:
# cover additional python and PT combinations
- { os: "ubuntu-22.04", python-version: "3.10", pytorch-version: "2.0.1" }
- { os: "ubuntu-22.04", python-version: "3.10", pytorch-version: "2.2.2" }
- { os: "ubuntu-22.04", python-version: "3.11", pytorch-version: "2.3.1" }
- { os: "ubuntu-20.04", python-version: "3.8", pytorch-version: "2.0.1", requires: "oldest" }
- { os: "ubuntu-22.04", python-version: "3.11", pytorch-version: "2.4.1" }
- { os: "ubuntu-22.04", python-version: "3.12", pytorch-version: "2.5.0" }
# standard mac machine, not the M1
- { os: "macOS-13", python-version: "3.10", pytorch-version: "2.0.1" }
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/docs-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ jobs:
- name: source cashing
uses: ./.github/actions/pull-caches
with:
requires: ${{ matrix.requires }}
pytorch-version: ${{ matrix.pytorch-version }}
pypi-dir: ${{ env.PYPI_CACHE }}

Expand Down
21 changes: 19 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,40 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added a new audio metric `NISQA` ([#2792](https://github.com/PyTorchLightning/metrics/pull/2792))


- Added `Dice` metric to segmentation metrics ([#2725](https://github.com/Lightning-AI/torchmetrics/pull/2725))


### Changed

-
- Changed naming and input order arguments in `KLDivergence` ([#2800](https://github.com/Lightning-AI/torchmetrics/pull/2800))

### Deprecated

- Deprecated Dice from classification metrics ([#2725](https://github.com/Lightning-AI/torchmetrics/pull/2725))


### Removed

- Changed minimum supported Pytorch version to 2.0 ([#2671](https://github.com/Lightning-AI/torchmetrics/pull/2671))


- Removed `num_outputs` in `R2Score` ([#2800](https://github.com/Lightning-AI/torchmetrics/pull/2800))


### Fixed

- Changing `_modules` dict type in Pytorch 2.5 preventing to fail collections metrics ([#2793](https://github.com/Lightning-AI/torchmetrics/pull/2793))
- Fixed iou scores in detection for either empty predictions/targets leading to wrong scores ([#2805](https://github.com/Lightning-AI/torchmetrics/pull/2805))


---

## [1.5.1] - 2024-10-22

### Fixed

- Changing `_modules` dict type in Pytorch 2.5 preventing to fail collections metrics ([#2793](https://github.com/Lightning-AI/torchmetrics/pull/2793))


## [1.5.0] - 2024-10-18

### Added
Expand Down
22 changes: 22 additions & 0 deletions docs/source/segmentation/dice.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
.. customcarditem::
:header: Score
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg
:tags: Segmentation

.. include:: ../links.rst

##########
Dice Score
##########

Module Interface
________________

.. autoclass:: torchmetrics.segmentation.DiceScore
:noindex:

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.segmentation.dice_score
:noindex:
2 changes: 1 addition & 1 deletion requirements/_tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ fire ==0.7.*
cloudpickle >1.3, <=3.1.0
scikit-learn ==1.2.*; python_version < "3.9"
scikit-learn ==1.5.*; python_version > "3.8" # we do not use `> =` because of oldest replcement
cachier ==3.0.1
cachier ==3.1.2
1 change: 1 addition & 0 deletions requirements/audio.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# this need to be the same as used inside speechmetrics
pesq >=0.0.4, <0.0.5
numpy <2.0 # strict, for compatibility reasons
pystoi >=0.4.0, <0.5.0
torchaudio >=2.0.1, <2.6.0
gammatone >=1.0.0, <1.1.0
Expand Down
2 changes: 1 addition & 1 deletion requirements/base.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

numpy >1.20.0, <2.0 # strict, for compatibility reasons
numpy >1.20.0
packaging >17.1
torch >=2.0.0, <2.6.0
typing-extensions; python_version < '3.9'
Expand Down
2 changes: 1 addition & 1 deletion requirements/multimodal.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

transformers >=4.42.3, <4.46.0
transformers >=4.42.3, <4.47.0
piq <=0.8.0
3 changes: 2 additions & 1 deletion requirements/segmentation_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

scipy >1.0.0, <1.15.0
monai ==1.4.0
monai ==1.3.2 ; python_version < "3.9"
monai ==1.4.0 ; python_version > "3.8"
2 changes: 1 addition & 1 deletion requirements/text.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
nltk >3.8.1, <=3.9.1
tqdm <4.67.0
regex >=2021.9.24, <=2024.9.11
transformers >4.4.0, <4.46.0
transformers >4.4.0, <4.47.0
mecab-python3 >=1.0.6, <1.1.0
ipadic >=1.0.0, <1.1.0
sentencepiece >=0.2.0, <0.3.0
2 changes: 1 addition & 1 deletion requirements/typing.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
mypy ==1.11.2
mypy ==1.13.0
torch ==2.5.0

types-PyYAML
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,5 +245,6 @@ def _prepare_extras(skip_pattern: str = "^_", skip_files: Tuple[str] = ("base.tx
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
],
)
2 changes: 1 addition & 1 deletion src/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,5 @@ def collect(self) -> GeneratorExit:
def pytest_collect_file(parent: Path, path: Path) -> Optional[DoctestModule]:
"""Collect doctests and add the reset_random_seed fixture."""
if path.ext == ".py":
return DoctestModule.from_parent(parent, fspath=path)
return DoctestModule.from_parent(parent, path=Path(path))
return None
15 changes: 15 additions & 0 deletions src/torchmetrics/classification/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torchmetrics.functional.classification.dice import _dice_compute
from torchmetrics.functional.classification.stat_scores import _stat_scores_update
from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
Expand Down Expand Up @@ -114,6 +115,12 @@ class Dice(Metric):
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
.. warning::
The ``dice`` metrics is being deprecated from the classification subpackage in v1.6.0 of torchmetrics and will
be removed in v1.7.0. Please instead consider using ``f1score`` metric from the classification subpackage as it
provides the same functionality. Additionally, we are going to re-add the ``dice`` metric in the segmentation
domain in v1.6.0 with slight modifications to functionality.
Raises:
ValueError:
If ``average`` is none of ``"micro"``, ``"macro"``, ``"samples"``, ``"none"``, ``None``.
Expand Down Expand Up @@ -155,6 +162,14 @@ def __init__(
multiclass: Optional[bool] = None,
**kwargs: Any,
) -> None:
rank_zero_warn(
"The `dice` metrics is being deprecated from the classification subpackage in v1.6.0 of torchmetrics and"
" will removed in v1.7.0. Please instead consider using `f1score` metric from the classification subpackage"
" as it provides the same functionality. Additionally, we are going to re-add the `dice` metric in the"
" segmentation domain in v1.6.0 with slight modifications to functionality.",
DeprecationWarning,
)

super().__init__(**kwargs)
allowed_average = ("micro", "macro", "samples", "none", None)
if average not in allowed_average:
Expand Down
28 changes: 26 additions & 2 deletions src/torchmetrics/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,30 @@
__doctest_skip__ = ["MetricCollection.plot", "MetricCollection.plot_all"]


def _remove_prefix(string: str, prefix: str) -> str:
"""Patch for older version with missing method `removeprefix`.
>>> _remove_prefix("prefix_string", "prefix_")
'string'
>>> _remove_prefix("not_prefix_string", "prefix_")
'not_prefix_string'
"""
return string[len(prefix) :] if string.startswith(prefix) else string


def _remove_suffix(string: str, suffix: str) -> str:
"""Patch for older version with missing method `removesuffix`.
>>> _remove_suffix("string_suffix", "_suffix")
'string'
>>> _remove_suffix("string_suffix_missing", "_suffix")
'string_suffix_missing'
"""
return string[: -len(suffix)] if string.endswith(suffix) else string


class MetricCollection(ModuleDict):
"""MetricCollection class can be used to chain metrics that have the same call pattern into one single class.
Expand Down Expand Up @@ -558,9 +582,9 @@ def __getitem__(self, key: str, copy_state: bool = True) -> Metric:
"""
self._compute_groups_create_state_ref(copy_state)
if self.prefix:
key = key.removeprefix(self.prefix)
key = _remove_prefix(key, self.prefix)
if self.postfix:
key = key.removesuffix(self.postfix)
key = _remove_suffix(key, self.postfix)
return self._modules[key]

@staticmethod
Expand Down
6 changes: 3 additions & 3 deletions src/torchmetrics/detection/_mean_ap.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,9 +849,9 @@ def __calculate_recall_precision_scores(

inds = torch.searchsorted(rc, rec_thresholds.to(rc.device), right=False)
num_inds = inds.argmax() if inds.max() >= tp_len else num_rec_thrs
inds = inds[:num_inds] # type: ignore[misc]
prec[:num_inds] = pr[inds] # type: ignore[misc]
score[:num_inds] = det_scores_sorted[inds] # type: ignore[misc]
inds = inds[:num_inds]
prec[:num_inds] = pr[inds]
score[:num_inds] = det_scores_sorted[inds]
precision[idx, :, idx_cls, idx_bbox_area, idx_max_det_thresholds] = prec
scores[idx, :, idx_cls, idx_bbox_area, idx_max_det_thresholds] = score

Expand Down
13 changes: 8 additions & 5 deletions src/torchmetrics/detection/iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,14 +182,17 @@ def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]
"""Update state with predictions and targets."""
_input_validator(preds, target, ignore_score=True)

for p, t in zip(preds, target):
det_boxes = self._get_safe_item_values(p["boxes"])
gt_boxes = self._get_safe_item_values(t["boxes"])
self.groundtruth_labels.append(t["labels"])
for p_i, t_i in zip(preds, target):
det_boxes = self._get_safe_item_values(p_i["boxes"])
gt_boxes = self._get_safe_item_values(t_i["boxes"])
self.groundtruth_labels.append(t_i["labels"])

iou_matrix = self._iou_update_fn(det_boxes, gt_boxes, self.iou_threshold, self._invalid_val) # N x M
if self.respect_labels:
label_eq = p["labels"].unsqueeze(1) == t["labels"].unsqueeze(0) # N x M
if det_boxes.numel() > 0 and gt_boxes.numel() > 0:
label_eq = p_i["labels"].unsqueeze(1) == t_i["labels"].unsqueeze(0) # N x M
else:
label_eq = torch.eye(iou_matrix.shape[0], dtype=bool, device=iou_matrix.device) # type: ignore[call-overload]
iou_matrix[~label_eq] = self._invalid_val
self.iou_matrix.append(iou_matrix)

Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/audio/pit.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def permutation_invariant_training(
metric_of_ps = metric_func(ppreds, ptarget)
metric_of_ps = torch.mean(metric_of_ps.reshape(batch_size, len(perms), -1), dim=-1)
# find the best metric and best permutation
best_metric, best_indexes = eval_op(metric_of_ps, dim=1) # type: ignore[call-overload]
best_metric, best_indexes = eval_op(metric_of_ps, dim=1)
best_indexes = best_indexes.detach()
best_perm = perms[best_indexes, :]
return best_metric, best_perm
Expand Down
15 changes: 15 additions & 0 deletions src/torchmetrics/functional/classification/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from torch import Tensor

from torchmetrics.functional.classification.stat_scores import _reduce_stat_scores, _stat_scores_update
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.checks import _input_squeeze
from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod

Expand Down Expand Up @@ -150,6 +151,12 @@ def dice(
Used only in certain special cases, where you want to treat inputs as a different type
than what they appear to be.
.. warning::
The ``dice`` metrics is being deprecated from the classification subpackage in v1.6.0 of torchmetrics and will
be removed in v1.7.0. Please instead consider using ``f1score`` metric from the classification subpackage as it
provides the same functionality. Additionally, we are going to re-add the ``dice`` metric in the segmentation
domain in v1.6.0 with slight modifications to functionality.
Return:
The shape of the returned tensor depends on the ``average`` parameter
Expand All @@ -174,6 +181,14 @@ def dice(
tensor(0.2500)
"""
rank_zero_warn(
"The `dice` metrics is being deprecated from the classification subpackage in v1.6.0 of torchmetrics and will"
" removed in v1.7.0. Please instead consider using `f1score` metric from the classification subpackage as it"
" provides the same functionality. Additionally, we are going to re-add the `dice` metric in the segmentation"
" domain in v1.6.0 with slight modifications to functionality.",
DeprecationWarning,
)

allowed_average = ("micro", "macro", "weighted", "samples", "none", None)
if average not in allowed_average:
raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")
Expand Down
5 changes: 5 additions & 0 deletions src/torchmetrics/functional/detection/ciou.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ def _ciou_update(

from torchvision.ops import complete_box_iou

if preds.numel() == 0: # if no boxes are predicted
return torch.zeros(target.shape[0], target.shape[0], device=target.device, dtype=torch.float32)
if target.numel() == 0: # if no boxes are true
return torch.zeros(preds.shape[0], preds.shape[0], device=preds.device, dtype=torch.float32)

iou = complete_box_iou(preds, target)
if iou_threshold is not None:
iou[iou < iou_threshold] = replacement_val
Expand Down
5 changes: 5 additions & 0 deletions src/torchmetrics/functional/detection/diou.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ def _diou_update(

from torchvision.ops import distance_box_iou

if preds.numel() == 0: # if no boxes are predicted
return torch.zeros(target.shape[0], target.shape[0], device=target.device, dtype=torch.float32)
if target.numel() == 0: # if no boxes are true
return torch.zeros(preds.shape[0], preds.shape[0], device=preds.device, dtype=torch.float32)

iou = distance_box_iou(preds, target)
if iou_threshold is not None:
iou[iou < iou_threshold] = replacement_val
Expand Down
5 changes: 5 additions & 0 deletions src/torchmetrics/functional/detection/giou.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ def _giou_update(

from torchvision.ops import generalized_box_iou

if preds.numel() == 0: # if no boxes are predicted
return torch.zeros(target.shape[0], target.shape[0], device=target.device, dtype=torch.float32)
if target.numel() == 0: # if no boxes are true
return torch.zeros(preds.shape[0], preds.shape[0], device=preds.device, dtype=torch.float32)

iou = generalized_box_iou(preds, target)
if iou_threshold is not None:
iou[iou < iou_threshold] = replacement_val
Expand Down
5 changes: 5 additions & 0 deletions src/torchmetrics/functional/detection/iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ def _iou_update(

from torchvision.ops import box_iou

if preds.numel() == 0: # if no boxes are predicted
return torch.zeros(target.shape[0], target.shape[0], device=target.device, dtype=torch.float32)
if target.numel() == 0: # if no boxes are true
return torch.zeros(preds.shape[0], preds.shape[0], device=preds.device, dtype=torch.float32)

iou = box_iou(preds, target)
if iou_threshold is not None:
iou[iou < iou_threshold] = replacement_val
Expand Down
Loading

0 comments on commit 9f9f28e

Please sign in to comment.