From 313b0f67bf9900ad9366df9ad62fd96656b92b48 Mon Sep 17 00:00:00 2001 From: chicm-ms <38930155+chicm-ms@users.noreply.github.com> Date: Wed, 9 Oct 2019 15:57:25 +0800 Subject: [PATCH] Fix gp tuner (#1592) * fix gp tuner --- azure-pipelines.yml | 2 +- src/sdk/pynni/nni/gp_tuner/gp_tuner.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index a2932fd217..336d2375b8 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -58,7 +58,7 @@ jobs: - script: | python3 -m pip install torch==0.4.1 --user python3 -m pip install torchvision==0.2.1 --user - python3 -m pip install tensorflow --user + python3 -m pip install tensorflow==1.13.1 --user displayName: 'Install dependencies for integration' - script: | source install.sh diff --git a/src/sdk/pynni/nni/gp_tuner/gp_tuner.py b/src/sdk/pynni/nni/gp_tuner/gp_tuner.py index 3f0b5506bc..5398cd7b13 100644 --- a/src/sdk/pynni/nni/gp_tuner/gp_tuner.py +++ b/src/sdk/pynni/nni/gp_tuner/gp_tuner.py @@ -83,7 +83,7 @@ def update_search_space(self, search_space): """ self._space = TargetSpace(search_space, self._random_state) - def generate_parameters(self, parameter_id): + def generate_parameters(self, parameter_id, **kwargs): """Generate next parameter for trial If the number of trial result is lower than cold start number, gp will first randomly generate some parameters. @@ -123,7 +123,7 @@ def generate_parameters(self, parameter_id): logger.info("Generate paramageters:\n %s", results) return results - def receive_trial_result(self, parameter_id, parameters, value): + def receive_trial_result(self, parameter_id, parameters, value, **kwargs): """Tuner receive result from trial. Parameters