Skip to content

Commit

Permalink
Add supported data types for PBT tuner (microsoft#2271)
Browse files Browse the repository at this point in the history
  • Loading branch information
RayMeng8 authored Apr 7, 2020
1 parent c61700f commit d2c5777
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 17 deletions.
Empty file.
4 changes: 2 additions & 2 deletions examples/trials/mnist-pbt-tuner-pytorch/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ def get_params():
help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
help='SGD momentum (default: 0.5)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--epochs', type=int, default=1, metavar='N',
help='number of epochs to train (default: 1)')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--no_cuda', action='store_true', default=False,
Expand Down
124 changes: 109 additions & 15 deletions src/sdk/pynni/nni/pbt_tuner/pbt_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,54 @@
import copy
import logging
import os
import random
import numpy as np

import nni
import nni.parameter_expressions
from nni.tuner import Tuner
from nni.utils import OptimizeMode, extract_scalar_reward, split_index, json2parameter, json2space


logger = logging.getLogger('pbt_tuner_AutoML')


def exploit_and_explore(bot_trial_info, top_trial_info, factors, epoch, search_space):
def perturbation(hyperparameter_type, value, resample_probablity, uv, ub, lv, lb, random_state):
"""
Perturbation for hyperparameters
Parameters
----------
hyperparameter_type : str
type of hyperparameter
value : list
parameters for sampling hyperparameter
resample_probability : float
probability for resampling
uv : float/int
upper value after perturbation
ub : float/int
upper bound
lv : float/int
lower value after perturbation
lb : float/int
lower bound
random_state : RandomState
random state
"""
if random.random() < resample_probablity:
if hyperparameter_type == "choice":
return value.index(nni.parameter_expressions.choice(value, random_state))
else:
return getattr(nni.parameter_expressions, hyperparameter_type)(*(value + [random_state]))
else:
if random.random() > 0.5:
return min(uv, ub)
else:
return max(lv, lb)


def exploit_and_explore(bot_trial_info, top_trial_info, factor, resample_probability, epoch, search_space):
"""
Replace checkpoint of bot_trial with top, and perturb hyperparameters
Expand All @@ -24,8 +61,10 @@ def exploit_and_explore(bot_trial_info, top_trial_info, factors, epoch, search_s
bottom model whose parameters should be replaced
top_trial_info : TrialInfo
better model
factors : float
factors for perturbation
factor : float
factor for perturbation
resample_probability : float
probability for resampling
epoch : int
step of PBTTuner
search_space : dict
Expand All @@ -34,21 +73,72 @@ def exploit_and_explore(bot_trial_info, top_trial_info, factors, epoch, search_s
bot_checkpoint_dir = bot_trial_info.checkpoint_dir
top_hyper_parameters = top_trial_info.hyper_parameters
hyper_parameters = copy.deepcopy(top_hyper_parameters)
# TODO think about different type of hyperparameters for 1.perturbation 2.within search space
random_state = np.random.RandomState()
for key in hyper_parameters.keys():
hyper_parameter = hyper_parameters[key]
if key == 'load_checkpoint_dir':
hyper_parameters[key] = hyper_parameters['save_checkpoint_dir']
continue
elif key == 'save_checkpoint_dir':
hyper_parameters[key] = os.path.join(bot_checkpoint_dir, str(epoch))
elif isinstance(hyper_parameters[key], float):
perturb = np.random.choice(factors)
val = hyper_parameters[key] * perturb
continue
elif search_space[key]["_type"] == "choice":
choices = search_space[key]["_value"]
ub, uv = len(choices) - 1, choices.index(hyper_parameter["_value"]) + 1
lb, lv = 0, choices.index(hyper_parameter["_value"]) - 1
elif search_space[key]["_type"] == "randint":
lb, ub = search_space[key]["_value"][:2]
if search_space[key]["_type"] in ("uniform", "normal"):
val = np.clip(val, lb, ub).item()
hyper_parameters[key] = val
ub -= 1
uv = hyper_parameter + 1
lv = hyper_parameter - 1
elif search_space[key]["_type"] == "uniform":
lb, ub = search_space[key]["_value"][:2]
perturb = (ub - lb) * factor
uv = hyper_parameter + perturb
lv = hyper_parameter - perturb
elif search_space[key]["_type"] == "quniform":
lb, ub, q = search_space[key]["_value"][:3]
multi = round(hyper_parameter / q)
uv = (multi + 1) * q
lv = (multi - 1) * q
elif search_space[key]["_type"] == "loguniform":
lb, ub = search_space[key]["_value"][:2]
perturb = (np.log(ub) - np.log(lb)) * factor
uv = np.exp(min(np.log(hyper_parameter) + perturb, np.log(ub)))
lv = np.exp(max(np.log(hyper_parameter) - perturb, np.log(lb)))
elif search_space[key]["_type"] == "qloguniform":
lb, ub, q = search_space[key]["_value"][:3]
multi = round(hyper_parameter / q)
uv = (multi + 1) * q
lv = (multi - 1) * q
elif search_space[key]["_type"] == "normal":
sigma = search_space[key]["_value"][1]
perturb = sigma * factor
uv = ub = hyper_parameter + perturb
lv = lb = hyper_parameter - perturb
elif search_space[key]["_type"] == "qnormal":
q = search_space[key]["_value"][2]
uv = ub = hyper_parameter + q
lv = lb = hyper_parameter - q
elif search_space[key]["_type"] == "lognormal":
sigma = search_space[key]["_value"][1]
perturb = sigma * factor
uv = ub = np.exp(np.log(hyper_parameter) + perturb)
lv = lb = np.exp(np.log(hyper_parameter) - perturb)
elif search_space[key]["_type"] == "qlognormal":
q = search_space[key]["_value"][2]
uv = ub = hyper_parameter + q
lv, lb = hyper_parameter - q, 1E-10
else:
logger.warning("Illegal type to perturb: %s", search_space[key]["_type"])
continue
if search_space[key]["_type"] == "choice":
idx = perturbation(search_space[key]["_type"], search_space[key]["_value"],
resample_probability, uv, ub, lv, lb, random_state)
hyper_parameters[key] = {'_index': idx, '_value': choices[idx]}
else:
hyper_parameters[key] = perturbation(search_space[key]["_type"], search_space[key]["_value"],
resample_probability, uv, ub, lv, lb, random_state)
bot_trial_info.hyper_parameters = hyper_parameters
bot_trial_info.clean_id()

Expand All @@ -70,7 +160,8 @@ def clean_id(self):


class PBTTuner(Tuner):
def __init__(self, optimize_mode="maximize", all_checkpoint_dir=None, population_size=10, factors=(1.2, 0.8), fraction=0.2):
def __init__(self, optimize_mode="maximize", all_checkpoint_dir=None, population_size=10, factor=0.2,
resample_probability=0.25, fraction=0.2):
"""
Initialization
Expand All @@ -82,8 +173,10 @@ def __init__(self, optimize_mode="maximize", all_checkpoint_dir=None, population
directory to store training model checkpoint
population_size : int
number of trials for each epoch
factors : tuple
factors for perturbation
factor : float
factor for perturbation
resample_probability : float
probability for resampling
fraction : float
fraction for selecting bottom and top trials
"""
Expand All @@ -93,7 +186,8 @@ def __init__(self, optimize_mode="maximize", all_checkpoint_dir=None, population
logger.info("Checkpoint dir is set to %s by default.", all_checkpoint_dir)
self.all_checkpoint_dir = all_checkpoint_dir
self.population_size = population_size
self.factors = factors
self.factor = factor
self.resample_probability = resample_probability
self.fraction = fraction
# defined in trial code
#self.perturbation_interval = perturbation_interval
Expand Down Expand Up @@ -237,7 +331,7 @@ def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
bottoms = self.finished[self.finished_trials - cutoff:]
for bottom in bottoms:
top = np.random.choice(tops)
exploit_and_explore(bottom, top, self.factors, self.epoch, self.searchspace_json)
exploit_and_explore(bottom, top, self.factor, self.resample_probability, self.epoch, self.searchspace_json)
for trial in self.finished:
if trial not in bottoms:
trial.clean_id()
Expand Down

0 comments on commit d2c5777

Please sign in to comment.