From 4371d9811aa5f4fa6dafec638e65eb1e7462efc8 Mon Sep 17 00:00:00 2001 From: IWAMOTO Toshihiro Date: Wed, 21 Nov 2018 14:05:22 +0900 Subject: [PATCH] Add Update{Study,Trial} Only tested with unit tests. Signed-off-by: IWAMOTO Toshihiro --- pkg/db/interface.go | 54 +++++++++++++++++++++++++++++++++++++--- pkg/db/interface_test.go | 31 +++++++++++++++++++++++ pkg/mock/db/db.go | 24 ++++++++++++++++++ 3 files changed, 105 insertions(+), 4 deletions(-) diff --git a/pkg/db/interface.go b/pkg/db/interface.go index f030ed494fb..1ebdd99c076 100644 --- a/pkg/db/interface.go +++ b/pkg/db/interface.go @@ -46,11 +46,13 @@ type VizierDBInterface interface { GetStudyConfig(string) (*api.StudyConfig, error) GetStudyList() ([]string, error) CreateStudy(*api.StudyConfig) (string, error) + UpdateStudy(string, *api.StudyConfig) error DeleteStudy(string) error GetTrial(string) (*api.Trial, error) GetTrialList(string) ([]*api.Trial, error) CreateTrial(*api.Trial) error + UpdateTrial(*api.Trial) error DeleteTrial(string) error GetWorker(string) (*api.Worker, error) @@ -307,6 +309,30 @@ func (d *dbConn) CreateStudy(in *api.StudyConfig) (string, error) { return studyID, nil } +// UpdateStudy updates the corresponding row in the DB. +// It only updates name, owner, tags and job_id. +// Other columns are silently ignored. +func (d *dbConn) UpdateStudy(studyID string, in *api.StudyConfig) error { + var err error + + tags := make([]string, len(in.Tags)) + for i, elem := range in.Tags { + tags[i], err = (&jsonpb.Marshaler{}).MarshalToString(elem) + if err != nil { + log.Printf("Error marshalling %v: %v", elem, err) + continue + } + } + _, err = d.db.Exec(`UPDATE studies SET name = ?, owner = ?, tags = ?, + job_id = ? WHERE id = ?`, + in.Name, + in.Owner, + strings.Join(tags, ",\n"), + in.JobId, + studyID) + return err +} + func (d *dbConn) DeleteStudy(id string) error { _, err := d.db.Exec("DELETE FROM studies WHERE id = ?", id) return err @@ -395,9 +421,7 @@ func (d *dbConn) GetTrialList(id string) ([]*api.Trial, error) { return trials, err } -func (d *dbConn) CreateTrial(trial *api.Trial) error { - // This function sets trial.id, unlike old dbInsertTrials(). - // Users should not overwrite trial.id +func marshalTrial(trial *api.Trial) ([]string, []string, error) { var err, lastErr error params := make([]string, len(trial.ParameterSet)) @@ -418,12 +442,20 @@ func (d *dbConn) CreateTrial(trial *api.Trial) error { lastErr = err } } + return params, tags, lastErr +} + +// CreateTrial stores into the trials DB table. +// As a side-effect, it generates and sets trial.TrialId. +// Users should not overwrite TrialId. +func (d *dbConn) CreateTrial(trial *api.Trial) error { + params, tags, lastErr := marshalTrial(trial) var trialID string i := 3 for true { trialID = generateRandid() - _, err = d.db.Exec("INSERT INTO trials VALUES (?, ?, ?, ?, ?)", + _, err := d.db.Exec("INSERT INTO trials VALUES (?, ?, ?, ?, ?)", trialID, trial.StudyId, strings.Join(params, ",\n"), trial.ObjectiveValue, strings.Join(tags, ",\n")) if err == nil { @@ -440,6 +472,20 @@ func (d *dbConn) CreateTrial(trial *api.Trial) error { return lastErr } +// UpdateTrial updates the corresponding row in the DB. +// It only updates parameters and tags. Other columns are silently ignored. +func (d *dbConn) UpdateTrial(trial *api.Trial) error { + params, tags, lastErr := marshalTrial(trial) + _, err := d.db.Exec(`UPDATE trials SET parameters = ?, tags = ?, + WHERE id = ?`, + strings.Join(params, ",\n"), strings.Join(tags, ",\n"), + trial.TrialId) + if err != nil { + return err + } + return lastErr +} + func (d *dbConn) DeleteTrial(id string) error { _, err := d.db.Exec("DELETE FROM trials WHERE id = ?", id) return err diff --git a/pkg/db/interface_test.go b/pkg/db/interface_test.go index 23ba07c1742..206ce2cbc5d 100644 --- a/pkg/db/interface_test.go +++ b/pkg/db/interface_test.go @@ -135,6 +135,22 @@ func TestGetStudyList(t *testing.T) { } } +func TestUpdateStudy(t *testing.T) { + studyID := generateRandid() + var in api.StudyConfig + in.Name = "hoge" + in.Owner = "joe" + in.JobId = "foobar123" + + mock.ExpectExec(`UPDATE studies SET name = \?, owner = \?, tags = \?, + job_id = \? WHERE id = \?`, + ).WithArgs(in.Name, in.Owner, "", in.JobId, studyID).WillReturnResult(sqlmock.NewResult(1, 1)) + err := dbInterface.UpdateStudy(studyID, &in) + if err != nil { + t.Errorf("UpdateStudy error %v", err) + } +} + func TestDeleteStudy(t *testing.T) { studyID := generateRandid() mock.ExpectExec(`DELETE FROM studies WHERE id = \?`).WithArgs(studyID).WillReturnResult(sqlmock.NewResult(1, 1)) @@ -225,6 +241,21 @@ func TestCreateTrial(t *testing.T) { } } +func TestUpdateTrial(t *testing.T) { + var trial api.Trial + trial.TrialId = generateRandid() + trial.StudyId = generateRandid() + trial.ParameterSet = make([]*api.Parameter, 1) + trial.ParameterSet[0] = &api.Parameter{Name: "abc"} + mock.ExpectExec(`UPDATE trials SET parameters = \?, tags = \?, + WHERE id = \?`, + ).WithArgs("{\"name\":\"abc\"}", "", trial.TrialId).WillReturnResult(sqlmock.NewResult(1, 1)) + err := dbInterface.UpdateTrial(&trial) + if err != nil { + t.Errorf("UpdateTrial error %v", err) + } +} + func TestDeleteTrial(t *testing.T) { id := generateRandid() mock.ExpectExec(`DELETE FROM trials WHERE id = \?`).WithArgs(id).WillReturnResult(sqlmock.NewResult(1, 1)) diff --git a/pkg/mock/db/db.go b/pkg/mock/db/db.go index a52c3a9ad15..b0b494cce83 100644 --- a/pkg/mock/db/db.go +++ b/pkg/mock/db/db.go @@ -351,6 +351,18 @@ func (mr *MockVizierDBInterfaceMockRecorder) UpdateEarlyStopParam(arg0, arg1 int return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateEarlyStopParam", reflect.TypeOf((*MockVizierDBInterface)(nil).UpdateEarlyStopParam), arg0, arg1) } +// UpdateStudy mocks base method +func (m *MockVizierDBInterface) UpdateStudy(arg0 string, arg1 *api.StudyConfig) error { + ret := m.ctrl.Call(m, "UpdateStudy", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateStudy indicates an expected call of UpdateStudy +func (mr *MockVizierDBInterfaceMockRecorder) UpdateStudy(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateStudy", reflect.TypeOf((*MockVizierDBInterface)(nil).UpdateStudy), arg0, arg1) +} + // UpdateSuggestionParam mocks base method func (m *MockVizierDBInterface) UpdateSuggestionParam(arg0 string, arg1 []*api.SuggestionParameter) error { ret := m.ctrl.Call(m, "UpdateSuggestionParam", arg0, arg1) @@ -363,6 +375,18 @@ func (mr *MockVizierDBInterfaceMockRecorder) UpdateSuggestionParam(arg0, arg1 in return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSuggestionParam", reflect.TypeOf((*MockVizierDBInterface)(nil).UpdateSuggestionParam), arg0, arg1) } +// UpdateTrial mocks base method +func (m *MockVizierDBInterface) UpdateTrial(arg0 *api.Trial) error { + ret := m.ctrl.Call(m, "UpdateTrial", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateTrial indicates an expected call of UpdateTrial +func (mr *MockVizierDBInterfaceMockRecorder) UpdateTrial(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateTrial", reflect.TypeOf((*MockVizierDBInterface)(nil).UpdateTrial), arg0) +} + // UpdateWorker mocks base method func (m *MockVizierDBInterface) UpdateWorker(arg0 string, arg1 api.State) error { ret := m.ctrl.Call(m, "UpdateWorker", arg0, arg1)