Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[resubmit] [pytorch][PR] Fix for num_threads==1 in OpenMP "parallel for" #39533

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion aten/src/ATen/ParallelOpenMP.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,12 @@ inline void parallel_for(
#ifdef _OPENMP
std::atomic_flag err_flag = ATOMIC_FLAG_INIT;
std::exception_ptr eptr;
// Work around memory leak when using 1 thread in nested "omp parallel"
// caused by some buggy OpenMP versions and the fact that omp_in_parallel()
// returns false when omp_get_max_threads() == 1 inside nested "omp parallel"
// See issue gh-32284

#pragma omp parallel if (!omp_in_parallel() && ((end - begin) > grain_size))
#pragma omp parallel if (omp_get_max_threads() > 1 && !omp_in_parallel() && ((end - begin) > grain_size))
{
// choose number of tasks based on grain size and number of threads
// can't use num_threads clause due to bugs in GOMP's thread pool (See #32008)
Expand Down
1 change: 1 addition & 0 deletions test/run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
'test_overrides',
'test_jit_fuser_te',
'test_tensorexpr',
'test_openmp',
'distributed/rpc/faulty_agent/test_dist_autograd_spawn',
'distributed/rpc/faulty_agent/test_rpc_spawn',
'distributed/rpc/jit/test_dist_autograd_spawn',
Expand Down
69 changes: 69 additions & 0 deletions test/test_openmp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import collections
import unittest

import numpy as np
import torch
from torch.testing._internal.common_utils import (
TestCase, run_tests, TEST_WITH_ASAN)

try:
import psutil
HAS_PSUTIL = True
except ImportError:
HAS_PSUTIL = False

device = torch.device('cpu')


class Network(torch.nn.Module):
maxp1 = torch.nn.MaxPool2d(1, 1)

def forward(self, x):
return self.maxp1(x)


@unittest.skipIf(not HAS_PSUTIL, "Requires psutil to run")
@unittest.skipIf(TEST_WITH_ASAN, "Cannot test with ASAN")
class TestOpenMP_ParallelFor(TestCase):
batch = 20
channels = 1
side_dim = 80
x = torch.randn([batch, channels, side_dim, side_dim], device=device)
model = Network()
ncores = min(5, psutil.cpu_count(logical=False))

def func(self, runs):
p = psutil.Process()
# warm up for 5 runs, then things should be stable for the last 5
last_rss = collections.deque(maxlen=5)
for n in range(10):
for i in range(runs):
self.model(self.x)
last_rss.append(p.memory_info().rss)
return last_rss

def func_rss(self, runs):
last_rss = self.func(runs)
# Do a least-mean-squares fit of last_rss to a line
poly = np.polynomial.Polynomial.fit(
range(len(last_rss)), np.array(last_rss), 1)
coefs = poly.convert().coef
# The coefs are (b, m) for the line y = m * x + b that fits the data.
# If m == 0 it will not be present. Assert it is missing or < 1000.
self.assertTrue(len(coefs) < 2 or coefs[1] < 1000,
msg='memory did not stabilize, {}'.format(str(list(last_rss))))

def test_one_thread(self):
"""Make sure there is no memory leak with one thread: issue gh-32284
"""
torch.set_num_threads(1)
self.func_rss(300)

def test_n_threads(self):
"""Make sure there is no memory leak with many threads
"""
torch.set_num_threads(self.ncores)
self.func_rss(300)

if __name__ == '__main__':
run_tests()