Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Add search space validation for choice types #3975

Merged
merged 3 commits into from
Jul 26, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions nni/common/hpo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ def validate_search_space(
raise ValueError(f'search space "{name}"\'s value is not a list : {spec}')

if type_ == 'choice':
if not all(isinstance(arg, (float, int, str)) for arg in args):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will it complain for nested space?

# FIXME: need further check for each algorithm which types are actually supported
# for now validation only prints warning so it doesn't harm
raise ValueError(f'search space "{name}" (choice) should only contain numbers or strings : {spec}')
continue

if type_.startswith('q'):
Expand Down
2 changes: 2 additions & 0 deletions test/ut/sdk/test_hpo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
bad_fields = { 'x': { 'type': 'choice', 'value': ['a', 'b'] } }
bad_type_name = { 'x': { '_type': 'choic', '_value': ['a'] } }
bad_value = { 'x': { '_type': 'choice', '_value': 'ab' } }
bad_choice_args = { 'x': { '_type': 'choice', 'value': [ 'a', object() ] } }
bad_2_args = { 'x': { '_type': 'randint', '_value': [1, 2, 3] } }
bad_3_args = { 'x': { '_type': 'quniform', '_value': [0] } }
bad_int_args = { 'x': { '_type': 'randint', '_value': [1.0, 2.0] } }
Expand All @@ -37,6 +38,7 @@ def test_hpo_utils():
assert not validate_search_space(bad_fields, raise_exception=False)
assert not validate_search_space(bad_type_name, raise_exception=False)
assert not validate_search_space(bad_value, raise_exception=False)
assert not validate_search_space(bad_choice_args, raise_exception=False)
assert not validate_search_space(bad_2_args, raise_exception=False)
assert not validate_search_space(bad_3_args, raise_exception=False)
assert not validate_search_space(bad_int_args, raise_exception=False)
Expand Down