Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENAS Check Algorithm Settings in Validate Function #1146

Merged
merged 1 commit into from
Apr 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 23 additions & 58 deletions pkg/suggestion/v1alpha3/nas/enas/AlgorithmSettings.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
def parseAlgorithmSettings(params_raw, logger):

param_standard = {
"controller_hidden_size": ['value', int, [1, 'inf']],
"controller_temperature": ['value', float, [0, 'inf']],
"controller_tanh_const": ['value', float, [0, 'inf']],
"controller_entropy_weight": ['value', float, [0.0, 'inf']],
"controller_baseline_decay": ['value', float, [0.0, 1.0]],
"controller_learning_rate": ['value', float, [0.0, 1.0]],
"controller_skip_target": ['value', float, [0.0, 1.0]],
"controller_skip_weight": ['value', float, [0.0, 'inf']],
"controller_train_steps": ['value', int, [1, 'inf']],
"controller_log_every_steps": ['value', int, [1, 'inf']],
}
algorithmSettingsValidator = {
"controller_hidden_size": [int, [1, 'inf']],
"controller_temperature": [float, [0, 'inf']],
"controller_tanh_const": [float, [0, 'inf']],
"controller_entropy_weight": [float, [0.0, 'inf']],
"controller_baseline_decay": [float, [0.0, 1.0]],
"controller_learning_rate": [float, [0.0, 1.0]],
"controller_skip_target": [float, [0.0, 1.0]],
"controller_skip_weight": [float, [0.0, 'inf']],
"controller_train_steps": [int, [1, 'inf']],
"controller_log_every_steps": [int, [1, 'inf']],
}


algorithm_settings = {
# TODO: Enable to add None values, e.g in controller_temperature parameter
def parseAlgorithmSettings(settings_raw):

algorithm_settings_default = {
"controller_hidden_size": 64,
"controller_temperature": 5.,
"controller_tanh_const": 2.25,
Expand All @@ -26,48 +29,10 @@ def parseAlgorithmSettings(params_raw, logger):
"controller_log_every_steps": 10,
}

# TODO: Enable to add None values, e.g in controller_temperature parameter
# TODO: Delete it and add to the Validation part
def checktype(param_name, param_value, check_mode, supposed_type, supposed_range=None, logger=None):
correct = True

try:
converted_value = supposed_type(param_value)
except:
correct = False
logger.info("Parameter {} is of wrong type. Set back to default value {}"
.format(param_name, algorithm_settings[param_name]))

if correct and check_mode == 'value':
if (
(supposed_range[0] != '-inf' and
((supposed_type == float and converted_value <= supposed_range[0]) or
converted_value < supposed_range[0])
) or
(supposed_range[1] != 'inf' and converted_value > supposed_range[1])
):
correct = False
logger.info("Parameter {} out of range. Set back to default value {}"
.format(param_name, algorithm_settings[param_name]))

elif correct and check_mode == 'categorical':
if converted_value not in supposed_range:
correct = False
logger.info("Parameter {} out of range. Set back to default value {}"
.format(param_name, algorithm_settings[param_name]))

if correct:
algorithm_settings[param_name] = converted_value

for param in params_raw:
if param.name in algorithm_settings.keys():
checktype(param.name,
param.value,
param_standard[param.name][0], # mode
param_standard[param.name][1], # type
param_standard[param.name][2], # range
logger)
else:
logger.info("Unknown Parameter name: {}".format(param.name))
for setting in settings_raw:
s_name = setting.name
s_value = setting.value
s_type = algorithmSettingsValidator[s_name][0]
algorithm_settings_default[s_name] = s_type(s_value)

return algorithm_settings
return algorithm_settings_default
33 changes: 28 additions & 5 deletions pkg/suggestion/v1alpha3/nas/enas_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pkg.apis.manager.v1alpha3.python import api_pb2_grpc
from pkg.suggestion.v1alpha3.nas.enas.Controller import Controller
from pkg.suggestion.v1alpha3.nas.enas.Operation import SearchSpace
from pkg.suggestion.v1alpha3.nas.enas.AlgorithmSettings import parseAlgorithmSettings
from pkg.suggestion.v1alpha3.nas.enas.AlgorithmSettings import parseAlgorithmSettings, algorithmSettingsValidator
from pkg.suggestion.v1alpha3.base_health_service import HealthServicer


Expand Down Expand Up @@ -66,9 +66,9 @@ def _get_experiment_param(self):

self.print_search_space()

# Get Experiment Parameters
params_raw = self.experiment.spec.algorithm.algorithm_setting
self.algorithm_settings = parseAlgorithmSettings(params_raw, self.logger)
# Get Experiment Algorithm Settings
settings_raw = self.experiment.spec.algorithm.algorithm_setting
self.algorithm_settings = parseAlgorithmSettings(settings_raw)

self.print_algorithm_settings()

Expand Down Expand Up @@ -150,7 +150,6 @@ def ValidateAlgorithmSettings(self, request, context):
self.logger.info("Validate Algorithm Settings start")
graph_config = request.experiment.spec.nas_config.graph_config

# TODO: Refactor this since we validate it in Katib Controller
# Validate GraphConfig
# Check InputSize
if not graph_config.input_sizes:
Expand Down Expand Up @@ -202,6 +201,30 @@ def ValidateAlgorithmSettings(self, request, context):
if parameter.parameter_type == api_pb2.DOUBLE and (not parameter.feasible_space.step or float(parameter.feasible_space.step) <= 0):
return self.SetValidateContextError(context, "Step parameter should be > 0 in ParameterConfig.feasibleSpace:\n{}".format(parameter))

# Validate Algorithm Settings
settings_raw = request.experiment.spec.algorithm.algorithm_setting
for setting in settings_raw:
if setting.name in algorithmSettingsValidator.keys():
setting_type = algorithmSettingsValidator[setting.name][0]
setting_range = algorithmSettingsValidator[setting.name][1]
try:
converted_value = setting_type(setting.value)
except:
return self.SetValidateContextError(context, "Algorithm Setting {} must be {} type".format(setting.name, setting_type.__name__))

if setting_type == float:
if converted_value <= setting_range[0] or (setting_range[1] != 'inf' and converted_value > setting_range[1]):
return self.SetValidateContextError(context, "Algorithm Setting {}: {} with {} type must be in range ({}, {}]".format(
setting.name, converted_value, setting_type.__name__, setting_range[0], setting_range[1]
))

elif converted_value < setting_range[0]:
return self.SetValidateContextError(context, "Algorithm Setting {}: {} with {} type must be in range [{}, {})".format(
setting.name, converted_value, setting_type.__name__, setting_range[0], setting_range[1]
))
else:
return self.SetValidateContextError(context, "Unknown Algorithm Setting name: {}".format(setting.name))

self.logger.info("All Experiment Settings are Valid")
return api_pb2.ValidateAlgorithmSettingsReply()

Expand Down