Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Investigate potential memory bug #902

Merged
merged 16 commits into from
Mar 23, 2022
Merged
2 changes: 2 additions & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ mypy>=0.790
phmdoctest>=1.1.1
pre-commit>=1.0

psutil
requests
fire

cloudpickle>=1.3
scikit-learn>=0.24
psutil
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
41 changes: 41 additions & 0 deletions tests/bases/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pickle
from collections import OrderedDict

import cloudpickle
import numpy as np
import psutil
import pytest
import torch
from torch import Tensor, nn, tensor
Expand Down Expand Up @@ -362,3 +364,42 @@ def device(self):
assert module.device == module.metric.device
if isinstance(module.metric.x, Tensor):
assert module.device == module.metric.x.device


@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("requires_grad", [True, False])
def test_constant_memory(device, requires_grad):
"""Checks that when updating a metric the memory does not increase."""
if not torch.cuda.is_available() and device == "cuda":
pytest.mark.skip("Test requires GPU support")

def get_memory_usage():
Borda marked this conversation as resolved.
Show resolved Hide resolved
if device == "cpu":
pid = os.getpid()
py = psutil.Process(pid)
return py.memory_info()[0] / 2.0 ** 30
else:
return torch.cuda.memory_allocated()

x = torch.randn(10, requires_grad=requires_grad, device=device)

# try update method
metric = DummyMetricSum().to(device)

metric.update(x.sum())
base_memory_level = get_memory_usage()

for _ in range(10):
metric.update(x.sum())
memory = get_memory_usage()
assert base_memory_level >= memory

# try forward method
metric = DummyMetricSum().cuda()
metric(x.sum())
base_memory_level = get_memory_usage()

for _ in range(10):
metric.update(x.sum())
memory = get_memory_usage()
assert base_memory_level >= memory