|
1 |
| -from typing import Callable, Union |
2 |
| - |
3 |
| -import torch |
4 |
| - |
5 |
| -from ignite.metrics import EpochMetric |
6 |
| - |
7 |
| - |
8 |
| -def average_precision_compute_fn(y_preds: torch.Tensor, y_targets: torch.Tensor) -> float: |
9 |
| - from sklearn.metrics import average_precision_score |
10 |
| - |
11 |
| - y_true = y_targets.cpu().numpy() |
12 |
| - y_pred = y_preds.cpu().numpy() |
13 |
| - return average_precision_score(y_true, y_pred) |
14 |
| - |
15 |
| - |
16 |
| -class AveragePrecision(EpochMetric): |
17 |
| - """Computes Average Precision accumulating predictions and the ground-truth during an epoch |
18 |
| - and applying `sklearn.metrics.average_precision_score <https://scikit-learn.org/stable/modules/generated/ |
19 |
| - sklearn.metrics.average_precision_score.html#sklearn.metrics.average_precision_score>`_ . |
20 |
| -
|
21 |
| - Args: |
22 |
| - output_transform: a callable that is used to transform the |
23 |
| - :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the |
24 |
| - form expected by the metric. This can be useful if, for example, you have a multi-output model and |
25 |
| - you want to compute the metric with respect to one of the outputs. |
26 |
| - check_compute_fn: Default False. If True, `average_precision_score |
27 |
| - <https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score.html |
28 |
| - #sklearn.metrics.average_precision_score>`_ is run on the first batch of data to ensure there are |
29 |
| - no issues. User will be warned in case there are any issues computing the function. |
30 |
| - device: optional device specification for internal storage. |
31 |
| -
|
32 |
| - Note: |
33 |
| - AveragePrecision expects y to be comprised of 0's and 1's. y_pred must either be probability estimates or |
34 |
| - confidence values. To apply an activation to y_pred, use output_transform as shown below: |
35 |
| -
|
36 |
| - .. code-block:: python |
37 |
| -
|
38 |
| - def activated_output_transform(output): |
39 |
| - y_pred, y = output |
40 |
| - y_pred = torch.softmax(y_pred, dim=1) |
41 |
| - return y_pred, y |
42 |
| - avg_precision = AveragePrecision(activated_output_transform) |
43 |
| -
|
44 |
| - Examples: |
45 |
| -
|
46 |
| - .. include:: defaults.rst |
47 |
| - :start-after: :orphan: |
48 |
| -
|
49 |
| - .. testcode:: |
50 |
| -
|
51 |
| - y_pred = torch.tensor([[0.79, 0.21], [0.30, 0.70], [0.46, 0.54], [0.16, 0.84]]) |
52 |
| - y_true = torch.tensor([[1, 1], [1, 1], [0, 1], [0, 1]]) |
53 |
| -
|
54 |
| - avg_precision = AveragePrecision() |
55 |
| - avg_precision.attach(default_evaluator, 'average_precision') |
56 |
| - state = default_evaluator.run([[y_pred, y_true]]) |
57 |
| - print(state.metrics['average_precision']) |
58 |
| -
|
59 |
| - .. testoutput:: |
60 |
| -
|
61 |
| - 0.9166... |
62 |
| -
|
63 |
| - """ |
64 |
| - |
65 |
| - def __init__( |
66 |
| - self, |
67 |
| - output_transform: Callable = lambda x: x, |
68 |
| - check_compute_fn: bool = False, |
69 |
| - device: Union[str, torch.device] = torch.device("cpu"), |
70 |
| - ): |
71 |
| - try: |
72 |
| - from sklearn.metrics import average_precision_score # noqa: F401 |
73 |
| - except ImportError: |
74 |
| - raise ModuleNotFoundError("This contrib module requires scikit-learn to be installed.") |
75 |
| - |
76 |
| - super(AveragePrecision, self).__init__( |
77 |
| - average_precision_compute_fn, |
78 |
| - output_transform=output_transform, |
79 |
| - check_compute_fn=check_compute_fn, |
80 |
| - device=device, |
81 |
| - ) |
| 1 | +""" ``ignite.contrib.metrics.average_precision`` was moved to ``ignite.metrics.average_precision``. |
| 2 | +Note: |
| 3 | + ``ignite.contrib.metrics.average_precision`` was moved to ``ignite.metrics.average_precision``. |
| 4 | + Please refer to :mod:`~ignite.metrics.average_precision`. |
| 5 | +""" |
| 6 | + |
| 7 | +import warnings |
| 8 | + |
| 9 | +removed_in = "0.6.0" |
| 10 | +deprecation_warning = ( |
| 11 | + f"{__file__} has been moved to /ignite/metrics/average_precision.py" |
| 12 | + + (f" and will be removed in version {removed_in}" if removed_in else "") |
| 13 | + + ".\n Please refer to the documentation for more details." |
| 14 | +) |
| 15 | +warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2) |
| 16 | +from ignite.metrics.average_precision import AveragePrecision |
| 17 | + |
| 18 | +__all__ = [ |
| 19 | + "AveragePrecision", |
| 20 | +] |
| 21 | + |
| 22 | +AveragePrecision = AveragePrecision |
0 commit comments