-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Done] Sync master client between passes and fix recordio split #2948
Changes from 6 commits
6cea7ba
c950b73
0391bf5
5a402b5
30adaa8
56309b2
9215501
e3d7c22
419d553
31bf3fb
bd2a610
149ced5
ec6b16e
270bdb3
8c0755b
9891627
c1e8c9b
f07dc95
fb9f810
7d2d744
cc45124
ebb007f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,9 +40,9 @@ func TestGetFinishTask(t *testing.T) { | |
panic(err) | ||
} | ||
go func(l net.Listener) { | ||
s, err := NewService(&InMemStore{}, chunkPerTask, time.Second, 1) | ||
s, e := NewService(&InMemStore{}, chunkPerTask, time.Second, 1) | ||
if err != nil { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. e?err? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @gongweibao 这里应该是开了go vet shadow检测出这个 |
||
panic(err) | ||
panic(e) | ||
} | ||
|
||
server := rpc.NewServer() | ||
|
@@ -89,21 +89,22 @@ func TestGetFinishTask(t *testing.T) { | |
ch := make(chan string, 1) | ||
ch <- addr | ||
go c.monitorMaster(ch) | ||
err = c.SetDataset([]string{path}) | ||
req := SetDatasetRequest{} | ||
req.GlobPaths = []string{path} | ||
req.NumPasses = 10 | ||
err = c.SetDataset(req) | ||
if err != nil { | ||
panic(err) | ||
t.Fatal(err) | ||
} | ||
|
||
checkOnePass := func(i int) { | ||
var tasks []Task | ||
for idx := 0; idx < totalTask; idx++ { | ||
task, err := c.getTask() | ||
if err != nil { | ||
t.Fatalf("Error: %v, pass: %d\n", err, i) | ||
task, e := c.getTask() | ||
if e != nil { | ||
t.Fatalf("Error: %v, pass: %d\n", e, i) | ||
} | ||
tasks = append(tasks, task) | ||
} | ||
|
||
_, err = c.getTask() | ||
if err == nil { | ||
t.Fatalf("Should get error, pass: %d\n", i) | ||
|
@@ -120,9 +121,9 @@ func TestGetFinishTask(t *testing.T) { | |
} | ||
|
||
tasks = tasks[1:] | ||
task, err := c.getTask() | ||
if err != nil { | ||
t.Fatal(err) | ||
task, e := c.getTask() | ||
if e.Error() != "no more available task" { | ||
t.Fatal(e) | ||
} | ||
tasks = append(tasks, task) | ||
|
||
|
@@ -134,7 +135,29 @@ func TestGetFinishTask(t *testing.T) { | |
} | ||
} | ||
|
||
for i := 0; i < 10; i++ { | ||
for i := 0; i < req.NumPasses-1; i++ { | ||
checkOnePass(i) | ||
} | ||
// last pass check all task finish of all passes | ||
for idx := 0; idx < totalTask; idx++ { | ||
task, e := c.getTask() | ||
if e != nil { | ||
t.Fatalf("Error: %v\n", e) | ||
} | ||
err = c.taskFinished(task.Meta.ID) | ||
if idx < totalTask-1 { | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
} else { | ||
// FIXME: use error string to identify error | ||
if err.Error() != "all task done" { | ||
t.Fatal(err) | ||
} | ||
} | ||
} | ||
_, e := c.getTask() | ||
if e == nil || e.Error() != "all task done" { | ||
t.Error(e) | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -67,6 +67,9 @@ type Service struct { | |
mu sync.Mutex | ||
initDone bool | ||
taskQueues taskQueues | ||
|
||
numPasses int | ||
currPass int | ||
} | ||
|
||
func partition(chunks []Chunk, chunksPerTask int) []taskEntry { | ||
|
@@ -106,6 +109,8 @@ func NewService(store Store, chunksPerTask int, timeoutDur time.Duration, failur | |
s.taskQueues.Pending = make(map[int]taskEntry) | ||
s.ready = make(chan struct{}) | ||
s.store = store | ||
s.numPasses = 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I could be wrong, but from the code I read, people don't usually explicitly initialize 0 values (Go already do it). It's still correct to initialize, but adds few more lines of code. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should init with 1? Train always need at least 1 pass. |
||
s.currPass = 0 | ||
recovered, err := s.recover() | ||
if err != nil { | ||
return nil, err | ||
|
@@ -228,11 +233,19 @@ func readChunks(globPaths []string) ([]Chunk, error) { | |
return chunks, nil | ||
} | ||
|
||
// SetDatasetRequest is a request for setting datasets and numpaasses | ||
type SetDatasetRequest struct { | ||
GlobPaths []string | ||
NumPasses int | ||
} | ||
|
||
// SetDataset sets dataset to dispatch for the master server. | ||
// | ||
// SetDataset can be call multiple times. But only the first call will | ||
// be honored. | ||
func (s *Service) SetDataset(globPaths []string, dummy *int) error { | ||
func (s *Service) SetDataset(request SetDatasetRequest, dummy *int) error { | ||
globPaths := request.GlobPaths | ||
s.numPasses = request.NumPasses | ||
if len(globPaths) == 0 { | ||
return errors.New("no dataset specified") | ||
} | ||
|
@@ -315,6 +328,25 @@ func (s *Service) logFields() log.Fields { | |
} | ||
} | ||
|
||
// updateQueuePasses check if tasks are done, move to todo to add another pass | ||
// IMPORTANT: must run under lock | ||
func (s *Service) updateQueuePasses() error { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Look at the name from the English language perspective, "update pass" only mean "change pass", does not have explicit meaning of "moving to next pass", maybe |
||
if len(s.taskQueues.Todo) == 0 && len(s.taskQueues.Pending) == 0 && len(s.taskQueues.Done) > 0 { | ||
if s.currPass >= s.numPasses-1 { | ||
// FIXME: stop task dispatching by return error | ||
return errors.New("all task done") | ||
} | ||
log.WithFields(s.logFields()).Infof("adding new pass %d", s.currPass) | ||
s.taskQueues.Todo = s.taskQueues.Done | ||
s.taskQueues.Done = nil | ||
if len(s.taskQueues.Failed) > 0 { | ||
s.taskQueues.Todo = append(s.taskQueues.Todo, s.taskQueues.Failed...) | ||
} | ||
s.currPass++ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里已经进入了下一个pass,为何还是 另外,既然进入下一个pass的前提条件是 |
||
} | ||
return errors.New("no more available task") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we can add a const error string const NoMoreAvailableTask = "no more available task" |
||
} | ||
|
||
// GetTask gets a new task from the service. | ||
func (s *Service) GetTask(dummy int, task *Task) error { | ||
select { | ||
|
@@ -323,7 +355,6 @@ func (s *Service) GetTask(dummy int, task *Task) error { | |
|
||
s.mu.Lock() | ||
defer s.mu.Unlock() | ||
|
||
if len(s.taskQueues.Todo) == 0 { | ||
if len(s.taskQueues.Done) == 0 { | ||
if len(s.taskQueues.Pending) == 0 { | ||
|
@@ -344,9 +375,7 @@ func (s *Service) GetTask(dummy int, task *Task) error { | |
log.WithFields(s.logFields()).Warningln("No more available task.") | ||
return err | ||
} | ||
s.taskQueues.Todo = s.taskQueues.Done | ||
s.taskQueues.Done = nil | ||
log.WithFields(s.logFields()).Infoln("No more todo task, but trainer is requesting task to do. Move all done task to todo.") | ||
return s.updateQueuePasses() | ||
} | ||
|
||
t := s.taskQueues.Todo[0] | ||
|
@@ -386,18 +415,18 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error { | |
delete(s.taskQueues.Pending, taskID) | ||
|
||
log.WithFields(s.logFields()).Infof("Task #%d finished.", taskID) | ||
|
||
if len(s.taskQueues.Pending) == 0 && len(s.taskQueues.Todo) == 0 { | ||
log.WithFields(s.logFields()).Infoln("No more todo and pending task, start a new pass.") | ||
s.taskQueues.Todo = append(s.taskQueues.Todo, s.taskQueues.Done...) | ||
s.taskQueues.Done = nil | ||
} | ||
// update queues if pass finishes | ||
errPass := s.updateQueuePasses() | ||
|
||
err := s.snapshot() | ||
if err != nil { | ||
log.Errorln(err) | ||
} | ||
return err | ||
// return updateQueuePasses errors | ||
if errPass.Error() == "no more available task" { | ||
return nil | ||
} | ||
return errPass | ||
} | ||
|
||
// TaskFailed tells the service that a task is failed. | ||
|
@@ -416,5 +445,10 @@ func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error { | |
} | ||
|
||
s.processFailedTask(t, meta.Epoch) | ||
return nil | ||
// update queues if pass finishes | ||
errPass := s.updateQueuePasses() | ||
if errPass.Error() == "no more available task" { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里不是RPC的client端,还是直接比较 var errNoMoreTask = errors.New("no more available task")
func updateQueuePasses() {
// ...
return errNoMoreTask
}
if errPass == errNoMoreTask {
// ...
} |
||
return nil | ||
} | ||
return errPass | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果可以过的话,要不要还是把
golint
和gotype
(gotype可以检查编译是否通过)加上?