Skip to content

Commit

Permalink
Merge pull request microsoft#139 from Microsoft/master
Browse files Browse the repository at this point in the history
fix annotation key-error (microsoft#806)
  • Loading branch information
SparkSnail authored Mar 11, 2019
2 parents 21165b5 + 7108466 commit d25f7b5
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 67 deletions.
62 changes: 25 additions & 37 deletions src/sdk/pynni/nni/smartparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,52 +82,40 @@ def function_choice(*funcs, name=None):

else:

def choice(options, name=None):
return options[_get_param('choice', name)]
def choice(options, name=None, key=None):
return options[_get_param(key)]

def randint(upper, name=None):
return _get_param('randint', name)
def randint(upper, name=None, key=None):
return _get_param(key)

def uniform(low, high, name=None):
return _get_param('uniform', name)
def uniform(low, high, name=None, key=None):
return _get_param(key)

def quniform(low, high, q, name=None):
return _get_param('quniform', name)
def quniform(low, high, q, name=None, key=None):
return _get_param(key)

def loguniform(low, high, name=None):
return _get_param('loguniform', name)
def loguniform(low, high, name=None, key=None):
return _get_param(key)

def qloguniform(low, high, q, name=None):
return _get_param('qloguniform', name)
def qloguniform(low, high, q, name=None, key=None):
return _get_param(key)

def normal(mu, sigma, name=None):
return _get_param('normal', name)
def normal(mu, sigma, name=None, key=None):
return _get_param(key)

def qnormal(mu, sigma, q, name=None):
return _get_param('qnormal', name)
def qnormal(mu, sigma, q, name=None, key=None):
return _get_param(key)

def lognormal(mu, sigma, name=None):
return _get_param('lognormal', name)
def lognormal(mu, sigma, name=None, key=None):
return _get_param(key)

def qlognormal(mu, sigma, q, name=None):
return _get_param('qlognormal', name)

def function_choice(funcs, name=None):
return funcs[_get_param('function_choice', name)]()

def _get_param(func, name):
# frames:
# layer 0: this function
# layer 1: the API function (caller of this function)
# layer 2: caller of the API function
frame = inspect.stack(0)[2]
filename = frame.filename
lineno = frame.lineno # NOTE: this is the lineno of caller's last argument
del frame # see official doc
module = inspect.getmodulename(filename)
if name is None:
name = '__line{:d}'.format(lineno)
key = '{}/{}/{}'.format(module, name, func)
def qlognormal(mu, sigma, q, name=None, key=None):
return _get_param(key)

def function_choice(funcs, name=None, key=None):
return funcs[_get_param(key)]()

def _get_param(key):
if trial._params is None:
trial.get_next_parameter()
return trial.get_current_parameter(key)
31 changes: 8 additions & 23 deletions src/sdk/pynni/tests/test_smartparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@
from unittest import TestCase, main


lineno1 = 61
lineno2 = 75

class SmartParamTestCase(TestCase):
def setUp(self):
Expand All @@ -39,43 +37,30 @@ def setUp(self):
'test_smartparam/choice2/choice': '3*2+1',
'test_smartparam/choice3/choice': '[1, 2]',
'test_smartparam/choice4/choice': '{"a", 2}',
'test_smartparam/__line{:d}/uniform'.format(lineno1): '5',
'test_smartparam/func/function_choice': 'bar',
'test_smartparam/lambda_func/function_choice': "lambda: 2*3",
'test_smartparam/__line{:d}/function_choice'.format(lineno2): 'max(1, 2, 3)'
'test_smartparam/lambda_func/function_choice': "lambda: 2*3"
}
nni.trial._params = { 'parameter_id': 'test_trial', 'parameters': params }


def test_specified_name(self):
val = nni.choice({'a': 'a', '3*2+1': 3*2+1, '[1, 2]': [1, 2], '{"a", 2}': {"a", 2}}, name = 'choice1')
val = nni.choice({'a': 'a', '3*2+1': 3*2+1, '[1, 2]': [1, 2], '{"a", 2}': {"a", 2}}, name = 'choice1', key='test_smartparam/choice1/choice')
self.assertEqual(val, 'a')
val = nni.choice({'a': 'a', '3*2+1': 3*2+1, '[1, 2]': [1, 2], '{"a", 2}': {"a", 2}}, name = 'choice2')
val = nni.choice({'a': 'a', '3*2+1': 3*2+1, '[1, 2]': [1, 2], '{"a", 2}': {"a", 2}}, name = 'choice2', key='test_smartparam/choice2/choice')
self.assertEqual(val, 7)
val = nni.choice({'a': 'a', '3*2+1': 3*2+1, '[1, 2]': [1, 2], '{"a", 2}': {"a", 2}}, name = 'choice3')
val = nni.choice({'a': 'a', '3*2+1': 3*2+1, '[1, 2]': [1, 2], '{"a", 2}': {"a", 2}}, name = 'choice3', key='test_smartparam/choice3/choice')
self.assertEqual(val, [1, 2])
val = nni.choice({'a': 'a', '3*2+1': 3*2+1, '[1, 2]': [1, 2], '{"a", 2}': {"a", 2}}, name = 'choice4')
val = nni.choice({'a': 'a', '3*2+1': 3*2+1, '[1, 2]': [1, 2], '{"a", 2}': {"a", 2}}, name = 'choice4', key='test_smartparam/choice4/choice')
self.assertEqual(val, {"a", 2})

def test_default_name(self):
val = nni.uniform(1, 10) # NOTE: assign this line number to lineno1
self.assertEqual(val, '5')

def test_specified_name_func(self):
val = nni.function_choice({'foo': foo, 'bar': bar}, name = 'func')
def test_func(self):
val = nni.function_choice({'foo': foo, 'bar': bar}, name='func', key='test_smartparam/func/function_choice')
self.assertEqual(val, 'bar')

def test_lambda_func(self):
val = nni.function_choice({"lambda: 2*3": lambda: 2*3, "lambda: 3*4": lambda: 3*4}, name = 'lambda_func')
val = nni.function_choice({"lambda: 2*3": lambda: 2*3, "lambda: 3*4": lambda: 3*4}, name = 'lambda_func', key='test_smartparam/lambda_func/function_choice')
self.assertEqual(val, 6)

def test_default_name_func(self):
val = nni.function_choice({
'max(1, 2, 3)': lambda: max(1, 2, 3),
'min(1, 2)': lambda: min(1, 2) # NOTE: assign this line number to lineno2
})
self.assertEqual(val, 3)


def foo():
return 'foo'
Expand Down
5 changes: 4 additions & 1 deletion tools/nni_annotation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,15 @@ def generate_search_space(code_dir):
def _generate_file_search_space(path, module):
with open(path) as src:
try:
return search_space_generator.generate(module, src.read())
search_space, code = search_space_generator.generate(module, src.read())
except Exception as exc: # pylint: disable=broad-except
if exc.args:
raise RuntimeError(path + ' ' + '\n'.join(exc.args))
else:
raise RuntimeError('Failed to generate search space for %s: %r' % (path, exc))
with open(path, 'w') as dst:
dst.write(code)
return search_space


def expand_annotations(src_dir, dst_dir):
Expand Down
19 changes: 13 additions & 6 deletions tools/nni_annotation/search_space_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@


import ast
import astor

# pylint: disable=unidiomatic-typecheck

Expand All @@ -40,7 +41,7 @@
]


class SearchSpaceGenerator(ast.NodeVisitor):
class SearchSpaceGenerator(ast.NodeTransformer):
"""Generate search space from smart parater APIs"""

def __init__(self, module_name):
Expand All @@ -53,16 +54,16 @@ def visit_Call(self, node): # pylint: disable=invalid-name

# ignore if the function is not 'nni.*'
if type(node.func) is not ast.Attribute:
return
return node
if type(node.func.value) is not ast.Name:
return
return node
if node.func.value.id != 'nni':
return
return node

# ignore if its not a search space function (e.g. `report_final_result`)
func = node.func.attr
if func not in _ss_funcs:
return
return node

self.last_line = node.lineno

Expand All @@ -77,6 +78,7 @@ def visit_Call(self, node): # pylint: disable=invalid-name
# generate the missing name automatically
name = '__line' + str(str(node.args[-1].lineno))
specified_name = False
node.keywords = list()

if func in ('choice', 'function_choice'):
# we will use keys in the dict as the choices, which is generated by code_generator according to the args given by user
Expand All @@ -89,6 +91,9 @@ def visit_Call(self, node): # pylint: disable=invalid-name
args = [arg.n for arg in node.args]

key = self.module_name + '/' + name + '/' + func
# store key in ast.Call
node.keywords.append(ast.keyword(arg='key', value=ast.Str(s=key)))

if func == 'function_choice':
func = 'choice'
value = {'_type': func, '_value': args}
Expand All @@ -103,6 +108,8 @@ def visit_Call(self, node): # pylint: disable=invalid-name

self.search_space[key] = value

return node


def generate(module_name, code):
"""Generate search space.
Expand All @@ -120,4 +127,4 @@ def generate(module_name, code):
visitor.visit(ast_tree)
except AssertionError as exc:
raise RuntimeError('%d: %s' % (visitor.last_line, exc.args[0]))
return visitor.search_space
return visitor.search_space, astor.to_source(ast_tree)

0 comments on commit d25f7b5

Please sign in to comment.