diff --git a/dlk/dlkmanager/learningTask.go b/dlk/dlkmanager/learningTask.go index f2630fe196e..dec6b942361 100644 --- a/dlk/dlkmanager/learningTask.go +++ b/dlk/dlkmanager/learningTask.go @@ -307,12 +307,12 @@ func (lt *learningTask) run() { } case <-time.After(1 * time.Second): - lt.pollJobs() err := lt.checkPodStatus(podState) if err != nil { fmt.Println(err.Error()) os.Exit(1) } + lt.pollJobs() if lt.nrCompletedWorkers == lt.ltc.NrWorker { state = ltStateCompleted diff --git a/manager/main.go b/manager/main.go index a85aa0ed736..d602c99fb09 100644 --- a/manager/main.go +++ b/manager/main.go @@ -49,6 +49,56 @@ type server struct { StudyChList map[string]studyCh } +func (s *server) saveCompletedModels(studyId string, conf *pb.StudyConfig) error { + ret, err := s.GetSavedModels(context.Background(), &pb.GetSavedModelsRequest{StudyName: conf.Name}) + if err != nil { + log.Printf("GetSavedModels Err %v", err) + return err + } + ts, err := dbIf.GetTrialList(studyId) + if err != nil { + log.Printf("GetTrials Err %v", err) + return err + } + for _, t := range ts { + tid := t.TrialId + tst, err := dbIf.GetTrialStatus(tid) + if err != nil { + log.Printf("GetTrialStatus Err %v", err) + continue + } + isin := false + if tst == pb.TrialState_COMPLETED { + for _, m := range ret.Models { + if m.TrialId == tid { + isin = true + break + } + } + if !isin { + met := make([]*pb.Metrics, len(conf.Metrics)) + for i, mn := range conf.Metrics { + l, _ := dbIf.GetTrialLogs(tid, &kdb.GetTrialLogOpts{Name: mn}) + if len(l) > 0 { + met[i] = &pb.Metrics{Name: mn, Value: l[len(l)-1].Value} + } + } + t, _ := dbIf.GetTrial(tid) + s.SaveModel(context.Background(), &pb.SaveModelRequest{ + Model: &pb.ModelInfo{ + StudyName: conf.Name, + TrialId: tid, + Parameters: t.ParameterSet, + Metrics: met, + }, + }) + log.Printf("Trial %v in Study %v is saved", tid, conf.Name) + } + } + } + return nil +} + func (s *server) trialIteration(conf *pb.StudyConfig, study_id string, sCh studyCh) error { defer delete(s.StudyChList, study_id) defer s.wIF.CleanWorkers(study_id) @@ -85,7 +135,7 @@ func (s *server) trialIteration(conf *pb.StudyConfig, study_id string, sCh study } if r.Completed { log.Printf("Study %v completed.", study_id) - return nil + return s.saveCompletedModels(study_id, conf) } else if len(r.Trials) > 0 { for _, trial := range r.Trials { trial.Status = pb.TrialState_PENDING @@ -112,43 +162,7 @@ func (s *server) trialIteration(conf *pb.StudyConfig, study_id string, sCh study tm.Reset(1 * time.Second) } case <-strtm.C: - ret, err := s.GetSavedModels(context.Background(), &pb.GetSavedModelsRequest{StudyName: conf.Name}) - if err != nil { - log.Printf("GetSavedModels Err %v", err) - } - ts, err := dbIf.GetTrialList(study_id) - if err != nil { - log.Printf("GetTrials Err %v", err) - } - for _, t := range ts { - tid := t.TrialId - tst, err := dbIf.GetTrialStatus(tid) - if err != nil { - log.Printf("GetTrialStatus Err %v", err) - continue - } - if tst == pb.TrialState_COMPLETED { - for _, m := range ret.Models { - if m.TrialId == tid { - met := make([]*pb.Metrics, len(conf.Metrics)) - for i, mn := range conf.Metrics { - l, _ := dbIf.GetTrialLogs(tid, &kdb.GetTrialLogOpts{Name: mn}) - met[i] = &pb.Metrics{Name: mn, Value: l[len(l)-1].Value} - } - t, _ := dbIf.GetTrial(tid) - s.SaveModel(context.Background(), &pb.SaveModelRequest{ - Model: &pb.ModelInfo{ - StudyName: conf.Name, - TrialId: tid, - Parameters: t.ParameterSet, - Metrics: met, - }, - }) - break - } - } - } - } + s.saveCompletedModels(study_id, conf) strtm.Reset(defaultSaveInterval * time.Second) case <-estm.C: @@ -168,7 +182,7 @@ func (s *server) trialIteration(conf *pb.StudyConfig, study_id string, sCh study for _, t := range s.wIF.GetRunningTrials(study_id) { t.Status = pb.TrialState_KILLED } - return nil + return s.saveCompletedModels(study_id, conf) case m := <-sCh.addMetricsCh: conf.Metrics = append(conf.Metrics, m) } diff --git a/manager/modelstore/modeldb.go b/manager/modelstore/modeldb.go index f59f9a22be8..691cd32a295 100644 --- a/manager/modelstore/modeldb.go +++ b/manager/modelstore/modeldb.go @@ -123,6 +123,9 @@ func (m *ModelDB) SaveModel(in *api.SaveModelRequest) error { } md.ID = fres.ModelId for _, met := range in.Model.Metrics { + if met == nil { + continue + } mv, err := strconv.ParseFloat(met.Value, 64) if err != nil { continue