From da7b6bc319b8ac0831b0d6d95c7b737f54bc867a Mon Sep 17 00:00:00 2001 From: avelichk Date: Fri, 3 Apr 2020 13:55:24 +0100 Subject: [PATCH] Support step in int parameter for hyperopt and chocolate --- .../v1alpha3/chocolate/base_chocolate_service.py | 2 +- .../v1alpha3/hyperopt/base_hyperopt_service.py | 3 ++- pkg/suggestion/v1alpha3/internal/search_space.py | 10 +++++++--- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/pkg/suggestion/v1alpha3/chocolate/base_chocolate_service.py b/pkg/suggestion/v1alpha3/chocolate/base_chocolate_service.py index 89c3b353439..2e4a3b6f397 100644 --- a/pkg/suggestion/v1alpha3/chocolate/base_chocolate_service.py +++ b/pkg/suggestion/v1alpha3/chocolate/base_chocolate_service.py @@ -30,7 +30,7 @@ def getSuggestions(self, search_space, trials, request_number): key = BaseChocolateService.encode(param.name) if param.type == INTEGER: chocolate_search_space[key] = choco.quantized_uniform( - int(param.min), int(param.max), 1) + int(param.min), int(param.max), int(param.step)) elif param.type == DOUBLE: chocolate_search_space[key] = choco.quantized_uniform( float(param.min), float(param.max), float(param.step)) diff --git a/pkg/suggestion/v1alpha3/hyperopt/base_hyperopt_service.py b/pkg/suggestion/v1alpha3/hyperopt/base_hyperopt_service.py index 29d6e5a20e0..3674823ba7e 100644 --- a/pkg/suggestion/v1alpha3/hyperopt/base_hyperopt_service.py +++ b/pkg/suggestion/v1alpha3/hyperopt/base_hyperopt_service.py @@ -39,7 +39,8 @@ def create_hyperopt_domain(self): hyperopt_search_space[param.name] = hyperopt.hp.quniform( param.name, float(param.min), - float(param.max), 1) + float(param.max), + float(param.step)) elif param.type == DOUBLE: hyperopt_search_space[param.name] = hyperopt.hp.uniform( param.name, diff --git a/pkg/suggestion/v1alpha3/internal/search_space.py b/pkg/suggestion/v1alpha3/internal/search_space.py index 99007ec32ac..fcc0dcf3375 100644 --- a/pkg/suggestion/v1alpha3/internal/search_space.py +++ b/pkg/suggestion/v1alpha3/internal/search_space.py @@ -37,7 +37,11 @@ def __str__(self): @staticmethod def convertParameter(p): if p.parameter_type == api.INT: - return HyperParameter.int(p.name, p.feasible_space.min, p.feasible_space.max) + # Default value for INT parameter step is 1 + step = 1 + if p.feasible_space.step != None and p.feasible_space.step != "": + step = p.feasible_space.step + return HyperParameter.int(p.name, p.feasible_space.min, p.feasible_space.max, step) elif p.parameter_type == api.DOUBLE: return HyperParameter.double(p.name, p.feasible_space.min, p.feasible_space.max, p.feasible_space.step) elif p.parameter_type == api.CATEGORICAL: @@ -67,8 +71,8 @@ def __str__(self): self.name, self.type, ", ".join(self.list)) @staticmethod - def int(name, min_, max_): - return HyperParameter(name, INTEGER, min_, max_, [], 0) + def int(name, min_, max_, step): + return HyperParameter(name, INTEGER, min_, max_, [], step) @staticmethod def double(name, min_, max_, step):