Skip to content

Commit

Permalink
Fix modelsave (#52)
Browse files Browse the repository at this point in the history
* refactor Model API

Signed-off-by: YujiOshima <yuji.oshima0x3fd@gmail.com>

* fix ModelStore IF name

Signed-off-by: YujiOshima <yuji.oshima0x3fd@gmail.com>

* fix model save bugs

Signed-off-by: YujiOshima <yuji.oshima0x3fd@gmail.com>

* avoid nil pointer

Signed-off-by: YujiOshima <yuji.oshima0x3fd@gmail.com>
  • Loading branch information
YujiOshima authored and k8s-ci-robot committed Apr 19, 2018
1 parent 5b0929b commit 56e143a
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 40 deletions.
2 changes: 1 addition & 1 deletion dlk/dlkmanager/learningTask.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,12 +307,12 @@ func (lt *learningTask) run() {
}

case <-time.After(1 * time.Second):
lt.pollJobs()
err := lt.checkPodStatus(podState)
if err != nil {
fmt.Println(err.Error())
os.Exit(1)
}
lt.pollJobs()

if lt.nrCompletedWorkers == lt.ltc.NrWorker {
state = ltStateCompleted
Expand Down
92 changes: 53 additions & 39 deletions manager/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,56 @@ type server struct {
StudyChList map[string]studyCh
}

func (s *server) saveCompletedModels(studyId string, conf *pb.StudyConfig) error {
ret, err := s.GetSavedModels(context.Background(), &pb.GetSavedModelsRequest{StudyName: conf.Name})
if err != nil {
log.Printf("GetSavedModels Err %v", err)
return err
}
ts, err := dbIf.GetTrialList(studyId)
if err != nil {
log.Printf("GetTrials Err %v", err)
return err
}
for _, t := range ts {
tid := t.TrialId
tst, err := dbIf.GetTrialStatus(tid)
if err != nil {
log.Printf("GetTrialStatus Err %v", err)
continue
}
isin := false
if tst == pb.TrialState_COMPLETED {
for _, m := range ret.Models {
if m.TrialId == tid {
isin = true
break
}
}
if !isin {
met := make([]*pb.Metrics, len(conf.Metrics))
for i, mn := range conf.Metrics {
l, _ := dbIf.GetTrialLogs(tid, &kdb.GetTrialLogOpts{Name: mn})
if len(l) > 0 {
met[i] = &pb.Metrics{Name: mn, Value: l[len(l)-1].Value}
}
}
t, _ := dbIf.GetTrial(tid)
s.SaveModel(context.Background(), &pb.SaveModelRequest{
Model: &pb.ModelInfo{
StudyName: conf.Name,
TrialId: tid,
Parameters: t.ParameterSet,
Metrics: met,
},
})
log.Printf("Trial %v in Study %v is saved", tid, conf.Name)
}
}
}
return nil
}

func (s *server) trialIteration(conf *pb.StudyConfig, study_id string, sCh studyCh) error {
defer delete(s.StudyChList, study_id)
defer s.wIF.CleanWorkers(study_id)
Expand Down Expand Up @@ -85,7 +135,7 @@ func (s *server) trialIteration(conf *pb.StudyConfig, study_id string, sCh study
}
if r.Completed {
log.Printf("Study %v completed.", study_id)
return nil
return s.saveCompletedModels(study_id, conf)
} else if len(r.Trials) > 0 {
for _, trial := range r.Trials {
trial.Status = pb.TrialState_PENDING
Expand All @@ -112,43 +162,7 @@ func (s *server) trialIteration(conf *pb.StudyConfig, study_id string, sCh study
tm.Reset(1 * time.Second)
}
case <-strtm.C:
ret, err := s.GetSavedModels(context.Background(), &pb.GetSavedModelsRequest{StudyName: conf.Name})
if err != nil {
log.Printf("GetSavedModels Err %v", err)
}
ts, err := dbIf.GetTrialList(study_id)
if err != nil {
log.Printf("GetTrials Err %v", err)
}
for _, t := range ts {
tid := t.TrialId
tst, err := dbIf.GetTrialStatus(tid)
if err != nil {
log.Printf("GetTrialStatus Err %v", err)
continue
}
if tst == pb.TrialState_COMPLETED {
for _, m := range ret.Models {
if m.TrialId == tid {
met := make([]*pb.Metrics, len(conf.Metrics))
for i, mn := range conf.Metrics {
l, _ := dbIf.GetTrialLogs(tid, &kdb.GetTrialLogOpts{Name: mn})
met[i] = &pb.Metrics{Name: mn, Value: l[len(l)-1].Value}
}
t, _ := dbIf.GetTrial(tid)
s.SaveModel(context.Background(), &pb.SaveModelRequest{
Model: &pb.ModelInfo{
StudyName: conf.Name,
TrialId: tid,
Parameters: t.ParameterSet,
Metrics: met,
},
})
break
}
}
}
}
s.saveCompletedModels(study_id, conf)
strtm.Reset(defaultSaveInterval * time.Second)

case <-estm.C:
Expand All @@ -168,7 +182,7 @@ func (s *server) trialIteration(conf *pb.StudyConfig, study_id string, sCh study
for _, t := range s.wIF.GetRunningTrials(study_id) {
t.Status = pb.TrialState_KILLED
}
return nil
return s.saveCompletedModels(study_id, conf)
case m := <-sCh.addMetricsCh:
conf.Metrics = append(conf.Metrics, m)
}
Expand Down
3 changes: 3 additions & 0 deletions manager/modelstore/modeldb.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ func (m *ModelDB) SaveModel(in *api.SaveModelRequest) error {
}
md.ID = fres.ModelId
for _, met := range in.Model.Metrics {
if met == nil {
continue
}
mv, err := strconv.ParseFloat(met.Value, 64)
if err != nil {
continue
Expand Down

0 comments on commit 56e143a

Please sign in to comment.