diff --git a/mmseg/utils/__init__.py b/mmseg/utils/__init__.py index 4b34f4c386..ed002c7de4 100644 --- a/mmseg/utils/__init__.py +++ b/mmseg/utils/__init__.py @@ -2,5 +2,9 @@ from .collect_env import collect_env from .logger import get_root_logger from .misc import find_latest_checkpoint +from .set_env import setup_multi_processes -__all__ = ['get_root_logger', 'collect_env', 'find_latest_checkpoint'] +__all__ = [ + 'get_root_logger', 'collect_env', 'find_latest_checkpoint', + 'setup_multi_processes' +] diff --git a/mmseg/utils/set_env.py b/mmseg/utils/set_env.py new file mode 100644 index 0000000000..b2d3aaf14b --- /dev/null +++ b/mmseg/utils/set_env.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import platform + +import cv2 +import torch.multiprocessing as mp + +from ..utils import get_root_logger + + +def setup_multi_processes(cfg): + """Setup multi-processing environment variables.""" + logger = get_root_logger() + + # set multi-process start method + if platform.system() != 'Windows': + mp_start_method = cfg.get('mp_start_method', None) + current_method = mp.get_start_method(allow_none=True) + if mp_start_method in ('fork', 'spawn', 'forkserver'): + logger.info( + f'Multi-processing start method `{mp_start_method}` is ' + f'different from the previous setting `{current_method}`.' + f'It will be force set to `{mp_start_method}`.') + mp.set_start_method(mp_start_method, force=True) + else: + logger.info( + f'Multi-processing start method is `{mp_start_method}`') + + # disable opencv multithreading to avoid system being overloaded + opencv_num_threads = cfg.get('opencv_num_threads', None) + if isinstance(opencv_num_threads, int): + logger.info(f'OpenCV num_threads is `{opencv_num_threads}`') + cv2.setNumThreads(opencv_num_threads) + else: + logger.info(f'OpenCV num_threads is `{cv2.getNumThreads}') + + if cfg.data.workers_per_gpu > 1: + # setup OMP threads + # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa + omp_num_threads = cfg.get('omp_num_threads', None) + if 'OMP_NUM_THREADS' not in os.environ: + if isinstance(omp_num_threads, int): + logger.info(f'OMP num threads is {omp_num_threads}') + os.environ['OMP_NUM_THREADS'] = str(omp_num_threads) + else: + logger.info(f'OMP num threads is {os.environ["OMP_NUM_THREADS"] }') + + # setup MKL threads + if 'MKL_NUM_THREADS' not in os.environ: + mkl_num_threads = cfg.get('mkl_num_threads', None) + if isinstance(mkl_num_threads, int): + logger.info(f'MKL num threads is {mkl_num_threads}') + os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads) + else: + logger.info(f'MKL num threads is {os.environ["MKL_NUM_THREADS"]}') diff --git a/tests/test_utils/test_set_env.py b/tests/test_utils/test_set_env.py new file mode 100644 index 0000000000..0af4424b1d --- /dev/null +++ b/tests/test_utils/test_set_env.py @@ -0,0 +1,85 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import multiprocessing as mp +import os +import platform + +import cv2 +import pytest +from mmcv import Config + +from mmseg.utils import setup_multi_processes + + +@pytest.mark.parametrize('workers_per_gpu', (0, 2)) +@pytest.mark.parametrize(('valid', 'env_cfg'), [(True, + dict( + mp_start_method='fork', + opencv_num_threads=0, + omp_num_threads=1, + mkl_num_threads=1)), + (False, + dict( + mp_start_method=1, + opencv_num_threads=0.1, + omp_num_threads='s', + mkl_num_threads='1'))]) +def test_setup_multi_processes(workers_per_gpu, valid, env_cfg): + # temp save system setting + sys_start_mehod = mp.get_start_method(allow_none=True) + sys_cv_threads = cv2.getNumThreads() + # pop and temp save system env vars + sys_omp_threads = os.environ.pop('OMP_NUM_THREADS', default=None) + sys_mkl_threads = os.environ.pop('MKL_NUM_THREADS', default=None) + + config = dict(data=dict(workers_per_gpu=workers_per_gpu)) + config.update(env_cfg) + cfg = Config(config) + setup_multi_processes(cfg) + + # test when cfg is valid and workers_per_gpu > 0 + # setup_multi_processes will work + if valid and workers_per_gpu > 0: + # test config without setting env + + assert os.getenv('OMP_NUM_THREADS') == str(env_cfg['omp_num_threads']) + assert os.getenv('MKL_NUM_THREADS') == str(env_cfg['mkl_num_threads']) + # when set to 0, the num threads will be 1 + assert cv2.getNumThreads() == env_cfg[ + 'opencv_num_threads'] if env_cfg['opencv_num_threads'] > 0 else 1 + if platform.system() != 'Windows': + assert mp.get_start_method() == env_cfg['mp_start_method'] + + # revert setting to avoid affecting other programs + if sys_start_mehod: + mp.set_start_method(sys_start_mehod, force=True) + cv2.setNumThreads(sys_cv_threads) + if sys_omp_threads: + os.environ['OMP_NUM_THREADS'] = sys_omp_threads + else: + os.environ.pop('OMP_NUM_THREADS') + if sys_mkl_threads: + os.environ['MKL_NUM_THREADS'] = sys_mkl_threads + else: + os.environ.pop('MKL_NUM_THREADS') + + elif valid and workers_per_gpu == 0: + + if platform.system() != 'Windows': + assert mp.get_start_method() == env_cfg['mp_start_method'] + assert cv2.getNumThreads() == env_cfg[ + 'opencv_num_threads'] if env_cfg['opencv_num_threads'] > 0 else 1 + assert 'OMP_NUM_THREADS' not in os.environ + assert 'MKL_NUM_THREADS' not in os.environ + if sys_start_mehod: + mp.set_start_method(sys_start_mehod, force=True) + cv2.setNumThreads(sys_cv_threads) + if sys_omp_threads: + os.environ['OMP_NUM_THREADS'] = sys_omp_threads + if sys_mkl_threads: + os.environ['MKL_NUM_THREADS'] = sys_mkl_threads + + else: + assert mp.get_start_method() == sys_start_mehod + assert cv2.getNumThreads() == sys_cv_threads + assert 'OMP_NUM_THREADS' not in os.environ + assert 'MKL_NUM_THREADS' not in os.environ diff --git a/tools/test.py b/tools/test.py index a9d88b8074..03d8754a90 100644 --- a/tools/test.py +++ b/tools/test.py @@ -16,6 +16,7 @@ from mmseg.apis import multi_gpu_test, single_gpu_test from mmseg.datasets import build_dataloader, build_dataset from mmseg.models import build_segmentor +from mmseg.utils import setup_multi_processes def parse_args(): @@ -124,6 +125,10 @@ def main(): cfg = mmcv.Config.fromfile(args.config) if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) + + # set multi-process settings + setup_multi_processes(cfg) + # set cudnn_benchmark if cfg.get('cudnn_benchmark', False): torch.backends.cudnn.benchmark = True diff --git a/tools/train.py b/tools/train.py index 81c7d854ea..70ca4c85c0 100644 --- a/tools/train.py +++ b/tools/train.py @@ -16,7 +16,7 @@ from mmseg.apis import init_random_seed, set_random_seed, train_segmentor from mmseg.datasets import build_dataset from mmseg.models import build_segmentor -from mmseg.utils import collect_env, get_root_logger +from mmseg.utils import collect_env, get_root_logger, setup_multi_processes def parse_args(): @@ -102,6 +102,10 @@ def main(): cfg = Config.fromfile(args.config) if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) + + # set multi-process settings + setup_multi_processes(cfg) + # set cudnn_benchmark if cfg.get('cudnn_benchmark', False): torch.backends.cudnn.benchmark = True