Skip to content

Commit

Permalink
Merge branch 'master' into bugfix/map_with_custom_threshold
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Apr 29, 2022
2 parents 0f20366 + 3a141ae commit ecb6f2e
Show file tree
Hide file tree
Showing 35 changed files with 717 additions and 108 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

-
- Added `RetrievalPrecisionRecallCurve` to retrieval package ([#951](https://github.com/PyTorchLightning/metrics/pull/951))

- Added `RetrievalRecallAtFixedPrecision` to retrieval package ([#951](https://github.com/PyTorchLightning/metrics/pull/951))


-
Expand Down
22 changes: 22 additions & 0 deletions docs/source/retrieval/precision_recall_curve.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
.. customcarditem::
:header: Precision Recall Curve
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/text_classification.svg
:tags: Retrieval

.. include:: ../links.rst

######################
Precision Recall Curve
######################

Module Interface
________________

.. autoclass:: torchmetrics.RetrievalPrecisionRecallCurve
:noindex:

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.retrieval_precision_recall_curve
:noindex:
7 changes: 4 additions & 3 deletions tests/bases/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
import psutil
import pytest
import torch
from torch import Tensor, nn, tensor
from torch import Tensor, tensor
from torch.nn import Module

from tests.helpers import seed_all
from tests.helpers.testers import DummyListMetric, DummyMetric, DummyMetricMultiOutput, DummyMetricSum
Expand Down Expand Up @@ -245,7 +246,7 @@ def test_load_state_dict(tmpdir):
def test_child_metric_state_dict():
"""test that child metric states will be added to parent state dict."""

class TestModule(nn.Module):
class TestModule(Module):
def __init__(self):
super().__init__()
self.metric = DummyMetric()
Expand Down Expand Up @@ -346,7 +347,7 @@ def test_forward_and_compute_to_device(metric_class):
def test_device_if_child_module(metric_class):
"""Test that if a metric is a child module all values gets moved to the correct device."""

class TestModule(nn.Module):
class TestModule(Module):
def __init__(self):
super().__init__()
self.metric = metric_class()
Expand Down
67 changes: 34 additions & 33 deletions tests/detection/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import pytest
import torch
from torch import Tensor

from tests.helpers.testers import MetricTester
from torchmetrics.detection.mean_ap import MeanAveragePrecision
Expand All @@ -27,19 +28,19 @@
preds=[
[
dict(
boxes=torch.Tensor([[258.15, 41.29, 606.41, 285.07]]),
scores=torch.Tensor([0.236]),
boxes=Tensor([[258.15, 41.29, 606.41, 285.07]]),
scores=Tensor([0.236]),
labels=torch.IntTensor([4]),
), # coco image id 42
dict(
boxes=torch.Tensor([[61.00, 22.75, 565.00, 632.42], [12.66, 3.32, 281.26, 275.23]]),
scores=torch.Tensor([0.318, 0.726]),
boxes=Tensor([[61.00, 22.75, 565.00, 632.42], [12.66, 3.32, 281.26, 275.23]]),
scores=Tensor([0.318, 0.726]),
labels=torch.IntTensor([3, 2]),
), # coco image id 73
],
[
dict(
boxes=torch.Tensor(
boxes=Tensor(
[
[87.87, 276.25, 384.29, 379.43],
[0.00, 3.66, 142.15, 316.06],
Expand All @@ -50,24 +51,24 @@
[276.11, 103.84, 291.44, 150.72],
]
),
scores=torch.Tensor([0.546, 0.3, 0.407, 0.611, 0.335, 0.805, 0.953]),
scores=Tensor([0.546, 0.3, 0.407, 0.611, 0.335, 0.805, 0.953]),
labels=torch.IntTensor([4, 1, 0, 0, 0, 0, 0]),
), # coco image id 74
dict(
boxes=torch.Tensor([[0.00, 2.87, 601.00, 421.52]]),
scores=torch.Tensor([0.699]),
boxes=Tensor([[0.00, 2.87, 601.00, 421.52]]),
scores=Tensor([0.699]),
labels=torch.IntTensor([5]),
), # coco image id 133
],
],
target=[
[
dict(
boxes=torch.Tensor([[214.1500, 41.2900, 562.4100, 285.0700]]),
boxes=Tensor([[214.1500, 41.2900, 562.4100, 285.0700]]),
labels=torch.IntTensor([4]),
), # coco image id 42
dict(
boxes=torch.Tensor(
boxes=Tensor(
[
[13.00, 22.75, 548.98, 632.42],
[1.66, 3.32, 270.26, 275.23],
Expand All @@ -78,7 +79,7 @@
],
[
dict(
boxes=torch.Tensor(
boxes=Tensor(
[
[61.87, 276.25, 358.29, 379.43],
[2.75, 3.66, 162.15, 316.06],
Expand All @@ -92,7 +93,7 @@
labels=torch.IntTensor([4, 1, 0, 0, 0, 0, 0]),
), # coco image id 74
dict(
boxes=torch.Tensor([[13.99, 2.87, 640.00, 421.52]]),
boxes=Tensor([[13.99, 2.87, 640.00, 421.52]]),
labels=torch.IntTensor([5]),
), # coco image id 133
],
Expand All @@ -104,29 +105,29 @@
preds=[
[
dict(
boxes=torch.Tensor([[258.0, 41.0, 606.0, 285.0]]),
scores=torch.Tensor([0.536]),
boxes=Tensor([[258.0, 41.0, 606.0, 285.0]]),
scores=Tensor([0.536]),
labels=torch.IntTensor([0]),
),
],
[
dict(
boxes=torch.Tensor([[258.0, 41.0, 606.0, 285.0]]),
scores=torch.Tensor([0.536]),
boxes=Tensor([[258.0, 41.0, 606.0, 285.0]]),
scores=Tensor([0.536]),
labels=torch.IntTensor([0]),
)
],
],
target=[
[
dict(
boxes=torch.Tensor([[214.0, 41.0, 562.0, 285.0]]),
boxes=Tensor([[214.0, 41.0, 562.0, 285.0]]),
labels=torch.IntTensor([0]),
)
],
[
dict(
boxes=torch.Tensor([]),
boxes=Tensor([]),
labels=torch.IntTensor([]),
)
],
Expand Down Expand Up @@ -196,20 +197,20 @@ def _compare_fn(preds, target) -> dict:
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.900
"""
return {
"map": torch.Tensor([0.706]),
"map_50": torch.Tensor([0.901]),
"map_75": torch.Tensor([0.846]),
"map_small": torch.Tensor([0.689]),
"map_medium": torch.Tensor([0.800]),
"map_large": torch.Tensor([0.701]),
"mar_1": torch.Tensor([0.592]),
"mar_10": torch.Tensor([0.716]),
"mar_100": torch.Tensor([0.716]),
"mar_small": torch.Tensor([0.767]),
"mar_medium": torch.Tensor([0.800]),
"mar_large": torch.Tensor([0.700]),
"map_per_class": torch.Tensor([0.725, 0.800, 0.454, -1.000, 0.650, 0.900]),
"mar_100_per_class": torch.Tensor([0.780, 0.800, 0.450, -1.000, 0.650, 0.900]),
"map": Tensor([0.706]),
"map_50": Tensor([0.901]),
"map_75": Tensor([0.846]),
"map_small": Tensor([0.689]),
"map_medium": Tensor([0.800]),
"map_large": Tensor([0.701]),
"mar_1": Tensor([0.592]),
"mar_10": Tensor([0.716]),
"mar_100": Tensor([0.716]),
"mar_small": Tensor([0.767]),
"mar_medium": Tensor([0.800]),
"mar_large": Tensor([0.700]),
"map_per_class": Tensor([0.725, 0.800, 0.454, -1.000, 0.650, 0.900]),
"mar_100_per_class": Tensor([0.780, 0.800, 0.450, -1.000, 0.650, 0.900]),
}


Expand Down Expand Up @@ -260,7 +261,7 @@ def test_empty_preds():

metric.update(
[
dict(boxes=torch.Tensor([]), scores=torch.Tensor([]), labels=torch.IntTensor([])),
dict(boxes=Tensor([]), scores=torch.Tensor([]), labels=torch.IntTensor([])),
],
[
dict(boxes=torch.Tensor([[214.1500, 41.2900, 562.4100, 285.0700]]), labels=torch.IntTensor([4])),
Expand Down
8 changes: 4 additions & 4 deletions tests/helpers/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def _class_test(
batch_result = metric(preds[i], target[i], **batch_kwargs_update)

if metric.dist_sync_on_step and check_dist_sync_on_step and rank == 0:
if isinstance(preds, torch.Tensor):
if isinstance(preds, Tensor):
ddp_preds = torch.cat([preds[i + r] for r in range(worldsize)]).cpu()
ddp_target = torch.cat([target[i + r] for r in range(worldsize)]).cpu()
else:
Expand All @@ -201,8 +201,8 @@ def _class_test(
k: v.cpu() if isinstance(v, Tensor) else v
for k, v in (batch_kwargs_update if fragment_kwargs else kwargs_update).items()
}
preds_ = preds[i].cpu() if isinstance(preds, torch.Tensor) else preds[i]
target_ = target[i].cpu() if isinstance(target, torch.Tensor) else target[i]
preds_ = preds[i].cpu() if isinstance(preds, Tensor) else preds[i]
target_ = target[i].cpu() if isinstance(target, Tensor) else target[i]
sk_batch_result = sk_metric(preds_, target_, **batch_kwargs_update)
if isinstance(batch_result, dict):
for key in batch_result.keys():
Expand All @@ -221,7 +221,7 @@ def _class_test(
else:
_assert_tensor(result)

if isinstance(preds, torch.Tensor):
if isinstance(preds, Tensor):
total_preds = torch.cat([preds[i] for i in range(num_batches)]).cpu()
total_target = torch.cat([target[i] for i in range(num_batches)]).cpu()
else:
Expand Down
3 changes: 2 additions & 1 deletion tests/image/test_fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pytest
import torch
from scipy.linalg import sqrtm as scipy_sqrtm
from torch.nn import Module
from torch.utils.data import Dataset

from torchmetrics.image.fid import FrechetInceptionDistance, sqrtm
Expand Down Expand Up @@ -44,7 +45,7 @@ def generate_cov(n):
def test_no_train():
"""Assert that metric never leaves evaluation mode."""

class MyModel(torch.nn.Module):
class MyModel(Module):
def __init__(self):
super().__init__()
self.metric = FrechetInceptionDistance()
Expand Down
3 changes: 2 additions & 1 deletion tests/image/test_inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import pytest
import torch
from torch.nn import Module
from torch.utils.data import Dataset

from torchmetrics.image.inception import InceptionScore
Expand All @@ -27,7 +28,7 @@
def test_no_train():
"""Assert that metric never leaves evaluation mode."""

class MyModel(torch.nn.Module):
class MyModel(Module):
def __init__(self):
super().__init__()
self.metric = InceptionScore()
Expand Down
3 changes: 2 additions & 1 deletion tests/image/test_kid.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import pytest
import torch
from torch.nn import Module
from torch.utils.data import Dataset

from torchmetrics.image.kid import KernelInceptionDistance
Expand All @@ -27,7 +28,7 @@
def test_no_train():
"""Assert that metric never leaves evaluation mode."""

class MyModel(torch.nn.Module):
class MyModel(Module):
def __init__(self):
super().__init__()
self.metric = KernelInceptionDistance()
Expand Down
Loading

0 comments on commit ecb6f2e

Please sign in to comment.