diff --git a/CHANGELOG.md b/CHANGELOG.md index b86a52d6261..d4a44c74db0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -133,6 +133,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed NaN or Inf results returned by `signal_distortion_ratio` ([#899](https://github.com/PyTorchLightning/metrics/pull/899)) +- Fixed memory leak when using `update` method with tensor where `requires_grad=True` ([#902](https://github.com/PyTorchLightning/metrics/pull/902)) + ## [0.7.2] - 2022-02-10 diff --git a/requirements/test.txt b/requirements/test.txt index 2ed4e84ebd0..dfffc35433d 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -10,6 +10,7 @@ mypy>=0.790 phmdoctest>=1.1.1 pre-commit>=1.0 +psutil requests fire diff --git a/tests/bases/test_metric.py b/tests/bases/test_metric.py index bad93b2b0d7..30a26fa5edd 100644 --- a/tests/bases/test_metric.py +++ b/tests/bases/test_metric.py @@ -11,11 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import pickle from collections import OrderedDict import cloudpickle import numpy as np +import psutil import pytest import torch from torch import Tensor, nn, tensor @@ -362,3 +364,44 @@ def device(self): assert module.device == module.metric.device if isinstance(module.metric.x, Tensor): assert module.device == module.metric.x.device + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +@pytest.mark.parametrize("requires_grad", [True, False]) +def test_constant_memory(device, requires_grad): + """Checks that when updating a metric the memory does not increase.""" + if not torch.cuda.is_available() and device == "cuda": + pytest.skip("Test requires GPU support") + + def get_memory_usage(): + if device == "cpu": + pid = os.getpid() + py = psutil.Process(pid) + return py.memory_info()[0] / 2.0 ** 30 + else: + return torch.cuda.memory_allocated() + + x = torch.randn(10, requires_grad=requires_grad, device=device) + + # try update method + metric = DummyMetricSum().to(device) + + metric.update(x.sum()) + + # we allow for 5% flucturation due to measuring + base_memory_level = 1.05 * get_memory_usage() + + for _ in range(10): + metric.update(x.sum()) + memory = get_memory_usage() + assert base_memory_level >= memory, "memory increased above base level" + + # try forward method + metric = DummyMetricSum().to(device) + metric(x.sum()) + base_memory_level = get_memory_usage() + + for _ in range(10): + metric.update(x.sum()) + memory = get_memory_usage() + assert base_memory_level >= memory, "memory increased above base level" diff --git a/tests/retrieval/helpers.py b/tests/retrieval/helpers.py index 37bd83c5194..3d04cb4b98f 100644 --- a/tests/retrieval/helpers.py +++ b/tests/retrieval/helpers.py @@ -485,7 +485,7 @@ def run_precision_test_gpu( metric_functional: Callable, ): if not torch.cuda.is_available(): - pytest.skip() + pytest.skip("Test requires GPU") def metric_functional_ignore_indexes(preds, target, indexes): return metric_functional(preds, target) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 7508f438261..8efea6b72bf 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -144,6 +144,7 @@ def __init__( self._update_called = False self._to_sync = True self._should_unsync = True + self._enable_grad = False # initialize state self._defaults: Dict[str, Union[List, Tensor]] = {} @@ -236,8 +237,8 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: "HINT: Did you forget to call ``unsync`` ?." ) - with torch.no_grad(): - self.update(*args, **kwargs) + # global accumulation + self.update(*args, **kwargs) self._to_sync = self.dist_sync_on_step # type: ignore # skip restore cache operation from compute as cache is stored below. @@ -247,6 +248,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: cache = {attr: getattr(self, attr) for attr in self._defaults} # call reset, update, compute, on single batch + self._enable_grad = True # allow grads for batch computation self.reset() self.update(*args, **kwargs) self._forward_cache = self.compute() @@ -259,6 +261,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: self._should_unsync = True self._to_sync = True self._computed = None + self._enable_grad = False return self._forward_cache @@ -294,7 +297,8 @@ def _wrap_update(self, update: Callable) -> Callable: def wrapped_func(*args: Any, **kwargs: Any) -> Optional[Any]: self._computed = None self._update_called = True - return update(*args, **kwargs) + with torch.set_grad_enabled(self._enable_grad): + return update(*args, **kwargs) return wrapped_func