Skip to content

Commit

Permalink
Refactor Information Retrieval tests (#156)
Browse files Browse the repository at this point in the history
* init transition to standard metric interface for IR metrics

* fixed typo in dtypes checks

* removed IGNORE_IDX, refactored tests using

* added pep8 compatibility

* fixed np.ndarray to np.array

* remove lambda functions

* fixed typos with numpy dtype

* fixed typo in doc example

* fixed typo in doc examples about new indexes position

* added paramter to class testing to divide kwargs as preds and targets. Fixed typo in doc format

* added typo in doc example

* added typo with new parameter frament_kwargs in MetricTester

* added typo in .cpu() conversion of non-tensor values

* improved test coverage

* improved test coverage

* added check on Tensor class to avoid calling .cpu() on non-tensor values

* improved doc and changed default values for 'empty_target_action' argument

* refactored tests lists

* formatting

* simple

* agrs

* format

* _sk

* fixed typo in tests

Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
  • Loading branch information
lucadiliello and Borda authored Apr 6, 2021
1 parent 6b95f87 commit 89acd4e
Show file tree
Hide file tree
Showing 22 changed files with 1,023 additions and 417 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ ENV/
env.bak/
venv.bak/

# Editor configs
.vscode/

# Spyder project settings
.spyderproject
.spyproject
Expand Down
6 changes: 3 additions & 3 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -310,14 +310,14 @@ the set of pairs ``(Q_i, D_j)`` having the same query ``Q_i``.
>>> target = torch.tensor([0, 1, 0, 1, 1])

>>> map = RetrievalMAP() # or some other retrieval metric
>>> map(indexes, preds, target)
>>> map(preds, target, indexes=indexes)
tensor(0.6667)

>>> # the previous instruction is roughly equivalent to
>>> res = []
>>> # iterate over indexes of first and second query
>>> for idx in ([0, 1], [2, 3, 4]):
... res.append(retrieval_average_precision(preds[idx], target[idx]))
>>> for indexes in ([0, 1], [2, 3, 4]):
... res.append(retrieval_average_precision(preds[indexes], target[indexes]))
>>> torch.stack(res).mean()
tensor(0.6667)

Expand Down
150 changes: 0 additions & 150 deletions tests/functional/test_retrieval.py

This file was deleted.

44 changes: 37 additions & 7 deletions tests/helpers/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def _class_test(
check_batch: bool = True,
atol: float = 1e-8,
device: str = 'cpu',
fragment_kwargs: bool = False,
**kwargs_update: Any,
):
"""Utility function doing the actual comparison between lightning class metric
Expand All @@ -102,6 +103,7 @@ def _class_test(
check_batch: bool, if true will check if the metric is also correctly
calculated across devices for each batch (and not just at the end)
device: determine which device to run on, either 'cuda' or 'cpu'
fragment_kwargs: whether tensors in kwargs should be divided as `preds` and `target` among processes
kwargs_update: Additional keyword arguments that will be passed with preds and
target when running update on the metric.
"""
Expand Down Expand Up @@ -133,14 +135,17 @@ def _class_test(
ddp_target = torch.cat([target[i + r] for r in range(worldsize)]).cpu()
ddp_kwargs_upd = {
k: torch.cat([v[i + r] for r in range(worldsize)]).cpu() if isinstance(v, Tensor) else v
for k, v in batch_kwargs_update.items()
for k, v in (kwargs_update if fragment_kwargs else batch_kwargs_update).items()
}

sk_batch_result = sk_metric(ddp_preds, ddp_target, **ddp_kwargs_upd)
_assert_allclose(batch_result, sk_batch_result, atol=atol)

elif check_batch and not metric.dist_sync_on_step:
batch_kwargs_update = {k: v.cpu() for k, v in kwargs_update.items()}
batch_kwargs_update = {
k: v.cpu() if isinstance(v, Tensor) else v
for k, v in (batch_kwargs_update if fragment_kwargs else kwargs_update).items()
}
sk_batch_result = sk_metric(preds[i].cpu(), target[i].cpu(), **batch_kwargs_update)
_assert_allclose(batch_result, sk_batch_result, atol=atol)

Expand Down Expand Up @@ -168,6 +173,7 @@ def _functional_test(
metric_args: dict = None,
atol: float = 1e-8,
device: str = 'cpu',
fragment_kwargs: bool = False,
**kwargs_update,
):
"""Utility function doing the actual comparison between lightning functional metric
Expand All @@ -180,6 +186,7 @@ def _functional_test(
sk_metric: callable function that is used for comparison
metric_args: dict with additional arguments used for class initialization
device: determine which device to run on, either 'cuda' or 'cpu'
fragment_kwargs: whether tensors in kwargs should be divided as `preds` and `target` among processes
kwargs_update: Additional keyword arguments that will be passed with preds and
target when running update on the metric.
"""
Expand All @@ -196,7 +203,10 @@ def _functional_test(
for i in range(NUM_BATCHES):
extra_kwargs = {k: v[i] if isinstance(v, Tensor) else v for k, v in kwargs_update.items()}
lightning_result = metric(preds[i], target[i], **extra_kwargs)
extra_kwargs = {k: v.cpu() for k, v in kwargs_update.items()}
extra_kwargs = {
k: v.cpu() if isinstance(v, Tensor) else v
for k, v in (extra_kwargs if fragment_kwargs else kwargs_update).items()
}
sk_result = sk_metric(preds[i].cpu(), target[i].cpu(), **extra_kwargs)

# assert its the same
Expand All @@ -209,6 +219,7 @@ def _assert_half_support(
preds: torch.Tensor,
target: torch.Tensor,
device: str = "cpu",
**kwargs_update
):
"""
Test if an metric can be used with half precision tensors
Expand All @@ -219,12 +230,18 @@ def _assert_half_support(
preds: torch tensor with predictions
target: torch tensor with targets
device: determine device, either "cpu" or "cuda"
kwargs_update: Additional keyword arguments that will be passed with preds and
target when running update on the metric.
"""
y_hat = preds[0].half().to(device) if preds[0].is_floating_point() else preds[0].to(device)
y = target[0].half().to(device) if target[0].is_floating_point() else target[0].to(device)
kwargs_update = {
k: (v[0].half() if v.is_floating_point() else v[0]).to(device) if isinstance(v, Tensor) else v
for k, v in kwargs_update.items()
}
metric_module = metric_module.to(device)
_assert_tensor(metric_module(y_hat, y))
_assert_tensor(metric_functional(y_hat, y))
_assert_tensor(metric_module(y_hat, y, **kwargs_update))
_assert_tensor(metric_functional(y_hat, y, **kwargs_update))


class MetricTester:
Expand Down Expand Up @@ -260,6 +277,7 @@ def run_functional_metric_test(
metric_functional: Callable,
sk_metric: Callable,
metric_args: dict = None,
fragment_kwargs: bool = False,
**kwargs_update,
):
"""Main method that should be used for testing functions. Call this inside
Expand All @@ -271,6 +289,7 @@ def run_functional_metric_test(
metric_functional: lightning metric class that should be tested
sk_metric: callable function that is used for comparison
metric_args: dict with additional arguments used for class initialization
fragment_kwargs: whether tensors in kwargs should be divided as `preds` and `target` among processes
kwargs_update: Additional keyword arguments that will be passed with preds and
target when running update on the metric.
"""
Expand All @@ -284,6 +303,7 @@ def run_functional_metric_test(
metric_args=metric_args,
atol=self.atol,
device=device,
fragment_kwargs=fragment_kwargs,
**kwargs_update,
)

Expand All @@ -298,6 +318,7 @@ def run_class_metric_test(
metric_args: dict = None,
check_dist_sync_on_step: bool = True,
check_batch: bool = True,
fragment_kwargs: bool = False,
**kwargs_update,
):
"""Main method that should be used for testing class. Call this inside testing
Expand All @@ -316,6 +337,7 @@ def run_class_metric_test(
calculated per batch per device (and not just at the end)
check_batch: bool, if true will check if the metric is also correctly
calculated across devices for each batch (and not just at the end)
fragment_kwargs: whether tensors in kwargs should be divided as `preds` and `target` among processes
kwargs_update: Additional keyword arguments that will be passed with preds and
target when running update on the metric.
"""
Expand All @@ -337,6 +359,7 @@ def run_class_metric_test(
check_dist_sync_on_step=check_dist_sync_on_step,
check_batch=check_batch,
atol=self.atol,
fragment_kwargs=fragment_kwargs,
**kwargs_update,
),
[(rank, self.poolSize) for rank in range(self.poolSize)],
Expand All @@ -357,6 +380,7 @@ def run_class_metric_test(
check_batch=check_batch,
atol=self.atol,
device=device,
fragment_kwargs=fragment_kwargs,
**kwargs_update,
)

Expand All @@ -367,6 +391,7 @@ def run_precision_test_cpu(
metric_module: Metric,
metric_functional: Callable,
metric_args: dict = {},
**kwargs_update,
):
"""Test if an metric can be used with half precision tensors on cpu
Args:
Expand All @@ -375,9 +400,11 @@ def run_precision_test_cpu(
metric_module: the metric module to test
metric_functional: the metric functional to test
metric_args: dict with additional arguments used for class initialization
kwargs_update: Additional keyword arguments that will be passed with preds and
target when running update on the metric.
"""
_assert_half_support(
metric_module(**metric_args), partial(metric_functional, **metric_args), preds, target, device="cpu"
metric_module(**metric_args), metric_functional, preds, target, device="cpu", **kwargs_update
)

def run_precision_test_gpu(
Expand All @@ -387,6 +414,7 @@ def run_precision_test_gpu(
metric_module: Metric,
metric_functional: Callable,
metric_args: dict = {},
**kwargs_update,
):
"""Test if an metric can be used with half precision tensors on gpu
Args:
Expand All @@ -395,9 +423,11 @@ def run_precision_test_gpu(
metric_module: the metric module to test
metric_functional: the metric functional to test
metric_args: dict with additional arguments used for class initialization
kwargs_update: Additional keyword arguments that will be passed with preds and
target when running update on the metric.
"""
_assert_half_support(
metric_module(**metric_args), partial(metric_functional, **metric_args), preds, target, device="cuda"
metric_module(**metric_args), metric_functional, preds, target, device="cuda", **kwargs_update
)


Expand Down
6 changes: 3 additions & 3 deletions tests/regression/test_ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
))


def _sk_metric(preds, target, data_range, multichannel):
def _sk_ssim(preds, target, data_range, multichannel):
c, h, w = preds.shape[-3:]
sk_preds = preds.view(-1, c, h, w).permute(0, 2, 3, 1).numpy()
sk_target = target.view(-1, c, h, w).permute(0, 2, 3, 1).numpy()
Expand Down Expand Up @@ -77,7 +77,7 @@ def test_ssim(self, preds, target, multichannel, ddp, dist_sync_on_step):
preds,
target,
SSIM,
partial(_sk_metric, data_range=1.0, multichannel=multichannel),
partial(_sk_ssim, data_range=1.0, multichannel=multichannel),
metric_args={"data_range": 1.0},
dist_sync_on_step=dist_sync_on_step,
)
Expand All @@ -87,7 +87,7 @@ def test_ssim_functional(self, preds, target, multichannel):
preds,
target,
ssim,
partial(_sk_metric, data_range=1.0, multichannel=multichannel),
partial(_sk_ssim, data_range=1.0, multichannel=multichannel),
metric_args={"data_range": 1.0},
)

Expand Down
Loading

0 comments on commit 89acd4e

Please sign in to comment.