Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

fix annotation key-error #806

Merged
merged 4 commits into from
Mar 7, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 25 additions & 37 deletions src/sdk/pynni/nni/smartparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,52 +82,40 @@ def function_choice(*funcs, name=None):

else:

def choice(options, name=None):
return options[_get_param('choice', name)]
def choice(options, name=None, key=None):
return options[_get_param('choice', key)]

def randint(upper, name=None):
return _get_param('randint', name)
def randint(upper, name=None, key=None):
return _get_param('randint', key)

def uniform(low, high, name=None):
return _get_param('uniform', name)
def uniform(low, high, name=None, key=None):
return _get_param('uniform', key)

def quniform(low, high, q, name=None):
return _get_param('quniform', name)
def quniform(low, high, q, name=None, key=None):
return _get_param('quniform', key)

def loguniform(low, high, name=None):
return _get_param('loguniform', name)
def loguniform(low, high, name=None, key=None):
return _get_param('loguniform', key)

def qloguniform(low, high, q, name=None):
return _get_param('qloguniform', name)
def qloguniform(low, high, q, name=None, key=None):
return _get_param('qloguniform', key)

def normal(mu, sigma, name=None):
return _get_param('normal', name)
def normal(mu, sigma, name=None, key=None):
return _get_param('normal', key)

def qnormal(mu, sigma, q, name=None):
return _get_param('qnormal', name)
def qnormal(mu, sigma, q, name=None, key=None):
return _get_param('qnormal', key)

def lognormal(mu, sigma, name=None):
return _get_param('lognormal', name)
def lognormal(mu, sigma, name=None, key=None):
return _get_param('lognormal', key)

def qlognormal(mu, sigma, q, name=None):
return _get_param('qlognormal', name)

def function_choice(funcs, name=None):
return funcs[_get_param('function_choice', name)]()

def _get_param(func, name):
# frames:
# layer 0: this function
# layer 1: the API function (caller of this function)
# layer 2: caller of the API function
frame = inspect.stack(0)[2]
filename = frame.filename
lineno = frame.lineno # NOTE: this is the lineno of caller's last argument
del frame # see official doc
module = inspect.getmodulename(filename)
if name is None:
name = '__line{:d}'.format(lineno)
key = '{}/{}/{}'.format(module, name, func)
def qlognormal(mu, sigma, q, name=None, key=None):
return _get_param('qlognormal', key)

def function_choice(funcs, name=None, key=None):
return funcs[_get_param('function_choice', key)]()

def _get_param(func, key):
if trial._params is None:
trial.get_next_parameter()
return trial.get_current_parameter(key)
5 changes: 4 additions & 1 deletion tools/nni_annotation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,15 @@ def generate_search_space(code_dir):
def _generate_file_search_space(path, module):
with open(path) as src:
try:
return search_space_generator.generate(module, src.read())
search_space, code = search_space_generator.generate(module, src.read())
except Exception as exc: # pylint: disable=broad-except
if exc.args:
raise RuntimeError(path + ' ' + '\n'.join(exc.args))
else:
raise RuntimeError('Failed to generate search space for %s: %r' % (path, exc))
with open(path, 'w') as dst:
dst.write(code)
return search_space


def expand_annotations(src_dir, dst_dir):
Expand Down
19 changes: 13 additions & 6 deletions tools/nni_annotation/search_space_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@


import ast
import astor

# pylint: disable=unidiomatic-typecheck

Expand All @@ -40,7 +41,7 @@
]


class SearchSpaceGenerator(ast.NodeVisitor):
class SearchSpaceGenerator(ast.NodeTransformer):
"""Generate search space from smart parater APIs"""

def __init__(self, module_name):
Expand All @@ -53,16 +54,16 @@ def visit_Call(self, node): # pylint: disable=invalid-name

# ignore if the function is not 'nni.*'
if type(node.func) is not ast.Attribute:
return
return node
if type(node.func.value) is not ast.Name:
return
return node
if node.func.value.id != 'nni':
return
return node

# ignore if its not a search space function (e.g. `report_final_result`)
func = node.func.attr
if func not in _ss_funcs:
return
return node

self.last_line = node.lineno

Expand All @@ -77,6 +78,7 @@ def visit_Call(self, node): # pylint: disable=invalid-name
# generate the missing name automatically
name = '__line' + str(str(node.args[-1].lineno))
specified_name = False
node.keywords = list()

if func in ('choice', 'function_choice'):
# we will use keys in the dict as the choices, which is generated by code_generator according to the args given by user
Expand All @@ -89,6 +91,9 @@ def visit_Call(self, node): # pylint: disable=invalid-name
args = [arg.n for arg in node.args]

key = self.module_name + '/' + name + '/' + func
# store key in ast.Call
node.keywords.append(ast.keyword(arg='key', value=ast.Str(s=key)))

if func == 'function_choice':
func = 'choice'
value = {'_type': func, '_value': args}
Expand All @@ -103,6 +108,8 @@ def visit_Call(self, node): # pylint: disable=invalid-name

self.search_space[key] = value

return node


def generate(module_name, code):
"""Generate search space.
Expand All @@ -120,4 +127,4 @@ def generate(module_name, code):
visitor.visit(ast_tree)
except AssertionError as exc:
raise RuntimeError('%d: %s' % (visitor.last_line, exc.args[0]))
return visitor.search_space
return visitor.search_space, astor.to_source(ast_tree)