Skip to content

Commit

Permalink
Fix Optuna Validation for CMA-ES (#2240)
Browse files Browse the repository at this point in the history
* Fix Optuna Validation for CMA-ES

* Fix Optuna test
  • Loading branch information
andreyvelich authored Nov 2, 2023
1 parent d2e311f commit 700e64e
Show file tree
Hide file tree
Showing 3 changed files with 390 additions and 132 deletions.
2 changes: 1 addition & 1 deletion cmd/suggestion/optuna/v1beta1/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
grpcio>=1.41.1
protobuf>=3.19.5, <=3.20.3
googleapis-common-protos==1.53.0
optuna>=3.0.0
optuna==3.3.0
99 changes: 70 additions & 29 deletions pkg/suggestion/v1beta1/optuna/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@


class OptunaService(api_pb2_grpc.SuggestionServicer, HealthServicer):

def __init__(self):
super(OptunaService, self).__init__()
self.lock = threading.Lock()
Expand All @@ -39,23 +38,29 @@ def GetSuggestions(self, request, context):
Main function to provide suggestion.
"""
with self.lock:
name, config = OptimizerConfiguration.convert_algorithm_spec(request.experiment.spec.algorithm)
name, config = OptimizerConfiguration.convert_algorithm_spec(
request.experiment.spec.algorithm
)
if self.base_service is None:
search_space = HyperParameterSearchSpace.convert(request.experiment)
self.base_service = BaseOptunaService(
algorithm_name=name,
algorithm_config=config,
search_space=search_space)
search_space=search_space,
)

trials = Trial.convert(request.trials)
list_of_assignments = self.base_service.get_suggestions(trials, request.current_request_number)
list_of_assignments = self.base_service.get_suggestions(
trials, request.current_request_number
)
return api_pb2.GetSuggestionsReply(
parameter_assignments=Assignment.generate(list_of_assignments)
)

def ValidateAlgorithmSettings(self, request, context):
is_valid, message = OptimizerConfiguration.validate_algorithm_spec(
request.experiment)
request.experiment
)
if not is_valid:
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
context.set_details(message)
Expand Down Expand Up @@ -88,7 +93,7 @@ class OptimizerConfiguration(object):
},
"grid": {
"seed": lambda x: int(x),
}
},
}

@classmethod
Expand Down Expand Up @@ -117,11 +122,12 @@ def validate_algorithm_spec(cls, experiment):
algorithm_spec = experiment.spec.algorithm
algorithm_name = algorithm_spec.algorithm_name
algorithm_settings = algorithm_spec.algorithm_settings
parameters = experiment.spec.parameter_specs.parameters

if algorithm_name == "tpe" or algorithm_name == "multivariate-tpe":
return cls._validate_tpe_setting(algorithm_spec)
elif algorithm_name == "cmaes":
return cls._validate_cmaes_setting(algorithm_settings)
return cls._validate_cmaes_setting(algorithm_settings, parameters)
elif algorithm_name == "random":
return cls._validate_random_setting(algorithm_settings)
elif algorithm_name == "grid":
Expand All @@ -138,37 +144,58 @@ def _validate_tpe_setting(cls, algorithm_spec):
try:
if s.name in ["n_startup_trials", "n_ei_candidates", "random_state"]:
if not int(s.value) >= 0:
return False, "{} should be greate or equal than zero".format(s.name)
return False, "{} should be greate or equal than zero".format(
s.name
)
else:
return False, "unknown setting {} for algorithm {}".format(s.name, algorithm_name)
return False, "unknown setting {} for algorithm {}".format(
s.name, algorithm_name
)
except Exception as e:
return False, "failed to validate {name}({value}): {exception}".format(name=s.name, value=s.value,
exception=e)
return False, "failed to validate {name}({value}): {exception}".format(
name=s.name, value=s.value, exception=e
)

return True, ""

@classmethod
def _validate_cmaes_setting(cls, algorithm_settings):
if len(algorithm_settings) < 2:
return False, "cmaes only supports two or more dimensional continuous search space."

def _validate_cmaes_setting(cls, algorithm_settings, parameters):
for s in algorithm_settings:
try:
if s.name == "restart_strategy":
if s.value not in ["ipop", "None", "none"]:
return False, "restart_strategy {} is not supported in CMAES optimization".format(s.value)
return (
False,
"restart_strategy {} is not supported in CMAES optimization".format(
s.value
),
)
elif s.name == "sigma":
if not float(s.value) >= 0:
return False, "sigma should be greate or equal than zero"
elif s.name == "random_state":
if not int(s.value) >= 0:
return False, "random_state should be greate or equal than zero"
else:
return False, "unknown setting {} for algorithm cmaes".format(s.name)
return False, "unknown setting {} for algorithm cmaes".format(
s.name
)

except Exception as e:
return False, "failed to validate {name}({value}): {exception}".format(name=s.name, value=s.value,
exception=e)
return False, "failed to validate {name}({value}): {exception}".format(
name=s.name, value=s.value, exception=e
)

cnt = 0
for p in parameters:
if p.parameter_type == api_pb2.DOUBLE or p.parameter_type == api_pb2.INT:
cnt += 1
if cnt < 2:
return (
False,
"cmaes only supports two or more dimensional continuous search space.",
)

return True, ""

@classmethod
Expand All @@ -179,11 +206,14 @@ def _validate_random_setting(cls, algorithm_settings):
if not int(s.value) >= 0:
return False, ""
else:
return False, "unknown setting {} for algorithm random".format(s.name)
return False, "unknown setting {} for algorithm random".format(
s.name
)

except Exception as e:
return False, "failed to validate {name}({value}): {exception}".format(name=s.name, value=s.value,
exception=e)
return False, "failed to validate {name}({value}): {exception}".format(
name=s.name, value=s.value, exception=e
)

return True, ""

Expand All @@ -201,19 +231,30 @@ def _validate_grid_setting(cls, experiment):
return False, "unknown setting {} for algorithm grid".format(s.name)

except Exception as e:
return False, "failed to validate {name}({value}): {exception}".format(name=s.name, value=s.value,
exception=e)
return False, "failed to validate {name}({value}): {exception}".format(
name=s.name, value=s.value, exception=e
)

try:
combinations = HyperParameterSearchSpace.convert_to_combinations(search_space)
combinations = HyperParameterSearchSpace.convert_to_combinations(
search_space
)
num_combinations = len(list(itertools.product(*combinations.values())))
max_trial_count = experiment.spec.max_trial_count
if max_trial_count > num_combinations:
return False, "Max Trial Count: {max_trial} > all possible search combinations: {combinations}".\
format(max_trial=max_trial_count, combinations=num_combinations)
return (
False,
"Max Trial Count: {max_trial} > all possible search combinations: {combinations}".format(
max_trial=max_trial_count, combinations=num_combinations
),
)

except Exception as e:
return False, "failed to validate parameters({parameters}): {exception}".\
format(parameters=search_space.params, exception=e)
return (
False,
"failed to validate parameters({parameters}): {exception}".format(
parameters=search_space.params, exception=e
),
)

return True, ""
Loading

0 comments on commit 700e64e

Please sign in to comment.