From e94c4badbeb1620f92cde2c394ae9f8bca2c1ea7 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 17 Mar 2021 11:50:25 +0100 Subject: [PATCH 01/36] add bootstrapping --- CHANGELOG.md | 2 + docs/source/references/modules.rst | 13 +++- torchmetrics/__init__.py | 1 + torchmetrics/wrappers/__init__.py | 14 ++++ torchmetrics/wrappers/bootstrapping.py | 98 ++++++++++++++++++++++++++ 5 files changed, 127 insertions(+), 1 deletion(-) create mode 100644 torchmetrics/wrappers/__init__.py create mode 100644 torchmetrics/wrappers/bootstrapping.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 784174932a5..baed9e5d428 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `CohenKappa` metric ([#69](https://github.com/PyTorchLightning/metrics/pull/69)) +- Added `BootStrapper` to easely calculate confidence intervals for metrics ([]()) + ### Changed - Changed `ExplainedVariance` from storing all preds/targets to tracking 5 statistics ([#68](https://github.com/PyTorchLightning/metrics/pull/68)) diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index ce4930c4931..1b85b5f0fed 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -256,4 +256,15 @@ R2Score ~~~~~~~ .. autoclass:: torchmetrics.R2Score - :noindex: \ No newline at end of file + :noindex: + +******** +Wrappers +******** + +Modular wrapper metrics are not metrics in themself, but instead the take in other metrics and alter +the internal logic of the base metric. + +.. autoclass:: torchmetrics.BootStrapper + :noindex: + diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 125af29392c..2d0e489e4f6 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -69,3 +69,4 @@ MeanSquaredLogError, R2Score, ) + from torchmetrics.wrappers import BootStrapper # noqa: F401 diff --git a/torchmetrics/wrappers/__init__.py b/torchmetrics/wrappers/__init__.py new file mode 100644 index 00000000000..7e6b7d4da94 --- /dev/null +++ b/torchmetrics/wrappers/__init__.py @@ -0,0 +1,14 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from torchmetrics.wrappers.bootstrapping import BootStrapper # noqa: F401 \ No newline at end of file diff --git a/torchmetrics/wrappers/bootstrapping.py b/torchmetrics/wrappers/bootstrapping.py new file mode 100644 index 00000000000..f23d755cdeb --- /dev/null +++ b/torchmetrics/wrappers/bootstrapping.py @@ -0,0 +1,98 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Union, List +from copy import deepcopy + +import torch +from torch import nn + +from torchmetrics.metric import Metric +from torchmetrics.utilities import apply_to_collection + + +def _bootstrap_sampler(tensor, size: Optional[int] = None): + """ """ + if size is None: + size = tensor.shape[0] + idx = torch.multinomial( + torch.ones(tensor.shape[0], device=tensor.device), + num_samples=size, + replacement=True + ) + return tensor[idx] + + +class BootStrapper(Metric): + def __init__(self, base_metric: Metric, + num_bootstraps: int = 10): + """ + Use to turn a metric into a bootstrapped metric that can automate the process of getting confidence + intervals for metric values. This wrapper class basically keeps multiple copies of the same base metric + in memory and whenever ``update`` or ``forward`` is called, all input tensors are resampled + (with replacement) along the first dimension. + + .. note:: Different from all other metrics, bootstrapped metrics has additional + arguments in its ``compute`` method determining what should be returned. + + Example:: + >>> from torchmetrics.wrappers import BootStrapper + >>> from torchmetrics import Accuracy + >>> _ = torch.manual_seed(0) + >>> base_metric = Accuracy() + >>> bootstrap = BootStrapper(base_metric, num_bootstraps=20) + >>> bootstrap.update(torch.randint(5, (20,)), torch.randint(5, (20,))) + >>> output = bootstrap.compute(mean=True, std=True) + >>> mean, std = output + >>> print(mean, std) + tensor(0.4950) tensor(0.1677) + + """ + super().__init__() + self.metrics = nn.ModuleList([deepcopy(base_metric) for _ in range(num_bootstraps)]) + self.num_bootstraps = num_bootstraps + + def update(self, *args, **kwargs): + """ Updates the state of the base metric. Any tensor passed in will be bootstrapped + along dimension 0 + """ + for idx in range(self.num_bootstraps): + args = apply_to_collection(args, torch.Tensor, _bootstrap_sampler) + kwargs = apply_to_collection(kwargs, torch.Tensor, _bootstrap_sampler) + self.metrics[idx].update(*args, **kwargs) + + def compute( + self, + mean: bool = True, + std: bool = True, + quantile: Optional[Union[float, torch.Tensor]] = None, + raw: bool = False + ) -> List[torch.Tensor]: + """ Computes the metric value. + Args: + mean: if `True` return the mean of the bootstraps + std: if `True` return the standard diviation of the bootstraps + quantile: if given, returns the quantile of the bootstraps + raw: if `True`, return all bootstrapped values + """ + computed_vals = torch.stack([m.compute() for m in self.metrics], dim=0) + output = [] + if mean: + output.append(computed_vals.mean(dim=0)) + if std: + output.append(computed_vals.std(dim=0)) + if quantile is not None: + output.append(torch.quantile(computed_vals, quantile)) + if raw: + output.append(computed_vals) + return output From 96967e4dcb883d2535d094114cfe1e1d7987b0bd Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 17 Mar 2021 15:44:57 +0100 Subject: [PATCH 02/36] tests --- tests/helpers/testers.py | 8 +++- tests/wrappers/__init__.py | 0 tests/wrappers/test_bootstrapping.py | 64 ++++++++++++++++++++++++++ torchmetrics/wrappers/bootstrapping.py | 38 +++++++++++++-- 4 files changed, 105 insertions(+), 5 deletions(-) create mode 100644 tests/wrappers/__init__.py create mode 100644 tests/wrappers/test_bootstrapping.py diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index 4834edd5448..233538610d5 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -104,7 +104,11 @@ def _class_test( calculated across devices for each batch (and not just at the end) """ # Instanciate lightning metric - metric = metric_class(compute_on_step=True, dist_sync_on_step=dist_sync_on_step, **metric_args) + metric = metric_class( + compute_on_step=check_dist_sync_on_step or check_batch, + dist_sync_on_step=dist_sync_on_step, + **metric_args + ) # verify metrics work after being loaded from pickled state pickled_metric = pickle.dumps(metric) @@ -120,6 +124,8 @@ def _class_test( sk_batch_result = sk_metric(ddp_preds, ddp_target) # assert for dist_sync_on_step if check_dist_sync_on_step: + import pdb + pdb.set_trace() _assert_allclose(batch_result, sk_batch_result, atol=atol) else: sk_batch_result = sk_metric(preds[i], target[i]) diff --git a/tests/wrappers/__init__.py b/tests/wrappers/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/wrappers/test_bootstrapping.py b/tests/wrappers/test_bootstrapping.py new file mode 100644 index 00000000000..6475e30cc22 --- /dev/null +++ b/tests/wrappers/test_bootstrapping.py @@ -0,0 +1,64 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from functools import partial + +import numpy as np +import torch + +from tests.helpers.testers import MetricTester + +from sklearn.metrics import precision_score, recall_score + +from torchmetrics.classification import Precision, Recall +from torchmetrics.wrappers.bootstrapping import BootStrapper + +_ = torch.manual_seed(0) + +_preds = torch.randint(10, (10, 32)) +_target = torch.randint(10, (10, 32)) + +def _sk_bootstrap(preds, target, func, num_bootstraps=10): + preds = preds.numpy() + target = target.numpy() + + scores = [ ] + for i in range(num_bootstraps): + idx = torch.multinomial(torch.ones(preds.shape[0]), num_samples=preds.shape[0], replacement=True) + print('numpy', idx) + preds_idx = preds[idx] + target_idx = target[idx] + scores.append(func(target_idx, preds_idx, average='micro')) + scores = np.stack(scores) + return [scores.mean(), scores.std()] + +@pytest.mark.parametrize("metric, sk_metric", [ + [Precision(), precision_score], + [Recall(), recall_score], +]) +class TestBootStrapper(MetricTester): + + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_bootstrapper(self, metric, sk_metric, ddp, dist_sync_on_step): + self.run_class_metric_test( + ddp, + _preds, + _target, + metric_class=partial(BootStrapper, base_metric=metric), + sk_metric=partial(_sk_bootstrap, func=sk_metric), + dist_sync_on_step=dist_sync_on_step, + ) + + \ No newline at end of file diff --git a/torchmetrics/wrappers/bootstrapping.py b/torchmetrics/wrappers/bootstrapping.py index f23d755cdeb..84f02cbb0a9 100644 --- a/torchmetrics/wrappers/bootstrapping.py +++ b/torchmetrics/wrappers/bootstrapping.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union, List +from typing import Any, Optional, Union, List, Callable from copy import deepcopy import torch @@ -21,7 +21,7 @@ from torchmetrics.utilities import apply_to_collection -def _bootstrap_sampler(tensor, size: Optional[int] = None): +def _bootstrap_sampler(tensor: torch.Tensor, size: Optional[int] = None) -> torch.Tensor: """ """ if size is None: size = tensor.shape[0] @@ -30,12 +30,18 @@ def _bootstrap_sampler(tensor, size: Optional[int] = None): num_samples=size, replacement=True ) + print('pytorch', idx) return tensor[idx] class BootStrapper(Metric): def __init__(self, base_metric: Metric, - num_bootstraps: int = 10): + num_bootstraps: int = 10, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None + ) -> None: """ Use to turn a metric into a bootstrapped metric that can automate the process of getting confidence intervals for metric values. This wrapper class basically keeps multiple copies of the same base metric @@ -45,6 +51,21 @@ def __init__(self, base_metric: Metric, .. note:: Different from all other metrics, bootstrapped metrics has additional arguments in its ``compute`` method determining what should be returned. + Args: + base_metric: + base metric class to wrap + compute_on_step: + Forward only calls ``update()`` and return ``None`` if this is set to ``False``. + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step + process_group: + Specify the process group on which synchronization is called. + default: ``None`` (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When ``None``, DDP + will be used to perform the allgather. + Example:: >>> from torchmetrics.wrappers import BootStrapper >>> from torchmetrics import Accuracy @@ -58,7 +79,16 @@ def __init__(self, base_metric: Metric, tensor(0.4950) tensor(0.1677) """ - super().__init__() + super().__init__( + compute_on_step, + dist_sync_on_step, + process_group, + dist_sync_fn + ) + if not isinstance(base_metric, Metric): + raise ValueError("Expected base metric to be an instance of torchmetrics.Metric" + f" but received {base_metric}") + self.metrics = nn.ModuleList([deepcopy(base_metric) for _ in range(num_bootstraps)]) self.num_bootstraps = num_bootstraps From 93c30a8a66e2afae32abb3a3ad801cd1212a98d2 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 17 Mar 2021 17:13:28 +0100 Subject: [PATCH 03/36] pep8 --- tests/wrappers/test_bootstrapping.py | 112 ++++++++++++++++--------- torchmetrics/utilities/__init__.py | 1 + torchmetrics/utilities/imports.py | 58 +++++++++++++ torchmetrics/wrappers/__init__.py | 2 +- torchmetrics/wrappers/bootstrapping.py | 85 ++++++++++++------- 5 files changed, 189 insertions(+), 69 deletions(-) create mode 100644 torchmetrics/utilities/imports.py diff --git a/tests/wrappers/test_bootstrapping.py b/tests/wrappers/test_bootstrapping.py index 6475e30cc22..4e633acc61e 100644 --- a/tests/wrappers/test_bootstrapping.py +++ b/tests/wrappers/test_bootstrapping.py @@ -11,54 +11,86 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import pytest -from functools import partial - import numpy as np +import pytest import torch - -from tests.helpers.testers import MetricTester - from sklearn.metrics import precision_score, recall_score from torchmetrics.classification import Precision, Recall -from torchmetrics.wrappers.bootstrapping import BootStrapper - -_ = torch.manual_seed(0) +from torchmetrics.utilities import _TORCH_GREATER_EQUAL_1_7, apply_to_collection +from torchmetrics.wrappers.bootstrapping import BootStrapper, _bootstrap_sampler _preds = torch.randint(10, (10, 32)) _target = torch.randint(10, (10, 32)) -def _sk_bootstrap(preds, target, func, num_bootstraps=10): - preds = preds.numpy() - target = target.numpy() - - scores = [ ] - for i in range(num_bootstraps): - idx = torch.multinomial(torch.ones(preds.shape[0]), num_samples=preds.shape[0], replacement=True) - print('numpy', idx) - preds_idx = preds[idx] - target_idx = target[idx] - scores.append(func(target_idx, preds_idx, average='micro')) - scores = np.stack(scores) - return [scores.mean(), scores.std()] - -@pytest.mark.parametrize("metric, sk_metric", [ - [Precision(), precision_score], - [Recall(), recall_score], -]) -class TestBootStrapper(MetricTester): - - @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_bootstrapper(self, metric, sk_metric, ddp, dist_sync_on_step): - self.run_class_metric_test( - ddp, - _preds, - _target, - metric_class=partial(BootStrapper, base_metric=metric), - sk_metric=partial(_sk_bootstrap, func=sk_metric), - dist_sync_on_step=dist_sync_on_step, + +class TestBootStrapper(BootStrapper): + """ For testing purpose, we subclass the bootstrapper class so we can get the exact permutation + the class is creating + """ + + def update(self, *args, **kwargs): + self.out = [] + for idx in range(self.num_bootstraps): + new_args = apply_to_collection(args, torch.Tensor, _bootstrap_sampler, generator=self.generator) + new_kwargs = apply_to_collection(kwargs, torch.Tensor, _bootstrap_sampler, generator=self.generator) + self.metrics[idx].update(*new_args, **new_kwargs) + self.out.append(new_args) + + +def test_bootstrap_sampler(): + """ make sure that the bootstrap sampler works as intended """ + old_samples = torch.randn(5, 2) + + # make sure that the new samples are only made up of old samples + new_samples = _bootstrap_sampler(old_samples) + for ns in new_samples: + assert ns in old_samples + + # make sure some samples are also sampled twice + found_one = False + for os in old_samples: + cond = os == new_samples + print(cond.sum()) + if cond.sum() > 2: + found_one = True + assert found_one, "resampling did not work because no samples were sampled twice" + + +@pytest.mark.parametrize( + "metric, sk_metric", [[Precision(average='micro'), precision_score], [Recall(average='micro'), recall_score]] +) +def test_bootstrap(metric, sk_metric): + """ Test that the different bootstraps gets updated as we expected and that the compute method works """ + bootstrapper = TestBootStrapper(metric) + + collected_preds = [[] for _ in range(10)] + collected_target = [[] for _ in range(10)] + for p, t in zip(_preds, _target): + bootstrapper.update(p, t) + + for i, o in enumerate(bootstrapper.out): + + collected_preds[i].append(o[0]) + collected_target[i].append(o[1]) + + collected_preds = [torch.cat(cp) for cp in collected_preds] + collected_target = [torch.cat(ct) for ct in collected_target] + + sk_scores = [sk_metric(ct, cp, average='micro') for ct, cp in zip(collected_target, collected_preds)] + + # quantile only avaible for pytorch v1.7 and forward + if _TORCH_GREATER_EQUAL_1_7: + pl_mean, pl_std, pl_quantile, pl_raw = bootstrapper.compute( + mean=True, std=True, quantile=torch.tensor([0.05, 0.95]), raw=True ) + assert np.allclose(pl_quantile[0], np.quantile(sk_scores, 0.05)) + assert np.allclose(pl_quantile[1], np.quantile(sk_scores, 0.95)) + else: + pl_mean, pl_std, pl_raw = bootstrapper.compute(mean=True, std=True, raw=True) - \ No newline at end of file + assert np.allclose(pl_mean, np.mean(sk_scores)) + import pdb + pdb.set_trace() + assert np.allclose(pl_std, np.std(sk_scores, ddof=1)) + assert np.allclose(pl_raw, sk_scores) diff --git a/torchmetrics/utilities/__init__.py b/torchmetrics/utilities/__init__.py index dff18c0f389..7b0dfd6d950 100644 --- a/torchmetrics/utilities/__init__.py +++ b/torchmetrics/utilities/__init__.py @@ -1,3 +1,4 @@ from torchmetrics.utilities.data import apply_to_collection # noqa: F401 from torchmetrics.utilities.distributed import class_reduce, reduce # noqa: F401 +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6, _TORCH_GREATER_EQUAL_1_7 # noqa: F401 from torchmetrics.utilities.prints import rank_zero_debug, rank_zero_info, rank_zero_warn # noqa: F401 diff --git a/torchmetrics/utilities/imports.py b/torchmetrics/utilities/imports.py new file mode 100644 index 00000000000..3ef2bbda28a --- /dev/null +++ b/torchmetrics/utilities/imports.py @@ -0,0 +1,58 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Import utilities""" +import importlib +import operator +from distutils.version import LooseVersion +from importlib.util import find_spec + +from pkg_resources import DistributionNotFound + + +def _module_available(module_path: str) -> bool: + """ + Check if a path is available in your environment + >>> _module_available('os') + True + >>> _module_available('bla.bla') + False + """ + try: + return find_spec(module_path) is not None + except AttributeError: + # Python 3.6 + return False + except ModuleNotFoundError: + # Python 3.7+ + return False + + +def _compare_version(package: str, op, version) -> bool: + """Compare package version with some requirements + >>> _compare_version("torch", operator.ge, "0.1") + True + """ + if not _module_available(package): + return False + try: + pkg = importlib.import_module(package) + assert hasattr(pkg, '__version__') + pkg_version = pkg.__version__ + return op(pkg_version, LooseVersion(version)) + except DistributionNotFound: + return False + + +_TORCH_GREATER_EQUAL_1_6 = _compare_version("torch", operator.ge, "1.6.0") +_TORCH_GREATER_EQUAL_1_7 = _compare_version("torch", operator.ge, "1.7.0") diff --git a/torchmetrics/wrappers/__init__.py b/torchmetrics/wrappers/__init__.py index 7e6b7d4da94..4f506ea4da3 100644 --- a/torchmetrics/wrappers/__init__.py +++ b/torchmetrics/wrappers/__init__.py @@ -11,4 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from torchmetrics.wrappers.bootstrapping import BootStrapper # noqa: F401 \ No newline at end of file +from torchmetrics.wrappers.bootstrapping import BootStrapper # noqa: F401 diff --git a/torchmetrics/wrappers/bootstrapping.py b/torchmetrics/wrappers/bootstrapping.py index 84f02cbb0a9..43cbfbb5d8b 100644 --- a/torchmetrics/wrappers/bootstrapping.py +++ b/torchmetrics/wrappers/bootstrapping.py @@ -11,49 +11,69 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Union, List, Callable from copy import deepcopy +from typing import Any, Callable, List, Optional, Union import torch from torch import nn from torchmetrics.metric import Metric -from torchmetrics.utilities import apply_to_collection +from torchmetrics.utilities import _TORCH_GREATER_EQUAL_1_7, apply_to_collection -def _bootstrap_sampler(tensor: torch.Tensor, size: Optional[int] = None) -> torch.Tensor: - """ """ +def _bootstrap_sampler( + tensor: torch.Tensor, + size: Optional[int] = None, + generator: Optional[torch.Generator] = None +) -> torch.Tensor: + """ Resample a tensor along its first dimension with replacement + Args: + tensor: tensor to resample + size: number of samples in new tensor. Defauls to same size as input tensor + generator: a instance of ``torch.Generator`` that controls the sampling + + Returns: + resampled tensor + + """ if size is None: size = tensor.shape[0] idx = torch.multinomial( torch.ones(tensor.shape[0], device=tensor.device), num_samples=size, - replacement=True + replacement=True, + generator=generator ) - print('pytorch', idx) return tensor[idx] class BootStrapper(Metric): - def __init__(self, base_metric: Metric, - num_bootstraps: int = 10, - compute_on_step: bool = True, - dist_sync_on_step: bool = False, - process_group: Optional[Any] = None, - dist_sync_fn: Callable = None + def __init__( + self, + base_metric: Metric, + num_bootstraps: int = 10, + generator: Optional[torch.Generator] = None, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None ) -> None: - """ + """ Use to turn a metric into a bootstrapped metric that can automate the process of getting confidence intervals for metric values. This wrapper class basically keeps multiple copies of the same base metric in memory and whenever ``update`` or ``forward`` is called, all input tensors are resampled (with replacement) along the first dimension. - + .. note:: Different from all other metrics, bootstrapped metrics has additional arguments in its ``compute`` method determining what should be returned. - + Args: - base_metric: + base_metric: base metric class to wrap + num_bootstraps: + number of copies to make of the base metric for bootstrapping + generator: + A pytorch random number generator for the bootstrap sampler compute_on_step: Forward only calls ``update()`` and return ``None`` if this is set to ``False``. dist_sync_on_step: @@ -65,7 +85,7 @@ def __init__(self, base_metric: Metric, dist_sync_fn: Callback that performs the allgather operation on the metric state. When ``None``, DDP will be used to perform the allgather. - + Example:: >>> from torchmetrics.wrappers import BootStrapper >>> from torchmetrics import Accuracy @@ -77,7 +97,7 @@ def __init__(self, base_metric: Metric, >>> mean, std = output >>> print(mean, std) tensor(0.4950) tensor(0.1677) - + """ super().__init__( compute_on_step, @@ -87,32 +107,39 @@ def __init__(self, base_metric: Metric, ) if not isinstance(base_metric, Metric): raise ValueError("Expected base metric to be an instance of torchmetrics.Metric" - f" but received {base_metric}") + f" but received {base_metric}") self.metrics = nn.ModuleList([deepcopy(base_metric) for _ in range(num_bootstraps)]) self.num_bootstraps = num_bootstraps - + + if generator is not None and not isinstance(generator, torch.Generator): + raise ValueError("Expected argument ``generator`` to be an instance of ``torch.Generator``" + f"but received {generator}") + self.generator = generator + def update(self, *args, **kwargs): """ Updates the state of the base metric. Any tensor passed in will be bootstrapped along dimension 0 """ for idx in range(self.num_bootstraps): - args = apply_to_collection(args, torch.Tensor, _bootstrap_sampler) - kwargs = apply_to_collection(kwargs, torch.Tensor, _bootstrap_sampler) - self.metrics[idx].update(*args, **kwargs) + new_args = apply_to_collection(args, torch.Tensor, _bootstrap_sampler, generator=self.generator) + new_kwargs = apply_to_collection(kwargs, torch.Tensor, _bootstrap_sampler, generator=self.generator) + self.metrics[idx].update(*new_args, **new_kwargs) def compute( - self, - mean: bool = True, - std: bool = True, + self, + mean: bool = True, + std: bool = True, quantile: Optional[Union[float, torch.Tensor]] = None, raw: bool = False - ) -> List[torch.Tensor]: + ) -> List[torch.Tensor]: """ Computes the metric value. Args: mean: if `True` return the mean of the bootstraps std: if `True` return the standard diviation of the bootstraps - quantile: if given, returns the quantile of the bootstraps + quantile: + if given, returns the quantile of the bootstraps. Can only be used when pytorch version + 1.6 or higher raw: if `True`, return all bootstrapped values """ computed_vals = torch.stack([m.compute() for m in self.metrics], dim=0) @@ -122,6 +149,8 @@ def compute( if std: output.append(computed_vals.std(dim=0)) if quantile is not None: + if not _TORCH_GREATER_EQUAL_1_7: + raise ValueError('quantial argument can only be used with pytorch v1.6 or higher') output.append(torch.quantile(computed_vals, quantile)) if raw: output.append(computed_vals) From d785e6ce09354be865aa168db2cfcf1b5effd0a0 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 18 Mar 2021 10:57:58 +0100 Subject: [PATCH 04/36] move args to init --- torchmetrics/wrappers/bootstrapping.py | 61 ++++++++++++++------------ 1 file changed, 32 insertions(+), 29 deletions(-) diff --git a/torchmetrics/wrappers/bootstrapping.py b/torchmetrics/wrappers/bootstrapping.py index 43cbfbb5d8b..e0959f97514 100644 --- a/torchmetrics/wrappers/bootstrapping.py +++ b/torchmetrics/wrappers/bootstrapping.py @@ -52,6 +52,10 @@ def __init__( self, base_metric: Metric, num_bootstraps: int = 10, + mean: bool = True, + std: bool = True, + quantile: Optional[Union[float, torch.Tensor]] = None, + raw: bool = False, generator: Optional[torch.Generator] = None, compute_on_step: bool = True, dist_sync_on_step: bool = False, @@ -64,14 +68,20 @@ def __init__( in memory and whenever ``update`` or ``forward`` is called, all input tensors are resampled (with replacement) along the first dimension. - .. note:: Different from all other metrics, bootstrapped metrics has additional - arguments in its ``compute`` method determining what should be returned. - Args: base_metric: base metric class to wrap num_bootstraps: number of copies to make of the base metric for bootstrapping + mean: + if ``True`` return the mean of the bootstraps + std: + if ``True`` return the standard diviation of the bootstraps + quantile: + if given, returns the quantile of the bootstraps. Can only be used with + pytorch version 1.6 or higher + raw: + if ``True``, return all bootstrapped values generator: A pytorch random number generator for the bootstrap sampler compute_on_step: @@ -89,14 +99,14 @@ def __init__( Example:: >>> from torchmetrics.wrappers import BootStrapper >>> from torchmetrics import Accuracy - >>> _ = torch.manual_seed(0) + >>> generator = torch.manual_seed(0) >>> base_metric = Accuracy() - >>> bootstrap = BootStrapper(base_metric, num_bootstraps=20) + >>> bootstrap = BootStrapper(base_metric, num_bootstraps=20, generator=generator) >>> bootstrap.update(torch.randint(5, (20,)), torch.randint(5, (20,))) - >>> output = bootstrap.compute(mean=True, std=True) + >>> output = bootstrap.compute() >>> mean, std = output >>> print(mean, std) - tensor(0.4950) tensor(0.1677) + tensor(0.2175) tensor(0.0950) """ super().__init__( @@ -112,6 +122,13 @@ def __init__( self.metrics = nn.ModuleList([deepcopy(base_metric) for _ in range(num_bootstraps)]) self.num_bootstraps = num_bootstraps + self.mean = mean + self.std = std + if quantile is not None and not _TORCH_GREATER_EQUAL_1_7: + raise ValueError('quantile argument can only be used with pytorch v1.6 or higher') + self.quantile = quantile + self.raw = raw + if generator is not None and not isinstance(generator, torch.Generator): raise ValueError("Expected argument ``generator`` to be an instance of ``torch.Generator``" f"but received {generator}") @@ -126,32 +143,18 @@ def update(self, *args, **kwargs): new_kwargs = apply_to_collection(kwargs, torch.Tensor, _bootstrap_sampler, generator=self.generator) self.metrics[idx].update(*new_args, **new_kwargs) - def compute( - self, - mean: bool = True, - std: bool = True, - quantile: Optional[Union[float, torch.Tensor]] = None, - raw: bool = False - ) -> List[torch.Tensor]: - """ Computes the metric value. - Args: - mean: if `True` return the mean of the bootstraps - std: if `True` return the standard diviation of the bootstraps - quantile: - if given, returns the quantile of the bootstraps. Can only be used when pytorch version - 1.6 or higher - raw: if `True`, return all bootstrapped values + def compute(self) -> List[torch.Tensor]: + """ Computes the bootstrapped metric values. Allways returns a list of tensors, but the content of + the list depends on how the class was initialized """ computed_vals = torch.stack([m.compute() for m in self.metrics], dim=0) output = [] - if mean: + if self.mean: output.append(computed_vals.mean(dim=0)) - if std: + if self.std: output.append(computed_vals.std(dim=0)) - if quantile is not None: - if not _TORCH_GREATER_EQUAL_1_7: - raise ValueError('quantial argument can only be used with pytorch v1.6 or higher') - output.append(torch.quantile(computed_vals, quantile)) - if raw: + if self.quantile is not None: + output.append(torch.quantile(computed_vals, self.quantile)) + if self.raw: output.append(computed_vals) return output From 0d23b7c305109c12f051fa26737f528821067cde Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 18 Mar 2021 11:03:34 +0100 Subject: [PATCH 05/36] fix tests --- CHANGELOG.md | 3 ++- torchmetrics/wrappers/bootstrapping.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index baed9e5d428..93b82272b79 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,7 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `CohenKappa` metric ([#69](https://github.com/PyTorchLightning/metrics/pull/69)) -- Added `BootStrapper` to easely calculate confidence intervals for metrics ([]()) +- Added `BootStrapper` to easely calculate confidence intervals for metrics ([#101](https://github.com/PyTorchLightning/metrics/pull/101)) + ### Changed diff --git a/torchmetrics/wrappers/bootstrapping.py b/torchmetrics/wrappers/bootstrapping.py index e0959f97514..9870517a684 100644 --- a/torchmetrics/wrappers/bootstrapping.py +++ b/torchmetrics/wrappers/bootstrapping.py @@ -134,7 +134,7 @@ def __init__( f"but received {generator}") self.generator = generator - def update(self, *args, **kwargs): + def update(self, *args: Any, **kwargs: Any): """ Updates the state of the base metric. Any tensor passed in will be bootstrapped along dimension 0 """ From 7a08934b95ffffc639bf8e543aa12b6890cbe50d Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 18 Mar 2021 11:08:55 +0100 Subject: [PATCH 06/36] fix tests --- tests/wrappers/test_bootstrapping.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/wrappers/test_bootstrapping.py b/tests/wrappers/test_bootstrapping.py index 4e633acc61e..f6b50da177d 100644 --- a/tests/wrappers/test_bootstrapping.py +++ b/tests/wrappers/test_bootstrapping.py @@ -62,7 +62,10 @@ def test_bootstrap_sampler(): ) def test_bootstrap(metric, sk_metric): """ Test that the different bootstraps gets updated as we expected and that the compute method works """ - bootstrapper = TestBootStrapper(metric) + if _TORCH_GREATER_EQUAL_1_7: + bootstrapper = TestBootStrapper(metric, mean=True, std=True, quantile=torch.tensor([0.05, 0.95]), raw=True) + else: + bootstrapper = TestBootStrapper(metric, mean=True, std=True, raw=True) collected_preds = [[] for _ in range(10)] collected_target = [[] for _ in range(10)] @@ -79,18 +82,15 @@ def test_bootstrap(metric, sk_metric): sk_scores = [sk_metric(ct, cp, average='micro') for ct, cp in zip(collected_target, collected_preds)] + output = bootstrapper.compute() # quantile only avaible for pytorch v1.7 and forward if _TORCH_GREATER_EQUAL_1_7: - pl_mean, pl_std, pl_quantile, pl_raw = bootstrapper.compute( - mean=True, std=True, quantile=torch.tensor([0.05, 0.95]), raw=True - ) + pl_mean, pl_std, pl_quantile, pl_raw = output assert np.allclose(pl_quantile[0], np.quantile(sk_scores, 0.05)) assert np.allclose(pl_quantile[1], np.quantile(sk_scores, 0.95)) else: - pl_mean, pl_std, pl_raw = bootstrapper.compute(mean=True, std=True, raw=True) + pl_mean, pl_std, pl_raw = output assert np.allclose(pl_mean, np.mean(sk_scores)) - import pdb - pdb.set_trace() assert np.allclose(pl_std, np.std(sk_scores, ddof=1)) assert np.allclose(pl_raw, sk_scores) From d1c0482bf4252f86860a2d5f36611f9ebc3d74b8 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 18 Mar 2021 11:54:28 +0100 Subject: [PATCH 07/36] mypy --- torchmetrics/wrappers/bootstrapping.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/wrappers/bootstrapping.py b/torchmetrics/wrappers/bootstrapping.py index 9870517a684..5dc6bf193af 100644 --- a/torchmetrics/wrappers/bootstrapping.py +++ b/torchmetrics/wrappers/bootstrapping.py @@ -134,7 +134,7 @@ def __init__( f"but received {generator}") self.generator = generator - def update(self, *args: Any, **kwargs: Any): + def update(self, *args: Any, **kwargs: Any) -> None: """ Updates the state of the base metric. Any tensor passed in will be bootstrapped along dimension 0 """ From 25b38adb5be93943c5b6d327e5efed5a8f0290db Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 18 Mar 2021 13:34:53 +0100 Subject: [PATCH 08/36] remove pdb --- tests/helpers/testers.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index 233538610d5..8d8ec21ac22 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -124,8 +124,6 @@ def _class_test( sk_batch_result = sk_metric(ddp_preds, ddp_target) # assert for dist_sync_on_step if check_dist_sync_on_step: - import pdb - pdb.set_trace() _assert_allclose(batch_result, sk_batch_result, atol=atol) else: sk_batch_result = sk_metric(preds[i], target[i]) From beac9e04a85fa1eb4b774a8850c0b45a15624c37 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 17 Mar 2021 11:50:25 +0100 Subject: [PATCH 09/36] add bootstrapping --- CHANGELOG.md | 4 ++ docs/source/references/modules.rst | 13 +++- torchmetrics/__init__.py | 1 + torchmetrics/wrappers/__init__.py | 14 ++++ torchmetrics/wrappers/bootstrapping.py | 98 ++++++++++++++++++++++++++ 5 files changed, 129 insertions(+), 1 deletion(-) create mode 100644 torchmetrics/wrappers/__init__.py create mode 100644 torchmetrics/wrappers/bootstrapping.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 6692332fbcf..dbb39334548 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added multilabel support to `ROC` metric ([#114](https://github.com/PyTorchLightning/metrics/pull/114)) + +- Added `BootStrapper` to easely calculate confidence intervals for metrics ([]()) + + ### Changed - Changed `ExplainedVariance` from storing all preds/targets to tracking 5 statistics ([#68](https://github.com/PyTorchLightning/metrics/pull/68)) diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index cd7fe83ea57..21a13cec812 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -269,4 +269,15 @@ R2Score ~~~~~~~ .. autoclass:: torchmetrics.R2Score - :noindex: \ No newline at end of file + :noindex: + +******** +Wrappers +******** + +Modular wrapper metrics are not metrics in themself, but instead the take in other metrics and alter +the internal logic of the base metric. + +.. autoclass:: torchmetrics.BootStrapper + :noindex: + diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 2e385e5635e..193660c8537 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -49,3 +49,4 @@ R2Score, ) from torchmetrics.retrieval import RetrievalMAP # noqa: F401 E402 +from torchmetrics.wrappers import BootStrapper # noqa: F401 diff --git a/torchmetrics/wrappers/__init__.py b/torchmetrics/wrappers/__init__.py new file mode 100644 index 00000000000..7e6b7d4da94 --- /dev/null +++ b/torchmetrics/wrappers/__init__.py @@ -0,0 +1,14 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from torchmetrics.wrappers.bootstrapping import BootStrapper # noqa: F401 \ No newline at end of file diff --git a/torchmetrics/wrappers/bootstrapping.py b/torchmetrics/wrappers/bootstrapping.py new file mode 100644 index 00000000000..f23d755cdeb --- /dev/null +++ b/torchmetrics/wrappers/bootstrapping.py @@ -0,0 +1,98 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Union, List +from copy import deepcopy + +import torch +from torch import nn + +from torchmetrics.metric import Metric +from torchmetrics.utilities import apply_to_collection + + +def _bootstrap_sampler(tensor, size: Optional[int] = None): + """ """ + if size is None: + size = tensor.shape[0] + idx = torch.multinomial( + torch.ones(tensor.shape[0], device=tensor.device), + num_samples=size, + replacement=True + ) + return tensor[idx] + + +class BootStrapper(Metric): + def __init__(self, base_metric: Metric, + num_bootstraps: int = 10): + """ + Use to turn a metric into a bootstrapped metric that can automate the process of getting confidence + intervals for metric values. This wrapper class basically keeps multiple copies of the same base metric + in memory and whenever ``update`` or ``forward`` is called, all input tensors are resampled + (with replacement) along the first dimension. + + .. note:: Different from all other metrics, bootstrapped metrics has additional + arguments in its ``compute`` method determining what should be returned. + + Example:: + >>> from torchmetrics.wrappers import BootStrapper + >>> from torchmetrics import Accuracy + >>> _ = torch.manual_seed(0) + >>> base_metric = Accuracy() + >>> bootstrap = BootStrapper(base_metric, num_bootstraps=20) + >>> bootstrap.update(torch.randint(5, (20,)), torch.randint(5, (20,))) + >>> output = bootstrap.compute(mean=True, std=True) + >>> mean, std = output + >>> print(mean, std) + tensor(0.4950) tensor(0.1677) + + """ + super().__init__() + self.metrics = nn.ModuleList([deepcopy(base_metric) for _ in range(num_bootstraps)]) + self.num_bootstraps = num_bootstraps + + def update(self, *args, **kwargs): + """ Updates the state of the base metric. Any tensor passed in will be bootstrapped + along dimension 0 + """ + for idx in range(self.num_bootstraps): + args = apply_to_collection(args, torch.Tensor, _bootstrap_sampler) + kwargs = apply_to_collection(kwargs, torch.Tensor, _bootstrap_sampler) + self.metrics[idx].update(*args, **kwargs) + + def compute( + self, + mean: bool = True, + std: bool = True, + quantile: Optional[Union[float, torch.Tensor]] = None, + raw: bool = False + ) -> List[torch.Tensor]: + """ Computes the metric value. + Args: + mean: if `True` return the mean of the bootstraps + std: if `True` return the standard diviation of the bootstraps + quantile: if given, returns the quantile of the bootstraps + raw: if `True`, return all bootstrapped values + """ + computed_vals = torch.stack([m.compute() for m in self.metrics], dim=0) + output = [] + if mean: + output.append(computed_vals.mean(dim=0)) + if std: + output.append(computed_vals.std(dim=0)) + if quantile is not None: + output.append(torch.quantile(computed_vals, quantile)) + if raw: + output.append(computed_vals) + return output From 807a4e2fb7b043f1ffae61ae2c2f47adc1ced7ec Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 17 Mar 2021 15:44:57 +0100 Subject: [PATCH 10/36] tests --- tests/helpers/testers.py | 8 +++- tests/wrappers/__init__.py | 0 tests/wrappers/test_bootstrapping.py | 64 ++++++++++++++++++++++++++ torchmetrics/wrappers/bootstrapping.py | 38 +++++++++++++-- 4 files changed, 105 insertions(+), 5 deletions(-) create mode 100644 tests/wrappers/__init__.py create mode 100644 tests/wrappers/test_bootstrapping.py diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index 149f300f0e0..6c6b4eb88dc 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -107,7 +107,11 @@ def _class_test( if not metric_args: metric_args = {} # Instanciate lightning metric - metric = metric_class(compute_on_step=True, dist_sync_on_step=dist_sync_on_step, **metric_args) + metric = metric_class( + compute_on_step=check_dist_sync_on_step or check_batch, + dist_sync_on_step=dist_sync_on_step, + **metric_args + ) # verify metrics work after being loaded from pickled state pickled_metric = pickle.dumps(metric) @@ -123,6 +127,8 @@ def _class_test( sk_batch_result = sk_metric(ddp_preds, ddp_target) # assert for dist_sync_on_step if check_dist_sync_on_step: + import pdb + pdb.set_trace() _assert_allclose(batch_result, sk_batch_result, atol=atol) else: sk_batch_result = sk_metric(preds[i], target[i]) diff --git a/tests/wrappers/__init__.py b/tests/wrappers/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/wrappers/test_bootstrapping.py b/tests/wrappers/test_bootstrapping.py new file mode 100644 index 00000000000..6475e30cc22 --- /dev/null +++ b/tests/wrappers/test_bootstrapping.py @@ -0,0 +1,64 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from functools import partial + +import numpy as np +import torch + +from tests.helpers.testers import MetricTester + +from sklearn.metrics import precision_score, recall_score + +from torchmetrics.classification import Precision, Recall +from torchmetrics.wrappers.bootstrapping import BootStrapper + +_ = torch.manual_seed(0) + +_preds = torch.randint(10, (10, 32)) +_target = torch.randint(10, (10, 32)) + +def _sk_bootstrap(preds, target, func, num_bootstraps=10): + preds = preds.numpy() + target = target.numpy() + + scores = [ ] + for i in range(num_bootstraps): + idx = torch.multinomial(torch.ones(preds.shape[0]), num_samples=preds.shape[0], replacement=True) + print('numpy', idx) + preds_idx = preds[idx] + target_idx = target[idx] + scores.append(func(target_idx, preds_idx, average='micro')) + scores = np.stack(scores) + return [scores.mean(), scores.std()] + +@pytest.mark.parametrize("metric, sk_metric", [ + [Precision(), precision_score], + [Recall(), recall_score], +]) +class TestBootStrapper(MetricTester): + + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_bootstrapper(self, metric, sk_metric, ddp, dist_sync_on_step): + self.run_class_metric_test( + ddp, + _preds, + _target, + metric_class=partial(BootStrapper, base_metric=metric), + sk_metric=partial(_sk_bootstrap, func=sk_metric), + dist_sync_on_step=dist_sync_on_step, + ) + + \ No newline at end of file diff --git a/torchmetrics/wrappers/bootstrapping.py b/torchmetrics/wrappers/bootstrapping.py index f23d755cdeb..84f02cbb0a9 100644 --- a/torchmetrics/wrappers/bootstrapping.py +++ b/torchmetrics/wrappers/bootstrapping.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union, List +from typing import Any, Optional, Union, List, Callable from copy import deepcopy import torch @@ -21,7 +21,7 @@ from torchmetrics.utilities import apply_to_collection -def _bootstrap_sampler(tensor, size: Optional[int] = None): +def _bootstrap_sampler(tensor: torch.Tensor, size: Optional[int] = None) -> torch.Tensor: """ """ if size is None: size = tensor.shape[0] @@ -30,12 +30,18 @@ def _bootstrap_sampler(tensor, size: Optional[int] = None): num_samples=size, replacement=True ) + print('pytorch', idx) return tensor[idx] class BootStrapper(Metric): def __init__(self, base_metric: Metric, - num_bootstraps: int = 10): + num_bootstraps: int = 10, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None + ) -> None: """ Use to turn a metric into a bootstrapped metric that can automate the process of getting confidence intervals for metric values. This wrapper class basically keeps multiple copies of the same base metric @@ -45,6 +51,21 @@ def __init__(self, base_metric: Metric, .. note:: Different from all other metrics, bootstrapped metrics has additional arguments in its ``compute`` method determining what should be returned. + Args: + base_metric: + base metric class to wrap + compute_on_step: + Forward only calls ``update()`` and return ``None`` if this is set to ``False``. + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step + process_group: + Specify the process group on which synchronization is called. + default: ``None`` (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When ``None``, DDP + will be used to perform the allgather. + Example:: >>> from torchmetrics.wrappers import BootStrapper >>> from torchmetrics import Accuracy @@ -58,7 +79,16 @@ def __init__(self, base_metric: Metric, tensor(0.4950) tensor(0.1677) """ - super().__init__() + super().__init__( + compute_on_step, + dist_sync_on_step, + process_group, + dist_sync_fn + ) + if not isinstance(base_metric, Metric): + raise ValueError("Expected base metric to be an instance of torchmetrics.Metric" + f" but received {base_metric}") + self.metrics = nn.ModuleList([deepcopy(base_metric) for _ in range(num_bootstraps)]) self.num_bootstraps = num_bootstraps From 7caabbe646637d1b13f962e674bae3bf6d47d6d5 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 17 Mar 2021 17:13:28 +0100 Subject: [PATCH 11/36] pep8 --- tests/wrappers/test_bootstrapping.py | 112 ++++++++++++++++--------- torchmetrics/utilities/__init__.py | 1 + torchmetrics/utilities/imports.py | 57 ++++++++++++- torchmetrics/wrappers/__init__.py | 2 +- torchmetrics/wrappers/bootstrapping.py | 85 ++++++++++++------- 5 files changed, 187 insertions(+), 70 deletions(-) diff --git a/tests/wrappers/test_bootstrapping.py b/tests/wrappers/test_bootstrapping.py index 6475e30cc22..4e633acc61e 100644 --- a/tests/wrappers/test_bootstrapping.py +++ b/tests/wrappers/test_bootstrapping.py @@ -11,54 +11,86 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import pytest -from functools import partial - import numpy as np +import pytest import torch - -from tests.helpers.testers import MetricTester - from sklearn.metrics import precision_score, recall_score from torchmetrics.classification import Precision, Recall -from torchmetrics.wrappers.bootstrapping import BootStrapper - -_ = torch.manual_seed(0) +from torchmetrics.utilities import _TORCH_GREATER_EQUAL_1_7, apply_to_collection +from torchmetrics.wrappers.bootstrapping import BootStrapper, _bootstrap_sampler _preds = torch.randint(10, (10, 32)) _target = torch.randint(10, (10, 32)) -def _sk_bootstrap(preds, target, func, num_bootstraps=10): - preds = preds.numpy() - target = target.numpy() - - scores = [ ] - for i in range(num_bootstraps): - idx = torch.multinomial(torch.ones(preds.shape[0]), num_samples=preds.shape[0], replacement=True) - print('numpy', idx) - preds_idx = preds[idx] - target_idx = target[idx] - scores.append(func(target_idx, preds_idx, average='micro')) - scores = np.stack(scores) - return [scores.mean(), scores.std()] - -@pytest.mark.parametrize("metric, sk_metric", [ - [Precision(), precision_score], - [Recall(), recall_score], -]) -class TestBootStrapper(MetricTester): - - @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_bootstrapper(self, metric, sk_metric, ddp, dist_sync_on_step): - self.run_class_metric_test( - ddp, - _preds, - _target, - metric_class=partial(BootStrapper, base_metric=metric), - sk_metric=partial(_sk_bootstrap, func=sk_metric), - dist_sync_on_step=dist_sync_on_step, + +class TestBootStrapper(BootStrapper): + """ For testing purpose, we subclass the bootstrapper class so we can get the exact permutation + the class is creating + """ + + def update(self, *args, **kwargs): + self.out = [] + for idx in range(self.num_bootstraps): + new_args = apply_to_collection(args, torch.Tensor, _bootstrap_sampler, generator=self.generator) + new_kwargs = apply_to_collection(kwargs, torch.Tensor, _bootstrap_sampler, generator=self.generator) + self.metrics[idx].update(*new_args, **new_kwargs) + self.out.append(new_args) + + +def test_bootstrap_sampler(): + """ make sure that the bootstrap sampler works as intended """ + old_samples = torch.randn(5, 2) + + # make sure that the new samples are only made up of old samples + new_samples = _bootstrap_sampler(old_samples) + for ns in new_samples: + assert ns in old_samples + + # make sure some samples are also sampled twice + found_one = False + for os in old_samples: + cond = os == new_samples + print(cond.sum()) + if cond.sum() > 2: + found_one = True + assert found_one, "resampling did not work because no samples were sampled twice" + + +@pytest.mark.parametrize( + "metric, sk_metric", [[Precision(average='micro'), precision_score], [Recall(average='micro'), recall_score]] +) +def test_bootstrap(metric, sk_metric): + """ Test that the different bootstraps gets updated as we expected and that the compute method works """ + bootstrapper = TestBootStrapper(metric) + + collected_preds = [[] for _ in range(10)] + collected_target = [[] for _ in range(10)] + for p, t in zip(_preds, _target): + bootstrapper.update(p, t) + + for i, o in enumerate(bootstrapper.out): + + collected_preds[i].append(o[0]) + collected_target[i].append(o[1]) + + collected_preds = [torch.cat(cp) for cp in collected_preds] + collected_target = [torch.cat(ct) for ct in collected_target] + + sk_scores = [sk_metric(ct, cp, average='micro') for ct, cp in zip(collected_target, collected_preds)] + + # quantile only avaible for pytorch v1.7 and forward + if _TORCH_GREATER_EQUAL_1_7: + pl_mean, pl_std, pl_quantile, pl_raw = bootstrapper.compute( + mean=True, std=True, quantile=torch.tensor([0.05, 0.95]), raw=True ) + assert np.allclose(pl_quantile[0], np.quantile(sk_scores, 0.05)) + assert np.allclose(pl_quantile[1], np.quantile(sk_scores, 0.95)) + else: + pl_mean, pl_std, pl_raw = bootstrapper.compute(mean=True, std=True, raw=True) - \ No newline at end of file + assert np.allclose(pl_mean, np.mean(sk_scores)) + import pdb + pdb.set_trace() + assert np.allclose(pl_std, np.std(sk_scores, ddof=1)) + assert np.allclose(pl_raw, sk_scores) diff --git a/torchmetrics/utilities/__init__.py b/torchmetrics/utilities/__init__.py index dff18c0f389..7b0dfd6d950 100644 --- a/torchmetrics/utilities/__init__.py +++ b/torchmetrics/utilities/__init__.py @@ -1,3 +1,4 @@ from torchmetrics.utilities.data import apply_to_collection # noqa: F401 from torchmetrics.utilities.distributed import class_reduce, reduce # noqa: F401 +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6, _TORCH_GREATER_EQUAL_1_7 # noqa: F401 from torchmetrics.utilities.prints import rank_zero_debug, rank_zero_info, rank_zero_warn # noqa: F401 diff --git a/torchmetrics/utilities/imports.py b/torchmetrics/utilities/imports.py index 9aa15dc8e82..ae2c0144b67 100644 --- a/torchmetrics/utilities/imports.py +++ b/torchmetrics/utilities/imports.py @@ -1,6 +1,61 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Import utilities""" +import importlib +import operator from distutils.version import LooseVersion +from importlib.util import find_spec -import torch +from pkg_resources import DistributionNotFound + + +def _module_available(module_path: str) -> bool: + """ + Check if a path is available in your environment + >>> _module_available('os') + True + >>> _module_available('bla.bla') + False + """ + try: + return find_spec(module_path) is not None + except AttributeError: + # Python 3.6 + return False + except ModuleNotFoundError: + # Python 3.7+ + return False + + +def _compare_version(package: str, op, version) -> bool: + """Compare package version with some requirements + >>> _compare_version("torch", operator.ge, "0.1") + True + """ + if not _module_available(package): + return False + try: + pkg = importlib.import_module(package) + assert hasattr(pkg, '__version__') + pkg_version = pkg.__version__ + return op(pkg_version, LooseVersion(version)) + except DistributionNotFound: + return False + + +_TORCH_GREATER_EQUAL_1_6 = _compare_version("torch", operator.ge, "1.6.0") +_TORCH_GREATER_EQUAL_1_7 = _compare_version("torch", operator.ge, "1.7.0") _TORCH_LOWER_1_4 = LooseVersion(torch.__version__) < LooseVersion("1.4.0") _TORCH_LOWER_1_5 = LooseVersion(torch.__version__) < LooseVersion("1.5.0") diff --git a/torchmetrics/wrappers/__init__.py b/torchmetrics/wrappers/__init__.py index 7e6b7d4da94..4f506ea4da3 100644 --- a/torchmetrics/wrappers/__init__.py +++ b/torchmetrics/wrappers/__init__.py @@ -11,4 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from torchmetrics.wrappers.bootstrapping import BootStrapper # noqa: F401 \ No newline at end of file +from torchmetrics.wrappers.bootstrapping import BootStrapper # noqa: F401 diff --git a/torchmetrics/wrappers/bootstrapping.py b/torchmetrics/wrappers/bootstrapping.py index 84f02cbb0a9..43cbfbb5d8b 100644 --- a/torchmetrics/wrappers/bootstrapping.py +++ b/torchmetrics/wrappers/bootstrapping.py @@ -11,49 +11,69 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Union, List, Callable from copy import deepcopy +from typing import Any, Callable, List, Optional, Union import torch from torch import nn from torchmetrics.metric import Metric -from torchmetrics.utilities import apply_to_collection +from torchmetrics.utilities import _TORCH_GREATER_EQUAL_1_7, apply_to_collection -def _bootstrap_sampler(tensor: torch.Tensor, size: Optional[int] = None) -> torch.Tensor: - """ """ +def _bootstrap_sampler( + tensor: torch.Tensor, + size: Optional[int] = None, + generator: Optional[torch.Generator] = None +) -> torch.Tensor: + """ Resample a tensor along its first dimension with replacement + Args: + tensor: tensor to resample + size: number of samples in new tensor. Defauls to same size as input tensor + generator: a instance of ``torch.Generator`` that controls the sampling + + Returns: + resampled tensor + + """ if size is None: size = tensor.shape[0] idx = torch.multinomial( torch.ones(tensor.shape[0], device=tensor.device), num_samples=size, - replacement=True + replacement=True, + generator=generator ) - print('pytorch', idx) return tensor[idx] class BootStrapper(Metric): - def __init__(self, base_metric: Metric, - num_bootstraps: int = 10, - compute_on_step: bool = True, - dist_sync_on_step: bool = False, - process_group: Optional[Any] = None, - dist_sync_fn: Callable = None + def __init__( + self, + base_metric: Metric, + num_bootstraps: int = 10, + generator: Optional[torch.Generator] = None, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None ) -> None: - """ + """ Use to turn a metric into a bootstrapped metric that can automate the process of getting confidence intervals for metric values. This wrapper class basically keeps multiple copies of the same base metric in memory and whenever ``update`` or ``forward`` is called, all input tensors are resampled (with replacement) along the first dimension. - + .. note:: Different from all other metrics, bootstrapped metrics has additional arguments in its ``compute`` method determining what should be returned. - + Args: - base_metric: + base_metric: base metric class to wrap + num_bootstraps: + number of copies to make of the base metric for bootstrapping + generator: + A pytorch random number generator for the bootstrap sampler compute_on_step: Forward only calls ``update()`` and return ``None`` if this is set to ``False``. dist_sync_on_step: @@ -65,7 +85,7 @@ def __init__(self, base_metric: Metric, dist_sync_fn: Callback that performs the allgather operation on the metric state. When ``None``, DDP will be used to perform the allgather. - + Example:: >>> from torchmetrics.wrappers import BootStrapper >>> from torchmetrics import Accuracy @@ -77,7 +97,7 @@ def __init__(self, base_metric: Metric, >>> mean, std = output >>> print(mean, std) tensor(0.4950) tensor(0.1677) - + """ super().__init__( compute_on_step, @@ -87,32 +107,39 @@ def __init__(self, base_metric: Metric, ) if not isinstance(base_metric, Metric): raise ValueError("Expected base metric to be an instance of torchmetrics.Metric" - f" but received {base_metric}") + f" but received {base_metric}") self.metrics = nn.ModuleList([deepcopy(base_metric) for _ in range(num_bootstraps)]) self.num_bootstraps = num_bootstraps - + + if generator is not None and not isinstance(generator, torch.Generator): + raise ValueError("Expected argument ``generator`` to be an instance of ``torch.Generator``" + f"but received {generator}") + self.generator = generator + def update(self, *args, **kwargs): """ Updates the state of the base metric. Any tensor passed in will be bootstrapped along dimension 0 """ for idx in range(self.num_bootstraps): - args = apply_to_collection(args, torch.Tensor, _bootstrap_sampler) - kwargs = apply_to_collection(kwargs, torch.Tensor, _bootstrap_sampler) - self.metrics[idx].update(*args, **kwargs) + new_args = apply_to_collection(args, torch.Tensor, _bootstrap_sampler, generator=self.generator) + new_kwargs = apply_to_collection(kwargs, torch.Tensor, _bootstrap_sampler, generator=self.generator) + self.metrics[idx].update(*new_args, **new_kwargs) def compute( - self, - mean: bool = True, - std: bool = True, + self, + mean: bool = True, + std: bool = True, quantile: Optional[Union[float, torch.Tensor]] = None, raw: bool = False - ) -> List[torch.Tensor]: + ) -> List[torch.Tensor]: """ Computes the metric value. Args: mean: if `True` return the mean of the bootstraps std: if `True` return the standard diviation of the bootstraps - quantile: if given, returns the quantile of the bootstraps + quantile: + if given, returns the quantile of the bootstraps. Can only be used when pytorch version + 1.6 or higher raw: if `True`, return all bootstrapped values """ computed_vals = torch.stack([m.compute() for m in self.metrics], dim=0) @@ -122,6 +149,8 @@ def compute( if std: output.append(computed_vals.std(dim=0)) if quantile is not None: + if not _TORCH_GREATER_EQUAL_1_7: + raise ValueError('quantial argument can only be used with pytorch v1.6 or higher') output.append(torch.quantile(computed_vals, quantile)) if raw: output.append(computed_vals) From 853614ccbc8f721b7a656c961ba75e23a32dc983 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 18 Mar 2021 10:57:58 +0100 Subject: [PATCH 12/36] move args to init --- torchmetrics/wrappers/bootstrapping.py | 61 ++++++++++++++------------ 1 file changed, 32 insertions(+), 29 deletions(-) diff --git a/torchmetrics/wrappers/bootstrapping.py b/torchmetrics/wrappers/bootstrapping.py index 43cbfbb5d8b..e0959f97514 100644 --- a/torchmetrics/wrappers/bootstrapping.py +++ b/torchmetrics/wrappers/bootstrapping.py @@ -52,6 +52,10 @@ def __init__( self, base_metric: Metric, num_bootstraps: int = 10, + mean: bool = True, + std: bool = True, + quantile: Optional[Union[float, torch.Tensor]] = None, + raw: bool = False, generator: Optional[torch.Generator] = None, compute_on_step: bool = True, dist_sync_on_step: bool = False, @@ -64,14 +68,20 @@ def __init__( in memory and whenever ``update`` or ``forward`` is called, all input tensors are resampled (with replacement) along the first dimension. - .. note:: Different from all other metrics, bootstrapped metrics has additional - arguments in its ``compute`` method determining what should be returned. - Args: base_metric: base metric class to wrap num_bootstraps: number of copies to make of the base metric for bootstrapping + mean: + if ``True`` return the mean of the bootstraps + std: + if ``True`` return the standard diviation of the bootstraps + quantile: + if given, returns the quantile of the bootstraps. Can only be used with + pytorch version 1.6 or higher + raw: + if ``True``, return all bootstrapped values generator: A pytorch random number generator for the bootstrap sampler compute_on_step: @@ -89,14 +99,14 @@ def __init__( Example:: >>> from torchmetrics.wrappers import BootStrapper >>> from torchmetrics import Accuracy - >>> _ = torch.manual_seed(0) + >>> generator = torch.manual_seed(0) >>> base_metric = Accuracy() - >>> bootstrap = BootStrapper(base_metric, num_bootstraps=20) + >>> bootstrap = BootStrapper(base_metric, num_bootstraps=20, generator=generator) >>> bootstrap.update(torch.randint(5, (20,)), torch.randint(5, (20,))) - >>> output = bootstrap.compute(mean=True, std=True) + >>> output = bootstrap.compute() >>> mean, std = output >>> print(mean, std) - tensor(0.4950) tensor(0.1677) + tensor(0.2175) tensor(0.0950) """ super().__init__( @@ -112,6 +122,13 @@ def __init__( self.metrics = nn.ModuleList([deepcopy(base_metric) for _ in range(num_bootstraps)]) self.num_bootstraps = num_bootstraps + self.mean = mean + self.std = std + if quantile is not None and not _TORCH_GREATER_EQUAL_1_7: + raise ValueError('quantile argument can only be used with pytorch v1.6 or higher') + self.quantile = quantile + self.raw = raw + if generator is not None and not isinstance(generator, torch.Generator): raise ValueError("Expected argument ``generator`` to be an instance of ``torch.Generator``" f"but received {generator}") @@ -126,32 +143,18 @@ def update(self, *args, **kwargs): new_kwargs = apply_to_collection(kwargs, torch.Tensor, _bootstrap_sampler, generator=self.generator) self.metrics[idx].update(*new_args, **new_kwargs) - def compute( - self, - mean: bool = True, - std: bool = True, - quantile: Optional[Union[float, torch.Tensor]] = None, - raw: bool = False - ) -> List[torch.Tensor]: - """ Computes the metric value. - Args: - mean: if `True` return the mean of the bootstraps - std: if `True` return the standard diviation of the bootstraps - quantile: - if given, returns the quantile of the bootstraps. Can only be used when pytorch version - 1.6 or higher - raw: if `True`, return all bootstrapped values + def compute(self) -> List[torch.Tensor]: + """ Computes the bootstrapped metric values. Allways returns a list of tensors, but the content of + the list depends on how the class was initialized """ computed_vals = torch.stack([m.compute() for m in self.metrics], dim=0) output = [] - if mean: + if self.mean: output.append(computed_vals.mean(dim=0)) - if std: + if self.std: output.append(computed_vals.std(dim=0)) - if quantile is not None: - if not _TORCH_GREATER_EQUAL_1_7: - raise ValueError('quantial argument can only be used with pytorch v1.6 or higher') - output.append(torch.quantile(computed_vals, quantile)) - if raw: + if self.quantile is not None: + output.append(torch.quantile(computed_vals, self.quantile)) + if self.raw: output.append(computed_vals) return output From 58980d6808cfa895840beec24fe2efe508b46011 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 18 Mar 2021 11:03:34 +0100 Subject: [PATCH 13/36] fix tests --- CHANGELOG.md | 2 +- torchmetrics/wrappers/bootstrapping.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index dbb39334548..44fb37f9353 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,7 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added multilabel support to `ROC` metric ([#114](https://github.com/PyTorchLightning/metrics/pull/114)) -- Added `BootStrapper` to easely calculate confidence intervals for metrics ([]()) +- Added `BootStrapper` to easely calculate confidence intervals for metrics ([#101](https://github.com/PyTorchLightning/metrics/pull/101)) ### Changed diff --git a/torchmetrics/wrappers/bootstrapping.py b/torchmetrics/wrappers/bootstrapping.py index e0959f97514..9870517a684 100644 --- a/torchmetrics/wrappers/bootstrapping.py +++ b/torchmetrics/wrappers/bootstrapping.py @@ -134,7 +134,7 @@ def __init__( f"but received {generator}") self.generator = generator - def update(self, *args, **kwargs): + def update(self, *args: Any, **kwargs: Any): """ Updates the state of the base metric. Any tensor passed in will be bootstrapped along dimension 0 """ From 86de867906041e4e0173244c72380fa7a22b4025 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 18 Mar 2021 11:08:55 +0100 Subject: [PATCH 14/36] fix tests --- tests/wrappers/test_bootstrapping.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/wrappers/test_bootstrapping.py b/tests/wrappers/test_bootstrapping.py index 4e633acc61e..f6b50da177d 100644 --- a/tests/wrappers/test_bootstrapping.py +++ b/tests/wrappers/test_bootstrapping.py @@ -62,7 +62,10 @@ def test_bootstrap_sampler(): ) def test_bootstrap(metric, sk_metric): """ Test that the different bootstraps gets updated as we expected and that the compute method works """ - bootstrapper = TestBootStrapper(metric) + if _TORCH_GREATER_EQUAL_1_7: + bootstrapper = TestBootStrapper(metric, mean=True, std=True, quantile=torch.tensor([0.05, 0.95]), raw=True) + else: + bootstrapper = TestBootStrapper(metric, mean=True, std=True, raw=True) collected_preds = [[] for _ in range(10)] collected_target = [[] for _ in range(10)] @@ -79,18 +82,15 @@ def test_bootstrap(metric, sk_metric): sk_scores = [sk_metric(ct, cp, average='micro') for ct, cp in zip(collected_target, collected_preds)] + output = bootstrapper.compute() # quantile only avaible for pytorch v1.7 and forward if _TORCH_GREATER_EQUAL_1_7: - pl_mean, pl_std, pl_quantile, pl_raw = bootstrapper.compute( - mean=True, std=True, quantile=torch.tensor([0.05, 0.95]), raw=True - ) + pl_mean, pl_std, pl_quantile, pl_raw = output assert np.allclose(pl_quantile[0], np.quantile(sk_scores, 0.05)) assert np.allclose(pl_quantile[1], np.quantile(sk_scores, 0.95)) else: - pl_mean, pl_std, pl_raw = bootstrapper.compute(mean=True, std=True, raw=True) + pl_mean, pl_std, pl_raw = output assert np.allclose(pl_mean, np.mean(sk_scores)) - import pdb - pdb.set_trace() assert np.allclose(pl_std, np.std(sk_scores, ddof=1)) assert np.allclose(pl_raw, sk_scores) From 191509b5f2f0511b7cb5c58c64e2557b1a8149ff Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 18 Mar 2021 11:54:28 +0100 Subject: [PATCH 15/36] mypy --- torchmetrics/wrappers/bootstrapping.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/wrappers/bootstrapping.py b/torchmetrics/wrappers/bootstrapping.py index 9870517a684..5dc6bf193af 100644 --- a/torchmetrics/wrappers/bootstrapping.py +++ b/torchmetrics/wrappers/bootstrapping.py @@ -134,7 +134,7 @@ def __init__( f"but received {generator}") self.generator = generator - def update(self, *args: Any, **kwargs: Any): + def update(self, *args: Any, **kwargs: Any) -> None: """ Updates the state of the base metric. Any tensor passed in will be bootstrapped along dimension 0 """ From 3bce9c146be18ddad1a82f0def4a630b2c92a8fd Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 18 Mar 2021 13:34:53 +0100 Subject: [PATCH 16/36] remove pdb --- tests/helpers/testers.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index 6c6b4eb88dc..6ae98f20652 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -127,8 +127,6 @@ def _class_test( sk_batch_result = sk_metric(ddp_preds, ddp_target) # assert for dist_sync_on_step if check_dist_sync_on_step: - import pdb - pdb.set_trace() _assert_allclose(batch_result, sk_batch_result, atol=atol) else: sk_batch_result = sk_metric(preds[i], target[i]) From 56fd2bcecca5d1bebf29f46794e503d0d203d47b Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 24 Mar 2021 00:12:01 +0100 Subject: [PATCH 17/36] versions --- tests/wrappers/test_bootstrapping.py | 3 +- torchmetrics/utilities/__init__.py | 1 - torchmetrics/utilities/imports.py | 42 +++++++------------ torchmetrics/wrappers/bootstrapping.py | 56 ++++++++++++-------------- 4 files changed, 42 insertions(+), 60 deletions(-) diff --git a/tests/wrappers/test_bootstrapping.py b/tests/wrappers/test_bootstrapping.py index f6b50da177d..7fa29d39694 100644 --- a/tests/wrappers/test_bootstrapping.py +++ b/tests/wrappers/test_bootstrapping.py @@ -17,7 +17,8 @@ from sklearn.metrics import precision_score, recall_score from torchmetrics.classification import Precision, Recall -from torchmetrics.utilities import _TORCH_GREATER_EQUAL_1_7, apply_to_collection +from torchmetrics.utilities import apply_to_collection +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_7 from torchmetrics.wrappers.bootstrapping import BootStrapper, _bootstrap_sampler _preds = torch.randint(10, (10, 32)) diff --git a/torchmetrics/utilities/__init__.py b/torchmetrics/utilities/__init__.py index 7b0dfd6d950..dff18c0f389 100644 --- a/torchmetrics/utilities/__init__.py +++ b/torchmetrics/utilities/__init__.py @@ -1,4 +1,3 @@ from torchmetrics.utilities.data import apply_to_collection # noqa: F401 from torchmetrics.utilities.distributed import class_reduce, reduce # noqa: F401 -from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6, _TORCH_GREATER_EQUAL_1_7 # noqa: F401 from torchmetrics.utilities.prints import rank_zero_debug, rank_zero_info, rank_zero_warn # noqa: F401 diff --git a/torchmetrics/utilities/imports.py b/torchmetrics/utilities/imports.py index ae2c0144b67..5129345ac73 100644 --- a/torchmetrics/utilities/imports.py +++ b/torchmetrics/utilities/imports.py @@ -20,43 +20,29 @@ from pkg_resources import DistributionNotFound -def _module_available(module_path: str) -> bool: - """ - Check if a path is available in your environment - >>> _module_available('os') - True - >>> _module_available('bla.bla') - False +def _compare_version(package: str, op, version) -> bool: """ - try: - return find_spec(module_path) is not None - except AttributeError: - # Python 3.6 - return False - except ModuleNotFoundError: - # Python 3.7+ - return False + Compare package version with some requirements - -def _compare_version(package: str, op, version) -> bool: - """Compare package version with some requirements >>> _compare_version("torch", operator.ge, "0.1") True """ - if not _module_available(package): - return False try: pkg = importlib.import_module(package) - assert hasattr(pkg, '__version__') - pkg_version = pkg.__version__ - return op(pkg_version, LooseVersion(version)) - except DistributionNotFound: + except (ModuleNotFoundError, DistributionNotFound): return False + try: + pkg_version = LooseVersion(pkg.__version__) + except AttributeError: + return False + if not (hasattr(pkg_version, "vstring") and hasattr(pkg_version, "version")): + # this is mock by sphinx, so it shall return True ro generate all summaries + return True + return op(pkg_version, LooseVersion(version)) +_TORCH_LOWER_1_4 = _compare_version("torch", operator.lt, "1.4.0") +_TORCH_LOWER_1_5 = _compare_version("torch", operator.lt, "1.5.0") +_TORCH_LOWER_1_6 = _compare_version("torch", operator.lt, "1.6.0") _TORCH_GREATER_EQUAL_1_6 = _compare_version("torch", operator.ge, "1.6.0") _TORCH_GREATER_EQUAL_1_7 = _compare_version("torch", operator.ge, "1.7.0") - -_TORCH_LOWER_1_4 = LooseVersion(torch.__version__) < LooseVersion("1.4.0") -_TORCH_LOWER_1_5 = LooseVersion(torch.__version__) < LooseVersion("1.5.0") -_TORCH_LOWER_1_6 = LooseVersion(torch.__version__) < LooseVersion("1.6.0") diff --git a/torchmetrics/wrappers/bootstrapping.py b/torchmetrics/wrappers/bootstrapping.py index 5dc6bf193af..f7b009d5f47 100644 --- a/torchmetrics/wrappers/bootstrapping.py +++ b/torchmetrics/wrappers/bootstrapping.py @@ -18,13 +18,12 @@ from torch import nn from torchmetrics.metric import Metric -from torchmetrics.utilities import _TORCH_GREATER_EQUAL_1_7, apply_to_collection +from torchmetrics.utilities import apply_to_collection +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_7 def _bootstrap_sampler( - tensor: torch.Tensor, - size: Optional[int] = None, - generator: Optional[torch.Generator] = None + tensor: torch.Tensor, size: Optional[int] = None, generator: Optional[torch.Generator] = None ) -> torch.Tensor: """ Resample a tensor along its first dimension with replacement Args: @@ -39,28 +38,26 @@ def _bootstrap_sampler( if size is None: size = tensor.shape[0] idx = torch.multinomial( - torch.ones(tensor.shape[0], device=tensor.device), - num_samples=size, - replacement=True, - generator=generator + torch.ones(tensor.shape[0], device=tensor.device), num_samples=size, replacement=True, generator=generator ) return tensor[idx] class BootStrapper(Metric): + def __init__( - self, - base_metric: Metric, - num_bootstraps: int = 10, - mean: bool = True, - std: bool = True, - quantile: Optional[Union[float, torch.Tensor]] = None, - raw: bool = False, - generator: Optional[torch.Generator] = None, - compute_on_step: bool = True, - dist_sync_on_step: bool = False, - process_group: Optional[Any] = None, - dist_sync_fn: Callable = None + self, + base_metric: Metric, + num_bootstraps: int = 10, + mean: bool = True, + std: bool = True, + quantile: Optional[Union[float, torch.Tensor]] = None, + raw: bool = False, + generator: Optional[torch.Generator] = None, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None ) -> None: """ Use to turn a metric into a bootstrapped metric that can automate the process of getting confidence @@ -109,15 +106,12 @@ def __init__( tensor(0.2175) tensor(0.0950) """ - super().__init__( - compute_on_step, - dist_sync_on_step, - process_group, - dist_sync_fn - ) + super().__init__(compute_on_step, dist_sync_on_step, process_group, dist_sync_fn) if not isinstance(base_metric, Metric): - raise ValueError("Expected base metric to be an instance of torchmetrics.Metric" - f" but received {base_metric}") + raise ValueError( + "Expected base metric to be an instance of torchmetrics.Metric" + f" but received {base_metric}" + ) self.metrics = nn.ModuleList([deepcopy(base_metric) for _ in range(num_bootstraps)]) self.num_bootstraps = num_bootstraps @@ -130,8 +124,10 @@ def __init__( self.raw = raw if generator is not None and not isinstance(generator, torch.Generator): - raise ValueError("Expected argument ``generator`` to be an instance of ``torch.Generator``" - f"but received {generator}") + raise ValueError( + "Expected argument ``generator`` to be an instance of ``torch.Generator``" + f"but received {generator}" + ) self.generator = generator def update(self, *args: Any, **kwargs: Any) -> None: From e3c2a24d142c897b3d48befefcc1c3fc01ec3c3b Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 24 Mar 2021 00:12:15 +0100 Subject: [PATCH 18/36] versions --- .../classification/test_matthews_corrcoef.py | 1 + tests/classification/test_roc.py | 35 ++++++------------- tests/helpers/testers.py | 4 +-- .../classification/matthews_corrcoef.py | 1 + torchmetrics/classification/roc.py | 1 + .../classification/matthews_corrcoef.py | 9 ++--- 6 files changed, 17 insertions(+), 34 deletions(-) diff --git a/tests/classification/test_matthews_corrcoef.py b/tests/classification/test_matthews_corrcoef.py index 8fcdc2f82dd..ad3ec060d9a 100644 --- a/tests/classification/test_matthews_corrcoef.py +++ b/tests/classification/test_matthews_corrcoef.py @@ -98,6 +98,7 @@ def _sk_matthews_corrcoef_multidim_multiclass(preds, target): (_input_mdmc.preds, _input_mdmc.target, _sk_matthews_corrcoef_multidim_multiclass, NUM_CLASSES)] ) class TestMatthewsCorrCoef(MetricTester): + @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_matthews_corrcoef(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): diff --git a/tests/classification/test_roc.py b/tests/classification/test_roc.py index 99895e6c855..ef674e05522 100644 --- a/tests/classification/test_roc.py +++ b/tests/classification/test_roc.py @@ -73,38 +73,25 @@ def _sk_roc_multidim_multiclass_prob(preds, target, num_classes=1): def _sk_roc_multilabel_prob(preds, target, num_classes=1): sk_preds = preds.numpy() sk_target = target.numpy() - return _sk_roc_curve( - y_true=sk_target, - probas_pred=sk_preds, - num_classes=num_classes, - multilabel=True - ) + return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes, multilabel=True) def _sk_roc_multilabel_multidim_prob(preds, target, num_classes=1): sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() sk_target = target.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() - return _sk_roc_curve( - y_true=sk_target, - probas_pred=sk_preds, - num_classes=num_classes, - multilabel=True - ) + return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes, multilabel=True) @pytest.mark.parametrize( - "preds, target, sk_metric, num_classes", [ - (_input_binary_prob.preds, _input_binary_prob.target, _sk_roc_binary_prob, 1), - (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_roc_multiclass_prob, NUM_CLASSES), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_roc_multidim_multiclass_prob, NUM_CLASSES), - (_input_multilabel_prob.preds, _input_multilabel_prob.target, _sk_roc_multilabel_prob, NUM_CLASSES), - ( - _input_multilabel_multidim_prob.preds, - _input_multilabel_multidim_prob.target, - _sk_roc_multilabel_multidim_prob, - NUM_CLASSES - ) - ] + "preds, target, sk_metric, num_classes", + [(_input_binary_prob.preds, _input_binary_prob.target, _sk_roc_binary_prob, 1), + (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_roc_multiclass_prob, NUM_CLASSES), + (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_roc_multidim_multiclass_prob, NUM_CLASSES), + (_input_multilabel_prob.preds, _input_multilabel_prob.target, _sk_roc_multilabel_prob, NUM_CLASSES), + ( + _input_multilabel_multidim_prob.preds, _input_multilabel_multidim_prob.target, + _sk_roc_multilabel_multidim_prob, NUM_CLASSES + )] ) class TestROC(MetricTester): diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index 6ae98f20652..4bdb4dfadf3 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -108,9 +108,7 @@ def _class_test( metric_args = {} # Instanciate lightning metric metric = metric_class( - compute_on_step=check_dist_sync_on_step or check_batch, - dist_sync_on_step=dist_sync_on_step, - **metric_args + compute_on_step=check_dist_sync_on_step or check_batch, dist_sync_on_step=dist_sync_on_step, **metric_args ) # verify metrics work after being loaded from pickled state diff --git a/torchmetrics/classification/matthews_corrcoef.py b/torchmetrics/classification/matthews_corrcoef.py index f4e84b21841..a240d97ec60 100644 --- a/torchmetrics/classification/matthews_corrcoef.py +++ b/torchmetrics/classification/matthews_corrcoef.py @@ -75,6 +75,7 @@ class MatthewsCorrcoef(Metric): tensor(0.5774) """ + def __init__( self, num_classes: int, diff --git a/torchmetrics/classification/roc.py b/torchmetrics/classification/roc.py index 6cab71d7534..bd93695153e 100644 --- a/torchmetrics/classification/roc.py +++ b/torchmetrics/classification/roc.py @@ -110,6 +110,7 @@ class ROC(Metric): tensor([1.1837, 0.1837, 0.1338, 0.1183, 0.1138])] """ + def __init__( self, num_classes: Optional[int] = None, diff --git a/torchmetrics/functional/classification/matthews_corrcoef.py b/torchmetrics/functional/classification/matthews_corrcoef.py index 91db05a7a40..f4c8052114e 100644 --- a/torchmetrics/functional/classification/matthews_corrcoef.py +++ b/torchmetrics/functional/classification/matthews_corrcoef.py @@ -24,15 +24,10 @@ def _matthews_corrcoef_compute(confmat: Tensor) -> Tensor: pk = confmat.sum(dim=1).float() c = torch.trace(confmat).float() s = confmat.sum().float() - return (c * s - sum(tk * pk)) / (torch.sqrt(s ** 2 - sum(pk * pk)) * torch.sqrt(s ** 2 - sum(tk * tk))) + return (c * s - sum(tk * pk)) / (torch.sqrt(s**2 - sum(pk * pk)) * torch.sqrt(s**2 - sum(tk * tk))) -def matthews_corrcoef( - preds: Tensor, - target: Tensor, - num_classes: int, - threshold: float = 0.5 -) -> Tensor: +def matthews_corrcoef(preds: Tensor, target: Tensor, num_classes: int, threshold: float = 0.5) -> Tensor: r""" Calculates `Matthews correlation coefficient `_ that measures From cf02eba196c01f7dcec59efac70df5a3551e5a7c Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 24 Mar 2021 09:59:16 +0100 Subject: [PATCH 19/36] Update docs/source/references/modules.rst Co-authored-by: thomas chaton --- docs/source/references/modules.rst | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index 21a13cec812..10846195724 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -275,9 +275,8 @@ R2Score Wrappers ******** -Modular wrapper metrics are not metrics in themself, but instead the take in other metrics and alter +Modular wrapper metrics are not metrics in themself, but instead take a metric and alter the internal logic of the base metric. .. autoclass:: torchmetrics.BootStrapper :noindex: - From 9ef1b47b0279e81e6d92cd627103d49c09862db3 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 24 Mar 2021 10:20:03 +0100 Subject: [PATCH 20/36] isort --- torchmetrics/wrappers/bootstrapping.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchmetrics/wrappers/bootstrapping.py b/torchmetrics/wrappers/bootstrapping.py index b60eae1599d..673d675ceb0 100644 --- a/torchmetrics/wrappers/bootstrapping.py +++ b/torchmetrics/wrappers/bootstrapping.py @@ -15,8 +15,7 @@ from typing import Any, Callable, List, Optional, Union import torch -from torch import nn -from torch import Tensor +from torch import Tensor, nn from torchmetrics.metric import Metric from torchmetrics.utilities import apply_to_collection From 557ae6abcd66cc91eb569fdfe559f763277b6599 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 24 Mar 2021 10:57:07 +0100 Subject: [PATCH 21/36] Apply suggestions from code review --- torchmetrics/utilities/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchmetrics/utilities/__init__.py b/torchmetrics/utilities/__init__.py index 7b0dfd6d950..dff18c0f389 100644 --- a/torchmetrics/utilities/__init__.py +++ b/torchmetrics/utilities/__init__.py @@ -1,4 +1,3 @@ from torchmetrics.utilities.data import apply_to_collection # noqa: F401 from torchmetrics.utilities.distributed import class_reduce, reduce # noqa: F401 -from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6, _TORCH_GREATER_EQUAL_1_7 # noqa: F401 from torchmetrics.utilities.prints import rank_zero_debug, rank_zero_info, rank_zero_warn # noqa: F401 From fa149f2b7caa9471907be0cac74bba56347b9772 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 24 Mar 2021 11:13:28 +0100 Subject: [PATCH 22/36] Update torchmetrics/wrappers/bootstrapping.py Co-authored-by: Jirka Borovec --- torchmetrics/wrappers/bootstrapping.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchmetrics/wrappers/bootstrapping.py b/torchmetrics/wrappers/bootstrapping.py index 673d675ceb0..9e7a6429f0a 100644 --- a/torchmetrics/wrappers/bootstrapping.py +++ b/torchmetrics/wrappers/bootstrapping.py @@ -96,8 +96,7 @@ def __init__( will be used to perform the allgather. Example:: - >>> from torchmetrics.wrappers import BootStrapper - >>> from torchmetrics import Accuracy + >>> from torchmetrics import Accuracy, BootStrapper >>> generator = torch.manual_seed(0) >>> base_metric = Accuracy() >>> bootstrap = BootStrapper(base_metric, num_bootstraps=20, generator=generator) From c5364d0601971facfe25fce726b85c345f62b82c Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 24 Mar 2021 20:26:52 +0100 Subject: [PATCH 23/36] update --- tests/classification/test_roc.py | 23 +++++++++++------------ tests/helpers/testers.py | 2 +- torchmetrics/wrappers/bootstrapping.py | 16 +++++++--------- 3 files changed, 19 insertions(+), 22 deletions(-) diff --git a/tests/classification/test_roc.py b/tests/classification/test_roc.py index e47536acbf3..ef674e05522 100644 --- a/tests/classification/test_roc.py +++ b/tests/classification/test_roc.py @@ -82,18 +82,17 @@ def _sk_roc_multilabel_multidim_prob(preds, target, num_classes=1): return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes, multilabel=True) -@pytest.mark.parametrize("preds, target, sk_metric, num_classes", [ - (_input_binary_prob.preds, _input_binary_prob.target, _sk_roc_binary_prob, 1), - (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_roc_multiclass_prob, NUM_CLASSES), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_roc_multidim_multiclass_prob, NUM_CLASSES), - (_input_multilabel_prob.preds, _input_multilabel_prob.target, _sk_roc_multilabel_prob, NUM_CLASSES), - ( - _input_multilabel_multidim_prob.preds, - _input_multilabel_multidim_prob.target, - _sk_roc_multilabel_multidim_prob, - NUM_CLASSES - ) -]) +@pytest.mark.parametrize( + "preds, target, sk_metric, num_classes", + [(_input_binary_prob.preds, _input_binary_prob.target, _sk_roc_binary_prob, 1), + (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_roc_multiclass_prob, NUM_CLASSES), + (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_roc_multidim_multiclass_prob, NUM_CLASSES), + (_input_multilabel_prob.preds, _input_multilabel_prob.target, _sk_roc_multilabel_prob, NUM_CLASSES), + ( + _input_multilabel_multidim_prob.preds, _input_multilabel_multidim_prob.target, + _sk_roc_multilabel_multidim_prob, NUM_CLASSES + )] +) class TestROC(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index 6ae98f20652..62cadd5d95e 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -110,7 +110,7 @@ def _class_test( metric = metric_class( compute_on_step=check_dist_sync_on_step or check_batch, dist_sync_on_step=dist_sync_on_step, - **metric_args + **metric_args, ) # verify metrics work after being loaded from pickled state diff --git a/torchmetrics/wrappers/bootstrapping.py b/torchmetrics/wrappers/bootstrapping.py index 9e7a6429f0a..1e481d3ef25 100644 --- a/torchmetrics/wrappers/bootstrapping.py +++ b/torchmetrics/wrappers/bootstrapping.py @@ -41,12 +41,13 @@ def _bootstrap_sampler( torch.ones(tensor.shape[0], device=tensor.device), num_samples=size, replacement=True, - generator=generator + generator=generator, ) return tensor[idx] class BootStrapper(Metric): + def __init__( self, base_metric: Metric, @@ -107,15 +108,12 @@ def __init__( tensor(0.2175) tensor(0.0950) """ - super().__init__( - compute_on_step, - dist_sync_on_step, - process_group, - dist_sync_fn - ) + super().__init__(compute_on_step, dist_sync_on_step, process_group, dist_sync_fn) if not isinstance(base_metric, Metric): - raise ValueError("Expected base metric to be an instance of torchmetrics.Metric" - f" but received {base_metric}") + raise ValueError( + "Expected base metric to be an instance of torchmetrics.Metric" + f" but received {base_metric}" + ) self.metrics = nn.ModuleList([deepcopy(base_metric) for _ in range(num_bootstraps)]) self.num_bootstraps = num_bootstraps From a3e9b40b72c86b7ad7e44a9937a7345f2015cbfc Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 24 Mar 2021 20:28:19 +0100 Subject: [PATCH 24/36] update --- torchmetrics/wrappers/bootstrapping.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torchmetrics/wrappers/bootstrapping.py b/torchmetrics/wrappers/bootstrapping.py index 1e481d3ef25..d88cd5b6a2b 100644 --- a/torchmetrics/wrappers/bootstrapping.py +++ b/torchmetrics/wrappers/bootstrapping.py @@ -133,9 +133,7 @@ def __init__( self.generator = generator def update(self, *args: Any, **kwargs: Any) -> None: - """ Updates the state of the base metric. Any tensor passed in will be bootstrapped - along dimension 0 - """ + """ Updates the state of the base metric. Any tensor passed in will be bootstrapped along dimension 0 """ for idx in range(self.num_bootstraps): new_args = apply_to_collection(args, Tensor, _bootstrap_sampler, generator=self.generator) new_kwargs = apply_to_collection(kwargs, Tensor, _bootstrap_sampler, generator=self.generator) From 7b211eac0d76d35f4f234dd1e69a2971b3db6d96 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 25 Mar 2021 11:06:55 +0100 Subject: [PATCH 25/36] add poisson --- tests/wrappers/test_bootstrapping.py | 20 +++++++-- torchmetrics/wrappers/bootstrapping.py | 60 +++++++++++++++++++------- 2 files changed, 61 insertions(+), 19 deletions(-) diff --git a/tests/wrappers/test_bootstrapping.py b/tests/wrappers/test_bootstrapping.py index 7fa29d39694..2759ccf91d8 100644 --- a/tests/wrappers/test_bootstrapping.py +++ b/tests/wrappers/test_bootstrapping.py @@ -39,12 +39,13 @@ def update(self, *args, **kwargs): self.out.append(new_args) -def test_bootstrap_sampler(): +@pytest.mark.parametrize("sampling_strategy", ['poisson', 'multinomial']) +def test_bootstrap_sampler(sampling_strategy): """ make sure that the bootstrap sampler works as intended """ old_samples = torch.randn(5, 2) # make sure that the new samples are only made up of old samples - new_samples = _bootstrap_sampler(old_samples) + new_samples = _bootstrap_sampler(old_samples, sampling_strategy=sampling_strategy) for ns in new_samples: assert ns in old_samples @@ -52,11 +53,22 @@ def test_bootstrap_sampler(): found_one = False for os in old_samples: cond = os == new_samples - print(cond.sum()) if cond.sum() > 2: found_one = True + break + assert found_one, "resampling did not work because no samples were sampled twice" - + + # make sure some samples are never sampled + found_zero = False + for os in old_samples: + cond = os != new_samples + if cond.sum() > 0: + found_zero = True + break + + assert found_zero, "resampling did not work because all samples were atleast sampled once" + @pytest.mark.parametrize( "metric, sk_metric", [[Precision(average='micro'), precision_score], [Recall(average='micro'), recall_score]] diff --git a/torchmetrics/wrappers/bootstrapping.py b/torchmetrics/wrappers/bootstrapping.py index 673d675ceb0..b450ea382a7 100644 --- a/torchmetrics/wrappers/bootstrapping.py +++ b/torchmetrics/wrappers/bootstrapping.py @@ -23,27 +23,38 @@ def _bootstrap_sampler( - tensor: Tensor, size: Optional[int] = None, generator: Optional[torch.Generator] = None + tensor: Tensor, + size: Optional[int] = None, + generator: Optional[torch.Generator] = None, + sampling_strategy: str = 'poisson' ) -> Tensor: """ Resample a tensor along its first dimension with replacement Args: tensor: tensor to resample - size: number of samples in new tensor. Defauls to same size as input tensor + size: number of samples in new tensor. Defauls to same size as input tensor. Only applies when + sampling strategy is ``'multinomial'`` generator: a instance of ``torch.Generator`` that controls the sampling + sampling_strategy: the strategy to use for sampling, either ``'poisson'`` or ``'multinomial'`` Returns: resampled tensor """ - if size is None: - size = tensor.shape[0] - idx = torch.multinomial( - torch.ones(tensor.shape[0], device=tensor.device), - num_samples=size, - replacement=True, - generator=generator - ) - return tensor[idx] + if sampling_strategy == 'poisson': + p = torch.distributions.Poisson(1) + n = p.sample((tensor.shape[0],)) + return tensor.repeat_interleave(n.long(), dim=0) + elif sampling_strategy == 'multinomial': + if size is None: + size = tensor.shape[0] + idx = torch.multinomial( + torch.ones(tensor.shape[0], device=tensor.device), + num_samples=size, + replacement=True, + generator=generator + ) + return tensor[idx] + raise ValueError('Unknown sampling strategy') class BootStrapper(Metric): @@ -55,13 +66,14 @@ def __init__( std: bool = True, quantile: Optional[Union[float, Tensor]] = None, raw: bool = False, + sampling_strategy: str = 'poisson', generator: Optional[torch.Generator] = None, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, dist_sync_fn: Callable = None ) -> None: - """ + r""" Use to turn a metric into a bootstrapped metric that can automate the process of getting confidence intervals for metric values. This wrapper class basically keeps multiple copies of the same base metric in memory and whenever ``update`` or ``forward`` is called, all input tensors are resampled @@ -81,6 +93,12 @@ def __init__( pytorch version 1.6 or higher raw: if ``True``, return all bootstrapped values + sampling_strategy: + Determines how to produce bootstrapped samplings. Either ``'poisson'`` or ``multinomial``. + If ``'possion'`` is chosen, the number of times each sample will be included in the bootstrap + will be given by :math:`n~Poisson(1)`, which approximates the true bootstrap distribution when + the number of samples is large. If ``'multinomial'`` is chosen, we will apply true bootstrapping + at the batch level to approximate bootstrapping over the hole dataset. generator: A pytorch random number generator for the bootstrap sampler compute_on_step: @@ -124,10 +142,18 @@ def __init__( self.mean = mean self.std = std if quantile is not None and not _TORCH_GREATER_EQUAL_1_7: - raise ValueError('quantile argument can only be used with pytorch v1.6 or higher') + raise ValueError('quantile argument can only be used with pytorch v1.7 or higher') self.quantile = quantile self.raw = raw + allowed_sampling = ('poisson', 'multinomial') + if sampling_strategy not in allowed_sampling: + raise ValueError( + f"Expected argument ``sampling_strategy`` to be one of {allowed_sampling}" + f" but recieved {sampling_strategy}" + ) + self.sampling_strategy = sampling_strategy + if generator is not None and not isinstance(generator, torch.Generator): raise ValueError( "Expected argument ``generator`` to be an instance of ``torch.Generator``" @@ -140,8 +166,12 @@ def update(self, *args: Any, **kwargs: Any) -> None: along dimension 0 """ for idx in range(self.num_bootstraps): - new_args = apply_to_collection(args, Tensor, _bootstrap_sampler, generator=self.generator) - new_kwargs = apply_to_collection(kwargs, Tensor, _bootstrap_sampler, generator=self.generator) + new_args = apply_to_collection( + args, Tensor, _bootstrap_sampler, generator=self.generator, sampling_strategy=self.sampling_strategy + ) + new_kwargs = apply_to_collection( + kwargs, Tensor, _bootstrap_sampler, generator=self.generator, sampling_strategy=self.sampling_strategy + ) self.metrics[idx].update(*new_args, **new_kwargs) def compute(self) -> List[Tensor]: From e80f0ce363ec8bdaaf939e91f58cf0aab5ab1999 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 25 Mar 2021 15:10:57 +0100 Subject: [PATCH 26/36] pep8 --- tests/classification/test_roc.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/tests/classification/test_roc.py b/tests/classification/test_roc.py index ef674e05522..c9b95809dde 100644 --- a/tests/classification/test_roc.py +++ b/tests/classification/test_roc.py @@ -82,17 +82,16 @@ def _sk_roc_multilabel_multidim_prob(preds, target, num_classes=1): return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes, multilabel=True) -@pytest.mark.parametrize( - "preds, target, sk_metric, num_classes", - [(_input_binary_prob.preds, _input_binary_prob.target, _sk_roc_binary_prob, 1), - (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_roc_multiclass_prob, NUM_CLASSES), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_roc_multidim_multiclass_prob, NUM_CLASSES), - (_input_multilabel_prob.preds, _input_multilabel_prob.target, _sk_roc_multilabel_prob, NUM_CLASSES), - ( - _input_multilabel_multidim_prob.preds, _input_multilabel_multidim_prob.target, - _sk_roc_multilabel_multidim_prob, NUM_CLASSES - )] -) +@pytest.mark.parametrize("preds, target, sk_metric, num_classes", [ + (_input_binary_prob.preds, _input_binary_prob.target, _sk_roc_binary_prob, 1), + (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_roc_multiclass_prob, NUM_CLASSES), + (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_roc_multidim_multiclass_prob, NUM_CLASSES), + (_input_multilabel_prob.preds, _input_multilabel_prob.target, _sk_roc_multilabel_prob, NUM_CLASSES), + ( + _input_multilabel_multidim_prob.preds, _input_multilabel_multidim_prob.target, + _sk_roc_multilabel_multidim_prob, NUM_CLASSES + ) +]) class TestROC(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) From cbf8a67d8da9337046df3c639e9e2c38270d191c Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 25 Mar 2021 15:18:06 +0100 Subject: [PATCH 27/36] revert --- tests/classification/test_roc.py | 38 ++++++++++++++++++++++---------- 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/tests/classification/test_roc.py b/tests/classification/test_roc.py index c9b95809dde..99895e6c855 100644 --- a/tests/classification/test_roc.py +++ b/tests/classification/test_roc.py @@ -73,25 +73,39 @@ def _sk_roc_multidim_multiclass_prob(preds, target, num_classes=1): def _sk_roc_multilabel_prob(preds, target, num_classes=1): sk_preds = preds.numpy() sk_target = target.numpy() - return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes, multilabel=True) + return _sk_roc_curve( + y_true=sk_target, + probas_pred=sk_preds, + num_classes=num_classes, + multilabel=True + ) def _sk_roc_multilabel_multidim_prob(preds, target, num_classes=1): sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() sk_target = target.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() - return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes, multilabel=True) + return _sk_roc_curve( + y_true=sk_target, + probas_pred=sk_preds, + num_classes=num_classes, + multilabel=True + ) -@pytest.mark.parametrize("preds, target, sk_metric, num_classes", [ - (_input_binary_prob.preds, _input_binary_prob.target, _sk_roc_binary_prob, 1), - (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_roc_multiclass_prob, NUM_CLASSES), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_roc_multidim_multiclass_prob, NUM_CLASSES), - (_input_multilabel_prob.preds, _input_multilabel_prob.target, _sk_roc_multilabel_prob, NUM_CLASSES), - ( - _input_multilabel_multidim_prob.preds, _input_multilabel_multidim_prob.target, - _sk_roc_multilabel_multidim_prob, NUM_CLASSES - ) -]) +@pytest.mark.parametrize( + "preds, target, sk_metric, num_classes", [ + (_input_binary_prob.preds, _input_binary_prob.target, _sk_roc_binary_prob, 1), + (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_roc_multiclass_prob, NUM_CLASSES), + (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_roc_multidim_multiclass_prob, NUM_CLASSES), + (_input_multilabel_prob.preds, _input_multilabel_prob.target, _sk_roc_multilabel_prob, NUM_CLASSES), + ( + _input_multilabel_multidim_prob.preds, + _input_multilabel_multidim_prob.target, + _sk_roc_multilabel_multidim_prob, + NUM_CLASSES + ) + ] +) class TestROC(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) From b0cb0d77d290c03e57fdefb2a49786d8ac945ad0 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 25 Mar 2021 17:22:32 +0100 Subject: [PATCH 28/36] link --- torchmetrics/wrappers/bootstrapping.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/torchmetrics/wrappers/bootstrapping.py b/torchmetrics/wrappers/bootstrapping.py index 0d4ce2dc7ca..c4bba0877a4 100644 --- a/torchmetrics/wrappers/bootstrapping.py +++ b/torchmetrics/wrappers/bootstrapping.py @@ -67,10 +67,10 @@ def __init__( dist_sync_fn: Callable = None ) -> None: r""" - Use to turn a metric into a bootstrapped metric that can automate the process of getting confidence - intervals for metric values. This wrapper class basically keeps multiple copies of the same base metric - in memory and whenever ``update`` or ``forward`` is called, all input tensors are resampled - (with replacement) along the first dimension. + Use to turn a metric into a `bootstrapped `_ + metric that can automate the process of getting confidence intervals for metric values. This wrapper + class basically keeps multiple copies of the same base metric in memory and whenever ``update`` or + ``forward`` is called, all input tensors are resampled (with replacement) along the first dimension. Args: base_metric: @@ -89,8 +89,8 @@ def __init__( sampling_strategy: Determines how to produce bootstrapped samplings. Either ``'poisson'`` or ``multinomial``. If ``'possion'`` is chosen, the number of times each sample will be included in the bootstrap - will be given by :math:`n~Poisson(1)`, which approximates the true bootstrap distribution when - the number of samples is large. If ``'multinomial'`` is chosen, we will apply true bootstrapping + will be given by :math:`n\sim Poisson(\lambda=1)`, which approximates the true bootstrap distribution + when the number of samples is large. If ``'multinomial'`` is chosen, we will apply true bootstrapping at the batch level to approximate bootstrapping over the hole dataset. compute_on_step: Forward only calls ``update()`` and return ``None`` if this is set to ``False``. @@ -106,13 +106,14 @@ def __init__( Example:: >>> from torchmetrics import Accuracy, BootStrapper + >>> _ = torch.manual_seed(123) >>> base_metric = Accuracy() >>> bootstrap = BootStrapper(base_metric, num_bootstraps=20) >>> bootstrap.update(torch.randint(5, (20,)), torch.randint(5, (20,))) >>> output = bootstrap.compute() >>> mean, std = output >>> print(mean, std) - tensor(0.2175) tensor(0.0950) + tensor(0.2205) tensor(0.0859) """ super().__init__(compute_on_step, dist_sync_on_step, process_group, dist_sync_fn) From e7922ac6be84aef50b496bd6d689b094aea835fa Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 25 Mar 2021 17:23:47 +0100 Subject: [PATCH 29/36] isort --- tests/wrappers/test_bootstrapping.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/wrappers/test_bootstrapping.py b/tests/wrappers/test_bootstrapping.py index 5cd95dbee39..e3eff658bd4 100644 --- a/tests/wrappers/test_bootstrapping.py +++ b/tests/wrappers/test_bootstrapping.py @@ -14,8 +14,8 @@ import numpy as np import pytest import torch -from torch import Tensor from sklearn.metrics import precision_score, recall_score +from torch import Tensor from torchmetrics.classification import Precision, Recall from torchmetrics.utilities import apply_to_collection From 6b7ebf8da5ddf385684d8b3f951dc4395f0b9ac0 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 25 Mar 2021 17:26:21 +0100 Subject: [PATCH 30/36] roc changes remove --- tests/classification/test_roc.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/classification/test_roc.py b/tests/classification/test_roc.py index caa7d5f00fe..99895e6c855 100644 --- a/tests/classification/test_roc.py +++ b/tests/classification/test_roc.py @@ -22,8 +22,7 @@ from tests.classification.inputs import _input_binary_prob from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from tests.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob -from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob +from tests.classification.inputs import _input_multilabel_multidim_prob, _input_multilabel_prob from tests.helpers.testers import NUM_CLASSES, MetricTester from torchmetrics.classification.roc import ROC from torchmetrics.functional import roc From ed825d5aab103260f0ed4856758dcc4e0f20fd1e Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 25 Mar 2021 17:28:13 +0100 Subject: [PATCH 31/36] fix --- tests/classification/test_roc.py | 35 +++++++++----------------------- 1 file changed, 10 insertions(+), 25 deletions(-) diff --git a/tests/classification/test_roc.py b/tests/classification/test_roc.py index 99895e6c855..a8f21b7dad2 100644 --- a/tests/classification/test_roc.py +++ b/tests/classification/test_roc.py @@ -22,7 +22,8 @@ from tests.classification.inputs import _input_binary_prob from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from tests.classification.inputs import _input_multilabel_multidim_prob, _input_multilabel_prob +from tests.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob +from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob from tests.helpers.testers import NUM_CLASSES, MetricTester from torchmetrics.classification.roc import ROC from torchmetrics.functional import roc @@ -73,38 +74,22 @@ def _sk_roc_multidim_multiclass_prob(preds, target, num_classes=1): def _sk_roc_multilabel_prob(preds, target, num_classes=1): sk_preds = preds.numpy() sk_target = target.numpy() - return _sk_roc_curve( - y_true=sk_target, - probas_pred=sk_preds, - num_classes=num_classes, - multilabel=True - ) + return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes, multilabel=True) def _sk_roc_multilabel_multidim_prob(preds, target, num_classes=1): sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() sk_target = target.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() - return _sk_roc_curve( - y_true=sk_target, - probas_pred=sk_preds, - num_classes=num_classes, - multilabel=True - ) + return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes, multilabel=True) @pytest.mark.parametrize( - "preds, target, sk_metric, num_classes", [ - (_input_binary_prob.preds, _input_binary_prob.target, _sk_roc_binary_prob, 1), - (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_roc_multiclass_prob, NUM_CLASSES), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_roc_multidim_multiclass_prob, NUM_CLASSES), - (_input_multilabel_prob.preds, _input_multilabel_prob.target, _sk_roc_multilabel_prob, NUM_CLASSES), - ( - _input_multilabel_multidim_prob.preds, - _input_multilabel_multidim_prob.target, - _sk_roc_multilabel_multidim_prob, - NUM_CLASSES - ) - ] + "preds, target, sk_metric, num_classes", + [(_input_binary_prob.preds, _input_binary_prob.target, _sk_roc_binary_prob, 1), + (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_roc_multiclass_prob, NUM_CLASSES), + (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_roc_multidim_multiclass_prob, NUM_CLASSES), + (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_roc_multilabel_prob, NUM_CLASSES), + (_input_mlmd_prob.preds, _input_mlmd_prob.target, _sk_roc_multilabel_multidim_prob, NUM_CLASSES)] ) class TestROC(MetricTester): From 944a1310bb5ca412c5fc5cf7f4dffb5295333075 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 26 Mar 2021 09:52:40 +0100 Subject: [PATCH 32/36] fix tests --- torchmetrics/utilities/imports.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchmetrics/utilities/imports.py b/torchmetrics/utilities/imports.py index b4780076c36..670fbed5714 100644 --- a/torchmetrics/utilities/imports.py +++ b/torchmetrics/utilities/imports.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Import utilities""" +import operator from distutils.version import LooseVersion from importlib import import_module from importlib.util import find_spec From cf527e90882b23dac826e3980a59d49d07578591 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 26 Mar 2021 09:56:01 +0100 Subject: [PATCH 33/36] pep8 --- torchmetrics/utilities/imports.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmetrics/utilities/imports.py b/torchmetrics/utilities/imports.py index 670fbed5714..b412e7f1914 100644 --- a/torchmetrics/utilities/imports.py +++ b/torchmetrics/utilities/imports.py @@ -17,7 +17,7 @@ from importlib import import_module from importlib.util import find_spec -import torch +import torch # noqa: F401 from pkg_resources import DistributionNotFound @@ -61,7 +61,7 @@ def _compare_version(package: str, op, version) -> bool: return True return op(pkg_version, LooseVersion(version)) - + _TORCH_LOWER_1_4 = _compare_version("torch", operator.lt, "1.4.0") _TORCH_LOWER_1_5 = _compare_version("torch", operator.lt, "1.5.0") _TORCH_LOWER_1_6 = _compare_version("torch", operator.lt, "1.6.0") From 0987e8cbd3e7859950be79f96416b34311fb8c76 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 26 Mar 2021 10:24:29 +0100 Subject: [PATCH 34/36] Apply suggestions from code review --- tests/wrappers/test_bootstrapping.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/wrappers/test_bootstrapping.py b/tests/wrappers/test_bootstrapping.py index e3eff658bd4..20a0a8350e2 100644 --- a/tests/wrappers/test_bootstrapping.py +++ b/tests/wrappers/test_bootstrapping.py @@ -80,9 +80,9 @@ def test_bootstrap(sampling_strategy, metric, sk_metric): """ Test that the different bootstraps gets updated as we expected and that the compute method works """ _kwargs = {'base_metric': metric, 'mean': True, 'std': True, 'raw': True, 'sampling_strategy': sampling_strategy} if _TORCH_GREATER_EQUAL_1_7: - bootstrapper = TestBootStrapper(**_kwargs, quantile=torch.tensor([0.05, 0.95])) - else: - bootstrapper = TestBootStrapper(**_kwargs) + _kwargs.update(dict(quantile=torch.tensor([0.05, 0.95]))) + + bootstrapper = TestBootStrapper(**_kwargs) collected_preds = [[] for _ in range(10)] collected_target = [[] for _ in range(10)] From 7997a03ad0db485bbe200132d5de37925dd56a1a Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 26 Mar 2021 11:35:19 +0100 Subject: [PATCH 35/36] suggestions --- tests/wrappers/test_bootstrapping.py | 45 ++++++++++++-------------- torchmetrics/wrappers/bootstrapping.py | 25 +++++++------- 2 files changed, 32 insertions(+), 38 deletions(-) diff --git a/tests/wrappers/test_bootstrapping.py b/tests/wrappers/test_bootstrapping.py index 20a0a8350e2..afea77d4267 100644 --- a/tests/wrappers/test_bootstrapping.py +++ b/tests/wrappers/test_bootstrapping.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import operator + import numpy as np import pytest import torch @@ -40,6 +42,16 @@ def update(self, *args) -> None: self.out.append(new_args) +def _sample_checker(old_samples, new_samples, op: operator, threshold: int): + found_one = False + for os in old_samples: + cond = op(os, new_samples) + if cond.sum() > threshold: + found_one = True + break + return found_one + + @pytest.mark.parametrize("sampling_strategy", ['poisson', 'multinomial']) def test_bootstrap_sampler(sampling_strategy): """ make sure that the bootstrap sampler works as intended """ @@ -51,24 +63,10 @@ def test_bootstrap_sampler(sampling_strategy): for ns in new_samples: assert ns in old_samples - # make sure some samples are also sampled twice - found_one = False - for os in old_samples: - cond = os == new_samples - if cond.sum() > 2: - found_one = True - break - + found_one = _sample_checker(old_samples, new_samples, operator.eq, 2) assert found_one, "resampling did not work because no samples were sampled twice" - # make sure some samples are never sampled - found_zero = False - for os in old_samples: - cond = os != new_samples - if cond.sum() > 0: - found_zero = True - break - + found_zero = _sample_checker(old_samples, new_samples, operator.ne, 0) assert found_zero, "resampling did not work because all samples were atleast sampled once" @@ -102,12 +100,9 @@ def test_bootstrap(sampling_strategy, metric, sk_metric): output = bootstrapper.compute() # quantile only avaible for pytorch v1.7 and forward if _TORCH_GREATER_EQUAL_1_7: - pl_mean, pl_std, pl_quantile, pl_raw = output - assert np.allclose(pl_quantile[0], np.quantile(sk_scores, 0.05)) - assert np.allclose(pl_quantile[1], np.quantile(sk_scores, 0.95)) - else: - pl_mean, pl_std, pl_raw = output - - assert np.allclose(pl_mean, np.mean(sk_scores)) - assert np.allclose(pl_std, np.std(sk_scores, ddof=1)) - assert np.allclose(pl_raw, sk_scores) + assert np.allclose(output['quantile'][0], np.quantile(sk_scores, 0.05)) + assert np.allclose(output['quantile'][1], np.quantile(sk_scores, 0.95)) + + assert np.allclose(output['mean'], np.mean(sk_scores)) + assert np.allclose(output['std'], np.std(sk_scores, ddof=1)) + assert np.allclose(output['raw'], sk_scores) diff --git a/torchmetrics/wrappers/bootstrapping.py b/torchmetrics/wrappers/bootstrapping.py index c4bba0877a4..759bf5f772e 100644 --- a/torchmetrics/wrappers/bootstrapping.py +++ b/torchmetrics/wrappers/bootstrapping.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy -from typing import Any, Callable, List, Optional, Union +from typing import Any, Callable, Dict, Optional, Union import torch from torch import Tensor, nn @@ -111,9 +111,8 @@ class basically keeps multiple copies of the same base metric in memory and when >>> bootstrap = BootStrapper(base_metric, num_bootstraps=20) >>> bootstrap.update(torch.randint(5, (20,)), torch.randint(5, (20,))) >>> output = bootstrap.compute() - >>> mean, std = output - >>> print(mean, std) - tensor(0.2205) tensor(0.0859) + >>> print(output) + {'mean': tensor(0.2205), 'std': tensor(0.0859)} """ super().__init__(compute_on_step, dist_sync_on_step, process_group, dist_sync_fn) @@ -157,18 +156,18 @@ def update(self, *args: Any, **kwargs: Any) -> None: new_kwargs = apply_to_collection(kwargs, Tensor, torch.index_select, dim=0, index=sample_idx) self.metrics[idx].update(*new_args, **new_kwargs) - def compute(self) -> List[Tensor]: - """ Computes the bootstrapped metric values. Allways returns a list of tensors, but the content of - the list depends on how the class was initialized + def compute(self) -> Dict[str, Tensor]: + """ Computes the bootstrapped metric values. Allways returns a dict of tensors, which can contain the + following keys: ``mean``, ``std``, ``quantile`` and ``raw`` depending on how the class was initialized """ computed_vals = torch.stack([m.compute() for m in self.metrics], dim=0) - output = [] + output_dict = {} if self.mean: - output.append(computed_vals.mean(dim=0)) + output_dict['mean'] = computed_vals.mean(dim=0) if self.std: - output.append(computed_vals.std(dim=0)) + output_dict['std'] = computed_vals.std(dim=0) if self.quantile is not None: - output.append(torch.quantile(computed_vals, self.quantile)) + output_dict['quantile'] = torch.quantile(computed_vals, self.quantile) if self.raw: - output.append(computed_vals) - return output + output_dict['raw'] = computed_vals + return output_dict From 03024d87e8b1e38809ef9fd473dcc8c7bfbde61d Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Sat, 27 Mar 2021 09:11:41 +0100 Subject: [PATCH 36/36] pprint --- torchmetrics/collections.py | 9 +++++---- torchmetrics/wrappers/bootstrapping.py | 3 ++- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/torchmetrics/collections.py b/torchmetrics/collections.py index cda4dacaa4d..224da83df88 100644 --- a/torchmetrics/collections.py +++ b/torchmetrics/collections.py @@ -49,6 +49,7 @@ class MetricCollection(nn.ModuleDict): Example: >>> # input as list >>> import torch + >>> from pprint import pprint >>> from torchmetrics import MetricCollection, Accuracy, Precision, Recall >>> target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2]) >>> preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2]) @@ -62,10 +63,10 @@ class MetricCollection(nn.ModuleDict): >>> metrics = MetricCollection({'micro_recall': Recall(num_classes=3, average='micro'), ... 'macro_recall': Recall(num_classes=3, average='macro')}) >>> same_metric = metrics.clone() - >>> metrics(preds, target) - {'micro_recall': tensor(0.1250), 'macro_recall': tensor(0.1111)} - >>> same_metric(preds, target) - {'micro_recall': tensor(0.1250), 'macro_recall': tensor(0.1111)} + >>> pprint(metrics(preds, target)) + {'macro_recall': tensor(0.1111), 'micro_recall': tensor(0.1250)} + >>> pprint(same_metric(preds, target)) + {'macro_recall': tensor(0.1111), 'micro_recall': tensor(0.1250)} >>> metrics.persistent() """ diff --git a/torchmetrics/wrappers/bootstrapping.py b/torchmetrics/wrappers/bootstrapping.py index 759bf5f772e..1bd8be2040b 100644 --- a/torchmetrics/wrappers/bootstrapping.py +++ b/torchmetrics/wrappers/bootstrapping.py @@ -105,13 +105,14 @@ class basically keeps multiple copies of the same base metric in memory and when will be used to perform the allgather. Example:: + >>> from pprint import pprint >>> from torchmetrics import Accuracy, BootStrapper >>> _ = torch.manual_seed(123) >>> base_metric = Accuracy() >>> bootstrap = BootStrapper(base_metric, num_bootstraps=20) >>> bootstrap.update(torch.randint(5, (20,)), torch.randint(5, (20,))) >>> output = bootstrap.compute() - >>> print(output) + >>> pprint(output) {'mean': tensor(0.2205), 'std': tensor(0.0859)} """