Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added adaptive_k argument to IR Precision metric #910

Merged
merged 10 commits into from
Mar 28, 2022
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added negative `ignore_index` for the Accuracy metric ([#362](https://github.com/PyTorchLightning/metrics/pull/362))


- Added `adaptive_k` for the `RetrievalPrecision` metric ([#910](https://github.com/PyTorchLightning/metrics/pull/910))


### Changed

- Made `num_classes` in `jaccard_index` a required argument ([#853](https://github.com/PyTorchLightning/metrics/pull/853))
Expand Down
19 changes: 19 additions & 0 deletions tests/retrieval/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from tests.retrieval.inputs import _input_retrieval_scores_empty as _irs_empty
from tests.retrieval.inputs import _input_retrieval_scores_extra as _irs_extra
from tests.retrieval.inputs import _input_retrieval_scores_float_target as _irs_float_tgt
from tests.retrieval.inputs import _input_retrieval_scores_for_adaptive_k as _irs_adpt_k
from tests.retrieval.inputs import _input_retrieval_scores_int_target as _irs_int_tgt
from tests.retrieval.inputs import _input_retrieval_scores_mismatching_sizes as _irs_mis_sz
from tests.retrieval.inputs import _input_retrieval_scores_mismatching_sizes_func as _irs_mis_sz_fn
Expand Down Expand Up @@ -162,6 +163,14 @@ def _concat_tests(*tests: Tuple[Dict]) -> Dict:
],
)

_errors_test_functional_metric_parameters_adaptive_k = dict(
argnames="preds,target,message,metric_args",
argvalues=[
(_irs.preds, _irs.target, "`adaptive_k` has to be a boolean", dict(adaptive_k=10)),
(_irs.preds, _irs.target, "`adaptive_k` has to be a boolean", dict(adaptive_k=None)),
],
)

_errors_test_class_metric_parameters_no_pos_target = dict(
argnames="indexes,preds,target,message,metric_args",
argvalues=[
Expand Down Expand Up @@ -302,6 +311,15 @@ def _concat_tests(*tests: Tuple[Dict]) -> Dict:
argnames="indexes,preds,target,message,metric_args",
argvalues=[
(_irs.index, _irs.preds, _irs.target, "`k` has to be a positive integer or None", dict(k=-10)),
(_irs.index, _irs.preds, _irs.target, "`k` has to be a positive integer or None", dict(k=4.0)),
],
)

_errors_test_class_metric_parameters_adaptive_k = dict(
argnames="indexes,preds,target,message,metric_args",
argvalues=[
(_irs.index, _irs.preds, _irs.target, "`adaptive_k` has to be a boolean", dict(adaptive_k=10)),
(_irs.index, _irs.preds, _irs.target, "`adaptive_k` has to be a boolean", dict(adaptive_k=None)),
],
)

Expand All @@ -311,6 +329,7 @@ def _concat_tests(*tests: Tuple[Dict]) -> Dict:
(_irs.indexes, _irs.preds, _irs.target),
(_irs_extra.indexes, _irs_extra.preds, _irs_extra.target),
(_irs_no_tgt.indexes, _irs_no_tgt.preds, _irs_no_tgt.target),
(_irs_adpt_k.indexes, _irs_adpt_k.preds, _irs_adpt_k.target),
],
)

Expand Down
6 changes: 6 additions & 0 deletions tests/retrieval/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)),
)

_input_retrieval_scores_for_adaptive_k = Input(
indexes=torch.randint(high=NUM_BATCHES * BATCH_SIZE // 2, size=(NUM_BATCHES, BATCH_SIZE)),
preds=torch.rand(NUM_BATCHES, BATCH_SIZE),
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)),
)

_input_retrieval_scores_extra = Input(
indexes=torch.randint(high=10, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)),
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM),
Expand Down
22 changes: 17 additions & 5 deletions tests/retrieval/test_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@
_default_metric_class_input_arguments,
_default_metric_class_input_arguments_ignore_index,
_default_metric_functional_input_arguments,
_errors_test_class_metric_parameters_adaptive_k,
_errors_test_class_metric_parameters_default,
_errors_test_class_metric_parameters_k,
_errors_test_class_metric_parameters_no_pos_target,
_errors_test_functional_metric_parameters_adaptive_k,
_errors_test_functional_metric_parameters_default,
_errors_test_functional_metric_parameters_k,
)
Expand All @@ -34,7 +36,7 @@
seed_all(42)


def _precision_at_k(target: np.ndarray, preds: np.ndarray, k: int = None):
def _precision_at_k(target: np.ndarray, preds: np.ndarray, k: int = None, adaptive_k: bool = False):
"""Didn't find a reliable implementation of Precision in Information Retrieval, so, reimplementing here.

A good explanation can be found
Expand All @@ -43,7 +45,7 @@ def _precision_at_k(target: np.ndarray, preds: np.ndarray, k: int = None):
assert target.shape == preds.shape
assert len(target.shape) == 1 # works only with single dimension inputs

if k is None:
if k is None or adaptive_k and k > len(preds):
k = len(preds)

if target.sum() > 0:
Expand All @@ -59,6 +61,7 @@ class TestPrecision(RetrievalMetricTester):
@pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"])
@pytest.mark.parametrize("ignore_index", [None, 1]) # avoid setting 0, otherwise test with all 0 targets will fail
@pytest.mark.parametrize("k", [None, 1, 4, 10])
@pytest.mark.parametrize("adaptive_k", [False, True])
@pytest.mark.parametrize(**_default_metric_class_input_arguments)
def test_class_metric(
self,
Expand All @@ -70,8 +73,11 @@ def test_class_metric(
empty_target_action: str,
ignore_index: int,
k: int,
adaptive_k: bool,
):
metric_args = dict(empty_target_action=empty_target_action, k=k, ignore_index=ignore_index)
metric_args = dict(
empty_target_action=empty_target_action, k=k, ignore_index=ignore_index, adaptive_k=adaptive_k
)

self.run_class_metric_test(
ddp=ddp,
Expand All @@ -88,6 +94,7 @@ def test_class_metric(
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
@pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"])
@pytest.mark.parametrize("k", [None, 1, 4, 10])
@pytest.mark.parametrize("adaptive_k", [False, True])
@pytest.mark.parametrize(**_default_metric_class_input_arguments_ignore_index)
def test_class_metric_ignore_index(
self,
Expand All @@ -98,8 +105,9 @@ def test_class_metric_ignore_index(
dist_sync_on_step: bool,
empty_target_action: str,
k: int,
adaptive_k: bool,
):
metric_args = dict(empty_target_action=empty_target_action, k=k, ignore_index=-100)
metric_args = dict(empty_target_action=empty_target_action, k=k, ignore_index=-100, adaptive_k=adaptive_k)

self.run_class_metric_test(
ddp=ddp,
Expand All @@ -114,14 +122,16 @@ def test_class_metric_ignore_index(

@pytest.mark.parametrize(**_default_metric_functional_input_arguments)
@pytest.mark.parametrize("k", [None, 1, 4, 10])
def test_functional_metric(self, preds: Tensor, target: Tensor, k: int):
@pytest.mark.parametrize("adaptive_k", [False, True])
def test_functional_metric(self, preds: Tensor, target: Tensor, k: int, adaptive_k: bool):
self.run_functional_metric_test(
preds=preds,
target=target,
metric_functional=retrieval_precision,
sk_metric=_precision_at_k,
metric_args={},
k=k,
adaptive_k=adaptive_k,
)

@pytest.mark.parametrize(**_default_metric_class_input_arguments)
Expand Down Expand Up @@ -149,6 +159,7 @@ def test_precision_gpu(self, indexes: Tensor, preds: Tensor, target: Tensor):
_errors_test_class_metric_parameters_default,
_errors_test_class_metric_parameters_no_pos_target,
_errors_test_class_metric_parameters_k,
_errors_test_class_metric_parameters_adaptive_k,
)
)
def test_arguments_class_metric(
Expand All @@ -169,6 +180,7 @@ def test_arguments_class_metric(
**_concat_tests(
_errors_test_functional_metric_parameters_default,
_errors_test_functional_metric_parameters_k,
_errors_test_functional_metric_parameters_adaptive_k,
)
)
def test_arguments_functional_metric(self, preds: Tensor, target: Tensor, message: str, metric_args: dict):
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/functional/retrieval/fall_out.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def retrieval_fall_out(preds: Tensor, target: Tensor, k: Optional[int] = None) -
if not (isinstance(k, int) and k > 0):
raise ValueError("`k` has to be a positive integer or None")

target = 1 - target
target = 1 - target # we want to compute the probability of getting a non-relevant doc among all non-relevant docs

if not target.sum():
return tensor(0.0, device=preds.device)
Expand Down
12 changes: 9 additions & 3 deletions torchmetrics/functional/retrieval/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torchmetrics.utilities.checks import _check_retrieval_functional_inputs


def retrieval_precision(preds: Tensor, target: Tensor, k: Optional[int] = None) -> Tensor:
def retrieval_precision(preds: Tensor, target: Tensor, k: Optional[int] = None, adaptive_k: bool = False) -> Tensor:
"""Computes the precision metric (for information retrieval). Precision is the fraction of relevant documents
among all the retrieved documents.

Expand All @@ -30,13 +30,16 @@ def retrieval_precision(preds: Tensor, target: Tensor, k: Optional[int] = None)
preds: estimated probabilities of each document to be relevant.
target: ground truth about each document being relevant or not.
k: consider only the top k elements (default: `None`, which considers them all)
adaptive_k: adjust `k` to `min(k, number of documents)` for each query

Returns:
a single-value tensor with the precision (at ``k``) of the predictions ``preds`` w.r.t. the labels ``target``.

Raises:
ValueError:
If ``k`` parameter is not `None` or an integer larger than 0
If ``k`` is not `None` or an integer larger than 0.
ValueError:
If ``adaptive_k`` is not boolean.

Example:
>>> preds = tensor([0.2, 0.3, 0.5])
Expand All @@ -46,7 +49,10 @@ def retrieval_precision(preds: Tensor, target: Tensor, k: Optional[int] = None)
"""
preds, target = _check_retrieval_functional_inputs(preds, target)

if k is None:
if not isinstance(adaptive_k, bool):
raise ValueError("`adaptive_k` has to be a boolean")

if k is None or (adaptive_k and k > preds.shape[-1]):
k = preds.shape[-1]

if not (isinstance(k, int) and k > 0):
Expand Down
11 changes: 9 additions & 2 deletions torchmetrics/retrieval/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class RetrievalPrecision(RetrievalMetric):
ignore_index:
Ignore predictions where the target is equal to this number.
k: consider only the top k elements for each query (default: `None`, which considers them all)
adaptive_k: adjust `k` to `min(k, number of documents)` for each query
compute_on_step:
Forward only calls ``update()`` and returns None if this is set to False.

Expand All @@ -62,7 +63,9 @@ class RetrievalPrecision(RetrievalMetric):
ValueError:
If ``ignore_index`` is not `None` or an integer.
ValueError:
If ``k`` parameter is not `None` or an integer larger than 0.
If ``k`` is not `None` or an integer larger than 0.
ValueError:
If ``adaptive_k`` is not boolean.

Example:
>>> from torchmetrics import RetrievalPrecision
Expand All @@ -81,6 +84,7 @@ def __init__(
empty_target_action: str = "neg",
ignore_index: Optional[int] = None,
k: Optional[int] = None,
adaptive_k: bool = False,
compute_on_step: Optional[bool] = None,
**kwargs: Dict[str, Any],
) -> None:
Expand All @@ -93,7 +97,10 @@ def __init__(

if (k is not None) and not (isinstance(k, int) and k > 0):
raise ValueError("`k` has to be a positive integer or None")
if not isinstance(adaptive_k, bool):
raise ValueError("`adaptive_k` has to be a boolean")
self.k = k
self.adaptive_k = adaptive_k

def _metric(self, preds: Tensor, target: Tensor) -> Tensor:
return retrieval_precision(preds, target, k=self.k)
return retrieval_precision(preds, target, k=self.k, adaptive_k=self.adaptive_k)