From 262f57cf7d1b4a5bc8df4bbcb850c4f5ff383f95 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 22 Mar 2022 13:55:29 +0100 Subject: [PATCH 01/10] test --- tests/bases/test_metric.py | 41 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/tests/bases/test_metric.py b/tests/bases/test_metric.py index bad93b2b0d7..de94cfc0b92 100644 --- a/tests/bases/test_metric.py +++ b/tests/bases/test_metric.py @@ -11,11 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import pickle from collections import OrderedDict import cloudpickle import numpy as np +import psutil import pytest import torch from torch import Tensor, nn, tensor @@ -362,3 +364,42 @@ def device(self): assert module.device == module.metric.device if isinstance(module.metric.x, Tensor): assert module.device == module.metric.x.device + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +@pytest.mark.parametrize("requires_grad", [True, False]) +def test_constant_memory(device, requires_grad): + """Checks that when updating a metric the memory does not increase.""" + if not torch.cuda.is_available() and device == "cuda": + pytest.mark.skip("Test requires GPU support") + + def get_memory_usage(): + if device == "cpu": + pid = os.getpid() + py = psutil.Process(pid) + return py.memory_info()[0] / 2.0 ** 30 + else: + return torch.cuda.memory_allocated() + + x = torch.randn(10, requires_grad=requires_grad, device=device) + + # try update method + metric = DummyMetricSum().to(device) + + metric.update(x.sum()) + base_memory_level = get_memory_usage() + + for _ in range(10): + metric.update(x.sum()) + memory = get_memory_usage() + assert base_memory_level >= memory + + # try forward method + metric = DummyMetricSum().cuda() + metric(x.sum()) + base_memory_level = get_memory_usage() + + for _ in range(10): + metric.update(x.sum()) + memory = get_memory_usage() + assert base_memory_level >= memory From 8570a959b1b6ed93a41e9446214ca9c66b155355 Mon Sep 17 00:00:00 2001 From: Jirka Date: Wed, 23 Mar 2022 06:05:55 +0100 Subject: [PATCH 02/10] psutil --- requirements/test.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/test.txt b/requirements/test.txt index 2ed4e84ebd0..dfffc35433d 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -10,6 +10,7 @@ mypy>=0.790 phmdoctest>=1.1.1 pre-commit>=1.0 +psutil requests fire From 91db4402e5762d1c0f4e08b3a9db56e17c524a44 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 23 Mar 2022 08:51:25 +0100 Subject: [PATCH 03/10] add psutil --- requirements/test.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/test.txt b/requirements/test.txt index 2ed4e84ebd0..cdbfaabdfda 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -15,3 +15,4 @@ fire cloudpickle>=1.3 scikit-learn>=0.24 +psutil From 1f08fd2df3e12d7bdc0e87e76c864608c891dc28 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Wed, 23 Mar 2022 08:52:53 +0100 Subject: [PATCH 04/10] Update requirements/test.txt --- requirements/test.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements/test.txt b/requirements/test.txt index eb46e5840ab..dfffc35433d 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -16,4 +16,3 @@ fire cloudpickle>=1.3 scikit-learn>=0.24 -psutil From e2b3d5a7dec6d5df56d49eaeb30cd31d6963b926 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 23 Mar 2022 09:16:12 +0100 Subject: [PATCH 05/10] fix skipping --- tests/bases/test_metric.py | 2 +- tests/retrieval/helpers.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/bases/test_metric.py b/tests/bases/test_metric.py index de94cfc0b92..1e0fb5593c9 100644 --- a/tests/bases/test_metric.py +++ b/tests/bases/test_metric.py @@ -371,7 +371,7 @@ def device(self): def test_constant_memory(device, requires_grad): """Checks that when updating a metric the memory does not increase.""" if not torch.cuda.is_available() and device == "cuda": - pytest.mark.skip("Test requires GPU support") + pytest.skip("Test requires GPU support") def get_memory_usage(): if device == "cpu": diff --git a/tests/retrieval/helpers.py b/tests/retrieval/helpers.py index 37bd83c5194..3d04cb4b98f 100644 --- a/tests/retrieval/helpers.py +++ b/tests/retrieval/helpers.py @@ -485,7 +485,7 @@ def run_precision_test_gpu( metric_functional: Callable, ): if not torch.cuda.is_available(): - pytest.skip() + pytest.skip("Test requires GPU") def metric_functional_ignore_indexes(preds, target, indexes): return metric_functional(preds, target) From ddd0fc5d2d75184fb888d058e31bc581f4884d5e Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 23 Mar 2022 09:36:21 +0100 Subject: [PATCH 06/10] fix test --- tests/bases/test_metric.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/bases/test_metric.py b/tests/bases/test_metric.py index 1e0fb5593c9..d3a6c054b87 100644 --- a/tests/bases/test_metric.py +++ b/tests/bases/test_metric.py @@ -392,14 +392,14 @@ def get_memory_usage(): for _ in range(10): metric.update(x.sum()) memory = get_memory_usage() - assert base_memory_level >= memory + assert base_memory_level >= memory, "memory increased above base level" # try forward method - metric = DummyMetricSum().cuda() + metric = DummyMetricSum().to(device) metric(x.sum()) base_memory_level = get_memory_usage() for _ in range(10): metric.update(x.sum()) memory = get_memory_usage() - assert base_memory_level >= memory + assert base_memory_level >= memory, "memory increased above base level" From 669da1a487fbfe5fd3124bd2ce853e92993d77a5 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 23 Mar 2022 10:42:45 +0100 Subject: [PATCH 07/10] fix --- torchmetrics/metric.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 7508f438261..ee6911ac13a 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -144,6 +144,7 @@ def __init__( self._update_called = False self._to_sync = True self._should_unsync = True + self._enable_grad = False # initialize state self._defaults: Dict[str, Union[List, Tensor]] = {} @@ -247,6 +248,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: cache = {attr: getattr(self, attr) for attr in self._defaults} # call reset, update, compute, on single batch + self._enable_grad = True # allow grads for batch computation self.reset() self.update(*args, **kwargs) self._forward_cache = self.compute() @@ -259,6 +261,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: self._should_unsync = True self._to_sync = True self._computed = None + self._enable_grad = False return self._forward_cache @@ -294,7 +297,8 @@ def _wrap_update(self, update: Callable) -> Callable: def wrapped_func(*args: Any, **kwargs: Any) -> Optional[Any]: self._computed = None self._update_called = True - return update(*args, **kwargs) + with torch.set_grad_enabled(self._enable_grad): + return update(*args, **kwargs) return wrapped_func From 7faa5f9c8f7626af565cdb574f577653b3199587 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 23 Mar 2022 10:48:29 +0100 Subject: [PATCH 08/10] remove no_grad --- torchmetrics/metric.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index ee6911ac13a..8efea6b72bf 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -237,8 +237,8 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: "HINT: Did you forget to call ``unsync`` ?." ) - with torch.no_grad(): - self.update(*args, **kwargs) + # global accumulation + self.update(*args, **kwargs) self._to_sync = self.dist_sync_on_step # type: ignore # skip restore cache operation from compute as cache is stored below. From c5d6501dd1b110a19a666ae7e6e56507c72762b7 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 23 Mar 2022 10:50:32 +0100 Subject: [PATCH 09/10] changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b86a52d6261..d4a44c74db0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -133,6 +133,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed NaN or Inf results returned by `signal_distortion_ratio` ([#899](https://github.com/PyTorchLightning/metrics/pull/899)) +- Fixed memory leak when using `update` method with tensor where `requires_grad=True` ([#902](https://github.com/PyTorchLightning/metrics/pull/902)) + ## [0.7.2] - 2022-02-10 From dce6fe6f20f55b9b553f355a9adda6c93fe72ea7 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 23 Mar 2022 11:25:52 +0100 Subject: [PATCH 10/10] memory --- tests/bases/test_metric.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/bases/test_metric.py b/tests/bases/test_metric.py index d3a6c054b87..30a26fa5edd 100644 --- a/tests/bases/test_metric.py +++ b/tests/bases/test_metric.py @@ -387,7 +387,9 @@ def get_memory_usage(): metric = DummyMetricSum().to(device) metric.update(x.sum()) - base_memory_level = get_memory_usage() + + # we allow for 5% flucturation due to measuring + base_memory_level = 1.05 * get_memory_usage() for _ in range(10): metric.update(x.sum())