diff --git a/src/sdk/pynni/nni/smartparam.py b/src/sdk/pynni/nni/smartparam.py index 2b2bba5812..725a2c4924 100644 --- a/src/sdk/pynni/nni/smartparam.py +++ b/src/sdk/pynni/nni/smartparam.py @@ -82,52 +82,40 @@ def function_choice(*funcs, name=None): else: - def choice(options, name=None): - return options[_get_param('choice', name)] + def choice(options, name=None, key=None): + return options[_get_param(key)] - def randint(upper, name=None): - return _get_param('randint', name) + def randint(upper, name=None, key=None): + return _get_param(key) - def uniform(low, high, name=None): - return _get_param('uniform', name) + def uniform(low, high, name=None, key=None): + return _get_param(key) - def quniform(low, high, q, name=None): - return _get_param('quniform', name) + def quniform(low, high, q, name=None, key=None): + return _get_param(key) - def loguniform(low, high, name=None): - return _get_param('loguniform', name) + def loguniform(low, high, name=None, key=None): + return _get_param(key) - def qloguniform(low, high, q, name=None): - return _get_param('qloguniform', name) + def qloguniform(low, high, q, name=None, key=None): + return _get_param(key) - def normal(mu, sigma, name=None): - return _get_param('normal', name) + def normal(mu, sigma, name=None, key=None): + return _get_param(key) - def qnormal(mu, sigma, q, name=None): - return _get_param('qnormal', name) + def qnormal(mu, sigma, q, name=None, key=None): + return _get_param(key) - def lognormal(mu, sigma, name=None): - return _get_param('lognormal', name) + def lognormal(mu, sigma, name=None, key=None): + return _get_param(key) - def qlognormal(mu, sigma, q, name=None): - return _get_param('qlognormal', name) - - def function_choice(funcs, name=None): - return funcs[_get_param('function_choice', name)]() - - def _get_param(func, name): - # frames: - # layer 0: this function - # layer 1: the API function (caller of this function) - # layer 2: caller of the API function - frame = inspect.stack(0)[2] - filename = frame.filename - lineno = frame.lineno # NOTE: this is the lineno of caller's last argument - del frame # see official doc - module = inspect.getmodulename(filename) - if name is None: - name = '__line{:d}'.format(lineno) - key = '{}/{}/{}'.format(module, name, func) + def qlognormal(mu, sigma, q, name=None, key=None): + return _get_param(key) + + def function_choice(funcs, name=None, key=None): + return funcs[_get_param(key)]() + + def _get_param(key): if trial._params is None: trial.get_next_parameter() return trial.get_current_parameter(key) diff --git a/src/sdk/pynni/tests/test_smartparam.py b/src/sdk/pynni/tests/test_smartparam.py index 3dda81bd1d..33fb783afc 100644 --- a/src/sdk/pynni/tests/test_smartparam.py +++ b/src/sdk/pynni/tests/test_smartparam.py @@ -29,8 +29,6 @@ from unittest import TestCase, main -lineno1 = 61 -lineno2 = 75 class SmartParamTestCase(TestCase): def setUp(self): @@ -39,43 +37,30 @@ def setUp(self): 'test_smartparam/choice2/choice': '3*2+1', 'test_smartparam/choice3/choice': '[1, 2]', 'test_smartparam/choice4/choice': '{"a", 2}', - 'test_smartparam/__line{:d}/uniform'.format(lineno1): '5', 'test_smartparam/func/function_choice': 'bar', - 'test_smartparam/lambda_func/function_choice': "lambda: 2*3", - 'test_smartparam/__line{:d}/function_choice'.format(lineno2): 'max(1, 2, 3)' + 'test_smartparam/lambda_func/function_choice': "lambda: 2*3" } nni.trial._params = { 'parameter_id': 'test_trial', 'parameters': params } def test_specified_name(self): - val = nni.choice({'a': 'a', '3*2+1': 3*2+1, '[1, 2]': [1, 2], '{"a", 2}': {"a", 2}}, name = 'choice1') + val = nni.choice({'a': 'a', '3*2+1': 3*2+1, '[1, 2]': [1, 2], '{"a", 2}': {"a", 2}}, name = 'choice1', key='test_smartparam/choice1/choice') self.assertEqual(val, 'a') - val = nni.choice({'a': 'a', '3*2+1': 3*2+1, '[1, 2]': [1, 2], '{"a", 2}': {"a", 2}}, name = 'choice2') + val = nni.choice({'a': 'a', '3*2+1': 3*2+1, '[1, 2]': [1, 2], '{"a", 2}': {"a", 2}}, name = 'choice2', key='test_smartparam/choice2/choice') self.assertEqual(val, 7) - val = nni.choice({'a': 'a', '3*2+1': 3*2+1, '[1, 2]': [1, 2], '{"a", 2}': {"a", 2}}, name = 'choice3') + val = nni.choice({'a': 'a', '3*2+1': 3*2+1, '[1, 2]': [1, 2], '{"a", 2}': {"a", 2}}, name = 'choice3', key='test_smartparam/choice3/choice') self.assertEqual(val, [1, 2]) - val = nni.choice({'a': 'a', '3*2+1': 3*2+1, '[1, 2]': [1, 2], '{"a", 2}': {"a", 2}}, name = 'choice4') + val = nni.choice({'a': 'a', '3*2+1': 3*2+1, '[1, 2]': [1, 2], '{"a", 2}': {"a", 2}}, name = 'choice4', key='test_smartparam/choice4/choice') self.assertEqual(val, {"a", 2}) - def test_default_name(self): - val = nni.uniform(1, 10) # NOTE: assign this line number to lineno1 - self.assertEqual(val, '5') - - def test_specified_name_func(self): - val = nni.function_choice({'foo': foo, 'bar': bar}, name = 'func') + def test_func(self): + val = nni.function_choice({'foo': foo, 'bar': bar}, name='func', key='test_smartparam/func/function_choice') self.assertEqual(val, 'bar') def test_lambda_func(self): - val = nni.function_choice({"lambda: 2*3": lambda: 2*3, "lambda: 3*4": lambda: 3*4}, name = 'lambda_func') + val = nni.function_choice({"lambda: 2*3": lambda: 2*3, "lambda: 3*4": lambda: 3*4}, name = 'lambda_func', key='test_smartparam/lambda_func/function_choice') self.assertEqual(val, 6) - def test_default_name_func(self): - val = nni.function_choice({ - 'max(1, 2, 3)': lambda: max(1, 2, 3), - 'min(1, 2)': lambda: min(1, 2) # NOTE: assign this line number to lineno2 - }) - self.assertEqual(val, 3) - def foo(): return 'foo' diff --git a/tools/nni_annotation/__init__.py b/tools/nni_annotation/__init__.py index b8b57d0585..5ee89d2454 100644 --- a/tools/nni_annotation/__init__.py +++ b/tools/nni_annotation/__init__.py @@ -59,12 +59,15 @@ def generate_search_space(code_dir): def _generate_file_search_space(path, module): with open(path) as src: try: - return search_space_generator.generate(module, src.read()) + search_space, code = search_space_generator.generate(module, src.read()) except Exception as exc: # pylint: disable=broad-except if exc.args: raise RuntimeError(path + ' ' + '\n'.join(exc.args)) else: raise RuntimeError('Failed to generate search space for %s: %r' % (path, exc)) + with open(path, 'w') as dst: + dst.write(code) + return search_space def expand_annotations(src_dir, dst_dir): diff --git a/tools/nni_annotation/search_space_generator.py b/tools/nni_annotation/search_space_generator.py index c560c649ab..833d989c1d 100644 --- a/tools/nni_annotation/search_space_generator.py +++ b/tools/nni_annotation/search_space_generator.py @@ -20,6 +20,7 @@ import ast +import astor # pylint: disable=unidiomatic-typecheck @@ -40,7 +41,7 @@ ] -class SearchSpaceGenerator(ast.NodeVisitor): +class SearchSpaceGenerator(ast.NodeTransformer): """Generate search space from smart parater APIs""" def __init__(self, module_name): @@ -53,16 +54,16 @@ def visit_Call(self, node): # pylint: disable=invalid-name # ignore if the function is not 'nni.*' if type(node.func) is not ast.Attribute: - return + return node if type(node.func.value) is not ast.Name: - return + return node if node.func.value.id != 'nni': - return + return node # ignore if its not a search space function (e.g. `report_final_result`) func = node.func.attr if func not in _ss_funcs: - return + return node self.last_line = node.lineno @@ -77,6 +78,7 @@ def visit_Call(self, node): # pylint: disable=invalid-name # generate the missing name automatically name = '__line' + str(str(node.args[-1].lineno)) specified_name = False + node.keywords = list() if func in ('choice', 'function_choice'): # we will use keys in the dict as the choices, which is generated by code_generator according to the args given by user @@ -89,6 +91,9 @@ def visit_Call(self, node): # pylint: disable=invalid-name args = [arg.n for arg in node.args] key = self.module_name + '/' + name + '/' + func + # store key in ast.Call + node.keywords.append(ast.keyword(arg='key', value=ast.Str(s=key))) + if func == 'function_choice': func = 'choice' value = {'_type': func, '_value': args} @@ -103,6 +108,8 @@ def visit_Call(self, node): # pylint: disable=invalid-name self.search_space[key] = value + return node + def generate(module_name, code): """Generate search space. @@ -120,4 +127,4 @@ def generate(module_name, code): visitor.visit(ast_tree) except AssertionError as exc: raise RuntimeError('%d: %s' % (visitor.last_line, exc.args[0])) - return visitor.search_space + return visitor.search_space, astor.to_source(ast_tree)