Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test util for checking stand-alone python scripts #1007

Merged
merged 11 commits into from
May 10, 2021
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions mmcv/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
track_parallel_progress, track_progress)
from .testing import (assert_attrs_equal, assert_dict_contains_subset,
assert_dict_has_keys, assert_is_norm_layer,
assert_keys_equal, assert_params_all_zeros)
assert_keys_equal, assert_params_all_zeros,
check_python_script)
from .timer import Timer, TimerError, check_time
from .version_utils import digit_version, get_git_hash

Expand All @@ -28,7 +29,7 @@
'Timer', 'TimerError', 'check_time', 'deprecated_api_warning',
'digit_version', 'get_git_hash', 'import_modules_from_strings',
'assert_dict_contains_subset', 'assert_attrs_equal',
'assert_dict_has_keys', 'assert_keys_equal'
'assert_dict_has_keys', 'assert_keys_equal', 'check_python_script'
]
else:
from .env import collect_env
Expand Down Expand Up @@ -57,5 +58,5 @@
'get_git_hash', 'import_modules_from_strings', 'jit', 'skip_no_elena',
'assert_dict_contains_subset', 'assert_attrs_equal',
'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer',
'assert_params_all_zeros'
'assert_params_all_zeros', 'check_python_script'
]
19 changes: 19 additions & 0 deletions mmcv/utils/testing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,25 @@
# Copyright (c) Open-MMLab.
import sys
from collections.abc import Iterable
from runpy import run_path
from shlex import split
from typing import Any, Dict, List
from unittest.mock import patch


def check_python_script(cmd):
"""Run the python cmd script with `__main__`. The difference between
`os.system` is that, this function exectues code in the current process, so
that it can be tracked by coverage tools. Currently it supports two forms:

- ./tests/data/scripts/hello.py zz
- python tests/data/scripts/hello.py zz
"""
args = split(cmd)
if args[0] == 'python':
args = args[1:]
with patch.object(sys, 'argv', args):
run_path(args[0], run_name='__main__')


def _any(judge_result):
Expand Down
24 changes: 24 additions & 0 deletions tests/data/scripts/hello.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/usr/bin/env python

import argparse
import warnings


def parse_args():
parser = argparse.ArgumentParser(description='Say hello.')
parser.add_argument('name', help='To whom.')

args = parser.parse_args()

return args


def main():
args = parse_args()
print(f'hello {args.name}!')
if args.name == 'lizz':
innerlee marked this conversation as resolved.
Show resolved Hide resolved
warnings.warn('I have a secret!')


if __name__ == '__main__':
main()
11 changes: 11 additions & 0 deletions tests/test_utils/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,14 @@ def test_assert_params_all_zeros():

nn.init.normal_(demo_module.weight, mean=0, std=0.01)
assert not mmcv.assert_params_all_zeros(demo_module)


def test_check_python_script(capsys):
mmcv.utils.check_python_script('./tests/data/scripts/hello.py zz')
captured = capsys.readouterr().out
assert captured == 'hello zz!\n'
mmcv.utils.check_python_script('./tests/data/scripts/hello.py lizz')
captured = capsys.readouterr().out
assert captured == 'hello lizz!\n'
with pytest.raises(SystemExit):
innerlee marked this conversation as resolved.
Show resolved Hide resolved
mmcv.utils.check_python_script('./tests/data/scripts/hello.py li zz')