Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
add ut
Browse files Browse the repository at this point in the history
  • Loading branch information
tabVersion committed Sep 22, 2020
1 parent b838829 commit 11f375e
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/sdk/pynni/nni/evo_nas_tuner/evo_nas_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def _random_model(self):

def _mutate_model(self, model):
new_individual = copy.deepcopy(model.parameters)
mutate_key = random.choice(new_individual.keys())
mutate_key = random.choice(list(new_individual.keys()))
mutate_val = self.search_space[mutate_key]
if mutate_val['_type'] == 'layer_choice':
idx = random.randint(0, len(mutate_val['_value']) - 1)
Expand Down
26 changes: 26 additions & 0 deletions src/sdk/pynni/tests/assets/classic_nas_search_space.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{
"first_conv": {
"_type": "layer_choice",
"_value": [
"conv5x5",
"conv3x3"
]
},
"mid_conv": {
"_type": "layer_choice",
"_value": [
"0",
"1"
]
},
"skip": {
"_type": "input_choice",
"_value": {
"candidates": [
"",
""
],
"n_chosen": 1
}
}
}
90 changes: 90 additions & 0 deletions src/sdk/pynni/tests/test_classic_nas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import glob
import json
import logging
import os
import random
import shutil
import sys
from collections import deque
from unittest import TestCase, main

from nni.tuner import Tuner
from nni.evo_nas_tuner.evo_nas_tuner import EvoNasTuner

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('test_tuner')


class ClassicNasTestCase(TestCase):
def setUp(self):
self.test_round = 3
self.params_each_round = 50
self.exhaustive = False

def check_range(self, generated_params, search_space):
for params in generated_params:
for k in params:
v = params[k]
items = search_space[k]
if items['_type'] == 'layer_choice':
self.assertIn(v['_value'], items['_value'])
elif items['_type'] == 'input_choice':
for choice in v['_value']:
self.assertIn(choice, items['_value']['candidates'])
else:
raise KeyError

def send_trial_result(self, tuner, parameter_id, parameters, metrics):
tuner.receive_trial_result(parameter_id, parameters, metrics)
tuner.trial_end(parameter_id, True)

def search_space_test_one(self, tuner_factory, search_space):
tuner = tuner_factory()
self.assertIsInstance(tuner, Tuner)
tuner.update_search_space(search_space)
for i in range(self.test_round):
queue = deque()
parameters = tuner.generate_multiple_parameters(list(range(i * self.params_each_round,
(i + 1) * self.params_each_round)),
st_callback=self.send_trial_callback(queue))
logger.debug(parameters)
self.check_range(parameters, search_space)
for k in range(min(len(parameters), self.params_each_round)):
self.send_trial_result(tuner, self.params_each_round * i + k, parameters[k], random.uniform(-100, 100))
while queue:
id_, params = queue.popleft()
self.check_range([params], search_space)
self.send_trial_result(tuner, id_, params, random.uniform(-100, 100))
if not parameters and not self.exhaustive:
raise ValueError("No parameters generated")

def send_trial_callback(self, param_queue):
def receive(*args):
param_queue.append(tuple(args))
return receive

def search_space_test_all(self, tuner_factory):
# Since classic tuner should support only LayerChoice and InputChoice,
# ignore type and fail type are dismissed here.
with open(os.path.join(os.path.dirname(__file__), "assets/classic_nas_search_space.json"), "r") as fp:
search_space_all = json.load(fp)
full_supported_search_space = dict()
for single in search_space_all:
space = search_space_all[single]
single_search_space = {single: space}
self.search_space_test_one(tuner_factory, single_search_space)
full_supported_search_space.update(single_search_space)
logger.info("Full supported search space: %s", full_supported_search_space)
self.search_space_test_one(tuner_factory, full_supported_search_space)

def test_evo_nas_tuner(self):
tuner_fn = lambda: EvoNasTuner()
self.search_space_test_all(tuner_fn)


if __name__ == '__main__':
main()

0 comments on commit 11f375e

Please sign in to comment.