Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
mergify[bot] authored Mar 23, 2022
2 parents eebe982 + 8c052e7 commit e950d31
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 4 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed NaN or Inf results returned by `signal_distortion_ratio` ([#899](https://github.com/PyTorchLightning/metrics/pull/899))


- Fixed memory leak when using `update` method with tensor where `requires_grad=True` ([#902](https://github.com/PyTorchLightning/metrics/pull/902))


## [0.7.2] - 2022-02-10

Expand Down
1 change: 1 addition & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ mypy>=0.790
phmdoctest>=1.1.1
pre-commit>=1.0

psutil
requests
fire

Expand Down
43 changes: 43 additions & 0 deletions tests/bases/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pickle
from collections import OrderedDict

import cloudpickle
import numpy as np
import psutil
import pytest
import torch
from torch import Tensor, nn, tensor
Expand Down Expand Up @@ -362,3 +364,44 @@ def device(self):
assert module.device == module.metric.device
if isinstance(module.metric.x, Tensor):
assert module.device == module.metric.x.device


@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("requires_grad", [True, False])
def test_constant_memory(device, requires_grad):
"""Checks that when updating a metric the memory does not increase."""
if not torch.cuda.is_available() and device == "cuda":
pytest.skip("Test requires GPU support")

def get_memory_usage():
if device == "cpu":
pid = os.getpid()
py = psutil.Process(pid)
return py.memory_info()[0] / 2.0 ** 30
else:
return torch.cuda.memory_allocated()

x = torch.randn(10, requires_grad=requires_grad, device=device)

# try update method
metric = DummyMetricSum().to(device)

metric.update(x.sum())

# we allow for 5% flucturation due to measuring
base_memory_level = 1.05 * get_memory_usage()

for _ in range(10):
metric.update(x.sum())
memory = get_memory_usage()
assert base_memory_level >= memory, "memory increased above base level"

# try forward method
metric = DummyMetricSum().to(device)
metric(x.sum())
base_memory_level = get_memory_usage()

for _ in range(10):
metric.update(x.sum())
memory = get_memory_usage()
assert base_memory_level >= memory, "memory increased above base level"
2 changes: 1 addition & 1 deletion tests/retrieval/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ def run_precision_test_gpu(
metric_functional: Callable,
):
if not torch.cuda.is_available():
pytest.skip()
pytest.skip("Test requires GPU")

def metric_functional_ignore_indexes(preds, target, indexes):
return metric_functional(preds, target)
Expand Down
10 changes: 7 additions & 3 deletions torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def __init__(
self._update_called = False
self._to_sync = True
self._should_unsync = True
self._enable_grad = False

# initialize state
self._defaults: Dict[str, Union[List, Tensor]] = {}
Expand Down Expand Up @@ -236,8 +237,8 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
"HINT: Did you forget to call ``unsync`` ?."
)

with torch.no_grad():
self.update(*args, **kwargs)
# global accumulation
self.update(*args, **kwargs)

self._to_sync = self.dist_sync_on_step # type: ignore
# skip restore cache operation from compute as cache is stored below.
Expand All @@ -247,6 +248,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
cache = {attr: getattr(self, attr) for attr in self._defaults}

# call reset, update, compute, on single batch
self._enable_grad = True # allow grads for batch computation
self.reset()
self.update(*args, **kwargs)
self._forward_cache = self.compute()
Expand All @@ -259,6 +261,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
self._should_unsync = True
self._to_sync = True
self._computed = None
self._enable_grad = False

return self._forward_cache

Expand Down Expand Up @@ -294,7 +297,8 @@ def _wrap_update(self, update: Callable) -> Callable:
def wrapped_func(*args: Any, **kwargs: Any) -> Optional[Any]:
self._computed = None
self._update_called = True
return update(*args, **kwargs)
with torch.set_grad_enabled(self._enable_grad):
return update(*args, **kwargs)

return wrapped_func

Expand Down

0 comments on commit e950d31

Please sign in to comment.