From 1b4bcdba7ee258cd676b7414e0eb398f606b255f Mon Sep 17 00:00:00 2001 From: Lee Date: Tue, 5 Mar 2019 00:29:06 -0600 Subject: [PATCH] Optimize MetisTuner (#811) * add different tuner config files for config_test * change MetisTuner config test due to no lightgbm python module in integration test * install smac package in azure-pipelines * SMAC need swig to be installed * Try to install swig from source code * remove SMAC test because the dependency can not be installed * use sudo to install the swig * sleep 10s to make sure the port has been released * remove tuner test for networkmorphism because it uses more than 30s to release the tcp port * word "down" to "done" * add config test for Curvefitting assessor * change file name * Fix data type not match bug * Optimize MetisTunner * pretty the code * Follow the review comment * add exploration probability --- .../Regression_GP/OutlierDetection.py | 26 ++-- src/sdk/pynni/nni/metis_tuner/metis_tuner.py | 131 +++++++++++------- 2 files changed, 91 insertions(+), 66 deletions(-) diff --git a/src/sdk/pynni/nni/metis_tuner/Regression_GP/OutlierDetection.py b/src/sdk/pynni/nni/metis_tuner/Regression_GP/OutlierDetection.py index 353c56f2b0..95a8abfdc6 100644 --- a/src/sdk/pynni/nni/metis_tuner/Regression_GP/OutlierDetection.py +++ b/src/sdk/pynni/nni/metis_tuner/Regression_GP/OutlierDetection.py @@ -37,13 +37,13 @@ def _outlierDetection_threaded(inputs): sys.stderr.write("[%s] DEBUG: Evaluating %dth of %d samples\n"\ % (os.path.basename(__file__), samples_idx + 1, len(samples_x))) outlier = None - + # Create a diagnostic regression model which removes the sample that we want to evaluate - diagnostic_regressor_gp = gp_create_model.createModel(\ + diagnostic_regressor_gp = gp_create_model.create_model(\ samples_x[0:samples_idx] + samples_x[samples_idx + 1:],\ samples_y_aggregation[0:samples_idx] + samples_y_aggregation[samples_idx + 1:]) mu, sigma = gp_prediction.predict(samples_x[samples_idx], diagnostic_regressor_gp['model']) - + # 2.33 is the z-score for 98% confidence level if abs(samples_y_aggregation[samples_idx] - mu) > (2.33 * sigma): outlier = {"samples_idx": samples_idx, @@ -51,26 +51,26 @@ def _outlierDetection_threaded(inputs): "expected_sigma": sigma, "difference": abs(samples_y_aggregation[samples_idx] - mu) - (2.33 * sigma)} return outlier - + def outlierDetection_threaded(samples_x, samples_y_aggregation): - ''' + ''' Use Multi-thread to detect the outlier ''' outliers = [] - + threads_inputs = [[samples_idx, samples_x, samples_y_aggregation]\ for samples_idx in range(0, len(samples_x))] threads_pool = ThreadPool(min(4, len(threads_inputs))) threads_results = threads_pool.map(_outlierDetection_threaded, threads_inputs) threads_pool.close() threads_pool.join() - + for threads_result in threads_results: if threads_result is not None: outliers.append(threads_result) else: print("error here.") - + outliers = None if len(outliers) == 0 else outliers return outliers @@ -79,21 +79,19 @@ def outlierDetection(samples_x, samples_y_aggregation): ''' outliers = [] for samples_idx in range(0, len(samples_x)): - #sys.stderr.write("[%s] DEBUG: Evaluating %d of %d samples\n" + #sys.stderr.write("[%s] DEBUG: Evaluating %d of %d samples\n" # \ % (os.path.basename(__file__), samples_idx + 1, len(samples_x))) - diagnostic_regressor_gp = gp_create_model.createModel(\ + diagnostic_regressor_gp = gp_create_model.create_model(\ samples_x[0:samples_idx] + samples_x[samples_idx + 1:],\ samples_y_aggregation[0:samples_idx] + samples_y_aggregation[samples_idx + 1:]) mu, sigma = gp_prediction.predict(samples_x[samples_idx], - diagnostic_regressor_gp['model']) + diagnostic_regressor_gp['model']) # 2.33 is the z-score for 98% confidence level if abs(samples_y_aggregation[samples_idx] - mu) > (2.33 * sigma): outliers.append({"samples_idx": samples_idx, "expected_mu": mu, "expected_sigma": sigma, "difference": abs(samples_y_aggregation[samples_idx] - mu) - (2.33 * sigma)}) - + outliers = None if len(outliers) == 0 else outliers return outliers - - \ No newline at end of file diff --git a/src/sdk/pynni/nni/metis_tuner/metis_tuner.py b/src/sdk/pynni/nni/metis_tuner/metis_tuner.py index d9e32cd4c3..4c743651aa 100644 --- a/src/sdk/pynni/nni/metis_tuner/metis_tuner.py +++ b/src/sdk/pynni/nni/metis_tuner/metis_tuner.py @@ -24,22 +24,20 @@ import random import statistics import sys - -import numpy as np - from enum import Enum, unique from multiprocessing.dummy import Pool as ThreadPool -from nni.tuner import Tuner +import numpy as np -import nni.metis_tuner.lib_data as lib_data import nni.metis_tuner.lib_constraint_summation as lib_constraint_summation -import nni.metis_tuner.Regression_GP.CreateModel as gp_create_model -import nni.metis_tuner.Regression_GP.Selection as gp_selection -import nni.metis_tuner.Regression_GP.Prediction as gp_prediction -import nni.metis_tuner.Regression_GP.OutlierDetection as gp_outlier_detection +import nni.metis_tuner.lib_data as lib_data import nni.metis_tuner.Regression_GMM.CreateModel as gmm_create_model import nni.metis_tuner.Regression_GMM.Selection as gmm_selection +import nni.metis_tuner.Regression_GP.CreateModel as gp_create_model +import nni.metis_tuner.Regression_GP.OutlierDetection as gp_outlier_detection +import nni.metis_tuner.Regression_GP.Prediction as gp_prediction +import nni.metis_tuner.Regression_GP.Selection as gp_selection +from nni.tuner import Tuner logger = logging.getLogger("Metis_Tuner_AutoML") @@ -67,33 +65,37 @@ class MetisTuner(Tuner): """ def __init__(self, optimize_mode="maximize", no_resampling=True, no_candidates=True, - selection_num_starting_points=10, cold_start_num=10): + selection_num_starting_points=600, cold_start_num=10, exploration_probability=0.1): """ Parameters ---------- optimize_mode : str optimize_mode is a string that including two mode "maximize" and "minimize" - + no_resampling : bool True or False. Should Metis consider re-sampling as part of the search strategy? If you are confident that the training dataset is noise-free, then you do not need re-sampling. - + no_candidates: bool True or False. Should Metis suggest parameters for the next benchmark? If you do not plan to do more benchmarks, Metis can skip this step. - + selection_num_starting_points: int how many times Metis should try to find the global optimal in the search space? The higher the number, the longer it takes to output the solution. - + cold_start_num: int Metis need some trial result to get cold start. when the number of trial result is less than cold_start_num, Metis will randomly sample hyper-parameter for trial. + + exploration_probability: float + The probability of Metis to select parameter from exploration instead of exploitation. """ self.samples_x = [] self.samples_y = [] self.samples_y_aggregation = [] + self.history_parameters = set() self.space = None self.no_resampling = no_resampling self.no_candidates = no_candidates @@ -101,6 +103,7 @@ def __init__(self, optimize_mode="maximize", no_resampling=True, no_candidates=T self.key_order = [] self.cold_start_num = cold_start_num self.selection_num_starting_points = selection_num_starting_points + self.exploration_probability = exploration_probability self.minimize_constraints_fun = None self.minimize_starting_points = None @@ -128,7 +131,7 @@ def update_search_space(self, search_space): except Exception as ex: logger.exception(ex) raise RuntimeError("The format search space contains \ - some key that didn't define in key_order.") + some key that didn't define in key_order." ) if key_type == 'quniform': if key_range[2] == 1: @@ -191,7 +194,7 @@ def generate_parameters(self, parameter_id): Parameters ---------- parameter_id : int - + Returns ------- result : dict @@ -200,13 +203,15 @@ def generate_parameters(self, parameter_id): init_parameter = _rand_init(self.x_bounds, self.x_types, 1)[0] results = self._pack_output(init_parameter) else: + self.minimize_starting_points = _rand_init(self.x_bounds, self.x_types, \ + self.selection_num_starting_points) results = self._selection(self.samples_x, self.samples_y_aggregation, self.samples_y, self.x_bounds, self.x_types, threshold_samplessize_resampling=(None if self.no_resampling is True else 50), no_candidates=self.no_candidates, minimize_starting_points=self.minimize_starting_points, minimize_constraints_fun=self.minimize_constraints_fun) - + logger.info("Generate paramageters:\n" + str(results)) return results @@ -245,7 +250,7 @@ def receive_trial_result(self, parameter_id, parameters, value): # calculate y aggregation median = get_median(temp_y) - self.samples_y_aggregation[idx] = median + self.samples_y_aggregation[idx] = [median] else: self.samples_x.append(sample_x) self.samples_y.append([value]) @@ -264,17 +269,21 @@ def _selection(self, samples_x, samples_y_aggregation, samples_y, candidates = [] samples_size_all = sum([len(i) for i in samples_y]) samples_size_unique = len(samples_y) - + # ===== STEP 1: Compute the current optimum ===== #sys.stderr.write("[%s] Predicting the optimal configuration from the current training dataset...\n" % (os.path.basename(__file__))) gp_model = gp_create_model.create_model(samples_x, samples_y_aggregation) - lm_current = gp_selection.selection("lm", samples_y_aggregation, x_bounds, - x_types, gp_model['model'], - minimize_starting_points, - minimize_constraints_fun=minimize_constraints_fun) + lm_current = gp_selection.selection( + "lm", + samples_y_aggregation, + x_bounds, + x_types, + gp_model['model'], + minimize_starting_points, + minimize_constraints_fun=minimize_constraints_fun) if not lm_current: return None - + if no_candidates is False: candidates.append({'hyperparameter': lm_current['hyperparameter'], 'expected_mu': lm_current['expected_mu'], @@ -284,10 +293,14 @@ def _selection(self, samples_x, samples_y_aggregation, samples_y, # ===== STEP 2: Get recommended configurations for exploration ===== #sys.stderr.write("[%s] Getting candidates for exploration...\n" #% \(os.path.basename(__file__))) - results_exploration = gp_selection.selection("lc", samples_y_aggregation, - x_bounds, x_types, gp_model['model'], - minimize_starting_points, - minimize_constraints_fun=minimize_constraints_fun) + results_exploration = gp_selection.selection( + "lc", + samples_y_aggregation, + x_bounds, + x_types, + gp_model['model'], + minimize_starting_points, + minimize_constraints_fun=minimize_constraints_fun) if results_exploration is not None: if _num_past_samples(results_exploration['hyperparameter'], samples_x, samples_y) == 0: @@ -308,12 +321,13 @@ def _selection(self, samples_x, samples_y_aggregation, samples_y, print("Getting candidates for exploitation...\n") try: gmm = gmm_create_model.create_model(samples_x, samples_y_aggregation) - results_exploitation = gmm_selection.selection(x_bounds, - x_types, - gmm['clusteringmodel_good'], - gmm['clusteringmodel_bad'], - minimize_starting_points, - minimize_constraints_fun=minimize_constraints_fun) + results_exploitation = gmm_selection.selection( + x_bounds, + x_types, + gmm['clusteringmodel_good'], + gmm['clusteringmodel_bad'], + minimize_starting_points, + minimize_constraints_fun=minimize_constraints_fun) if results_exploitation is not None: if _num_past_samples(results_exploitation['hyperparameter'], samples_x, samples_y) == 0: @@ -326,9 +340,9 @@ def _selection(self, samples_x, samples_y_aggregation, samples_y, logger.info("DEBUG: No suitable exploitation_gmm candidates were found\n") except ValueError as exception: - # The exception: ValueError: Fitting the mixture model failed - # because some components have ill-defined empirical covariance - # (for instance caused by singleton or collapsed samples). + # The exception: ValueError: Fitting the mixture model failed + # because some components have ill-defined empirical covariance + # (for instance caused by singleton or collapsed samples). # Try to decrease the number of components, or increase reg_covar. logger.info("DEBUG: No suitable exploitation_gmm candidates were found due to exception.") logger.info(exception) @@ -340,7 +354,6 @@ def _selection(self, samples_x, samples_y_aggregation, samples_y, results_outliers = gp_outlier_detection.outlierDetection_threaded(samples_x, samples_y_aggregation) if results_outliers is not None: - #temp = len(candidates) for results_outlier in results_outliers: if _num_past_samples(samples_x[results_outlier['samples_idx']], samples_x, samples_y) < max_resampling_per_x: @@ -357,7 +370,10 @@ def _selection(self, samples_x, samples_y_aggregation, samples_y, logger.info("Evaluating information gain of %d candidates...\n") next_improvement = 0 - threads_inputs = [[candidate, samples_x, samples_y, x_bounds, x_types, minimize_constraints_fun, minimize_starting_points] for candidate in candidates] + threads_inputs = [[ + candidate, samples_x, samples_y, x_bounds, x_types, + minimize_constraints_fun, minimize_starting_points + ] for candidate in candidates] threads_pool = ThreadPool(4) # Evaluate what would happen if we actually sample each candidate threads_results = threads_pool.map(_calculate_lowest_mu_threaded, threads_inputs) @@ -368,21 +384,23 @@ def _selection(self, samples_x, samples_y_aggregation, samples_y, if threads_result['expected_lowest_mu'] < lm_current['expected_mu']: # Information gain temp_improvement = threads_result['expected_lowest_mu'] - lm_current['expected_mu'] - + if next_improvement > temp_improvement: - # logger.info("DEBUG: \"next_candidate\" changed: \ - # lowest mu might reduce from %f (%s) to %f (%s), %s\n" %\ - # lm_current['expected_mu'], str(lm_current['hyperparameter']),\ - # threads_result['expected_lowest_mu'],\ - # str(threads_result['candidate']['hyperparameter']),\ - # threads_result['candidate']['reason']) + logger.info("DEBUG: \"next_candidate\" changed: \ + lowest mu might reduce from %f (%s) to %f (%s), %s\n" %\ + lm_current['expected_mu'], str(lm_current['hyperparameter']),\ + threads_result['expected_lowest_mu'],\ + str(threads_result['candidate']['hyperparameter']),\ + threads_result['candidate']['reason']) next_improvement = temp_improvement next_candidate = threads_result['candidate'] else: # ===== STEP 6: If we have no candidates, randomly pick one ===== - logger.info("DEBUG: No candidates from exploration, exploitation,\ - and resampling. We will random a candidate for next_candidate\n") + logger.info( + "DEBUG: No candidates from exploration, exploitation,\ + and resampling. We will random a candidate for next_candidate\n" + ) next_candidate = _rand_with_constraints(x_bounds, x_types) \ if minimize_starting_points is None else minimize_starting_points[0] @@ -391,7 +409,12 @@ def _selection(self, samples_x, samples_y_aggregation, samples_y, next_candidate = {'hyperparameter': next_candidate, 'reason': "random", 'expected_mu': expected_mu, 'expected_sigma': expected_sigma} + # ===== STEP 7: If current optimal hyperparameter occurs in the history or exploration probability is less than the threshold, take next config as exploration step ===== outputs = self._pack_output(lm_current['hyperparameter']) + ap = random.uniform(0, 1) + if outputs in self.history_parameters or ap<=self.exploration_probability: + outputs = self._pack_output(next_candidate['hyperparameter']) + self.history_parameters.add(outputs) return outputs @@ -437,10 +460,14 @@ def _calculate_lowest_mu_threaded(inputs): # Aggregates multiple observation of the sample sampling points temp_y_aggregation = [statistics.median(temp_sample_y) for temp_sample_y in temp_samples_y] temp_gp = gp_create_model.create_model(temp_samples_x, temp_y_aggregation) - temp_results = gp_selection.selection("lm", temp_y_aggregation, - x_bounds, x_types, temp_gp['model'], - minimize_starting_points, - minimize_constraints_fun=minimize_constraints_fun) + temp_results = gp_selection.selection( + "lm", + temp_y_aggregation, + x_bounds, + x_types, + temp_gp['model'], + minimize_starting_points, + minimize_constraints_fun=minimize_constraints_fun) if outputs["expected_lowest_mu"] is None or outputs["expected_lowest_mu"] > temp_results['expected_mu']: outputs["expected_lowest_mu"] = temp_results['expected_mu']