Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Done] Sync master client between passes and fix recordio split #2948

Merged
merged 22 commits into from
Jul 27, 2017
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
- id: remove-crlf
files: (?!.*third_party)^.*$ | (?!.*book)^.*$
- repo: https://github.com/PaddlePaddle/mirrors-yapf.git
sha: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37
sha: v0.16.2
hooks:
- id: yapf
files: (.*\.(py|bzl)|BUILD|.*\.BUILD|WORKSPACE)$
- repo: https://github.com/pre-commit/pre-commit-hooks
sha: 5bf6c09bfa1297d3692cadd621ef95f1284e33c0
sha: v0.9.1
hooks:
- id: check-added-large-files
- id: check-merge-conflict
Expand All @@ -18,13 +18,15 @@
files: (?!.*third_party)^.*$ | (?!.*book)^.*$
- id: end-of-file-fixer
- repo: https://github.com/PaddlePaddle/clang-format-pre-commit-hook.git
sha: 28c0ea8a67a3e2dbbf4822ef44e85b63a0080a29
sha: 32a6f751000a9b3d22020203b848b487e829a229
hooks:
- id: clang-formater
- repo: https://github.com/PaddlePaddle/pre-commit-golang
sha: 16398aeccf263adaf53b2495eed0406347d76281
sha: 762c7c84af6ddce7cbcbe030217856fd0b5aec46
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果可以过的话,要不要还是把golintgotype(gotype可以检查编译是否通过)加上?

hooks:
- id: go-fmt
types: [go]
- id: gometalinter
types: [go]
- id: go-fmt
types:
- go
- id: gometalinter
types:
- go
9 changes: 5 additions & 4 deletions go/master/c/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,16 @@ func paddle_release_master_client(client C.paddle_master_client) {
}

//export paddle_set_dataset
func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int) C.int {
func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int, passes C.int) C.int {
c := get(client)
var paths []string
var request master.SetDatasetRequest
request.NumPasses = int(passes)
for i := 0; i < int(size); i++ {
ptr := (**C.char)(unsafe.Pointer(uintptr(unsafe.Pointer(path)) + uintptr(i)*unsafe.Sizeof(*path)))
str := C.GoString(*ptr)
paths = append(paths, str)
request.GlobPaths = append(request.GlobPaths, str)
}
err := c.SetDataset(paths)
err := c.SetDataset(request)
if err != nil {
log.Errorln(err)
return C.PADDLE_MASTER_ERROR
Expand Down
10 changes: 5 additions & 5 deletions go/master/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ func (c *Client) getRecords() {
}

for _, chunk := range t.Chunks {
f, err := os.Open(chunk.Path)
if err != nil {
log.Errorln(err)
f, e := os.Open(chunk.Path)
if e != nil {
log.Errorln(e)
continue
}

Expand Down Expand Up @@ -106,8 +106,8 @@ func (c *Client) monitorMaster(addrCh <-chan string) {
//
// SetDataset can be call multiple times from different nodes. But
// only the first call will be honored.
func (c *Client) SetDataset(globPaths []string) error {
return c.conn.Call("Service.SetDataset", globPaths, nil)
func (c *Client) SetDataset(request SetDatasetRequest) error {
return c.conn.Call("Service.SetDataset", request, nil)
}

// getTask gets a new task from the master server.
Expand Down
49 changes: 36 additions & 13 deletions go/master/client_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ func TestGetFinishTask(t *testing.T) {
panic(err)
}
go func(l net.Listener) {
s, err := NewService(&InMemStore{}, chunkPerTask, time.Second, 1)
s, e := NewService(&InMemStore{}, chunkPerTask, time.Second, 1)
if err != nil {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e?err?

Copy link
Contributor

@helinwang helinwang Jul 20, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gongweibao 这里应该是开了go vet shadow检测出这个err shadow了外面的err,所以改成e了。虽然go vet shadow报出来的情况不一定是错误,让它不报错需要改变量名,但开了它可以避免shadow出错的情况(这种情况出bug之后不容易发现),所以我觉得开go vet shadow是个不错的选择。

panic(err)
panic(e)
}

server := rpc.NewServer()
Expand Down Expand Up @@ -89,21 +89,22 @@ func TestGetFinishTask(t *testing.T) {
ch := make(chan string, 1)
ch <- addr
go c.monitorMaster(ch)
err = c.SetDataset([]string{path})
req := SetDatasetRequest{}
req.GlobPaths = []string{path}
req.NumPasses = 10
err = c.SetDataset(req)
if err != nil {
panic(err)
t.Fatal(err)
}

checkOnePass := func(i int) {
var tasks []Task
for idx := 0; idx < totalTask; idx++ {
task, err := c.getTask()
if err != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i)
task, e := c.getTask()
if e != nil {
t.Fatalf("Error: %v, pass: %d\n", e, i)
}
tasks = append(tasks, task)
}

_, err = c.getTask()
if err == nil {
t.Fatalf("Should get error, pass: %d\n", i)
Expand All @@ -120,9 +121,9 @@ func TestGetFinishTask(t *testing.T) {
}

tasks = tasks[1:]
task, err := c.getTask()
if err != nil {
t.Fatal(err)
task, e := c.getTask()
if e.Error() != "no more available task" {
t.Fatal(e)
}
tasks = append(tasks, task)

Expand All @@ -134,7 +135,29 @@ func TestGetFinishTask(t *testing.T) {
}
}

for i := 0; i < 10; i++ {
for i := 0; i < req.NumPasses-1; i++ {
checkOnePass(i)
}
// last pass check all task finish of all passes
for idx := 0; idx < totalTask; idx++ {
task, e := c.getTask()
if e != nil {
t.Fatalf("Error: %v\n", e)
}
err = c.taskFinished(task.Meta.ID)
if idx < totalTask-1 {
if err != nil {
t.Fatal(err)
}
} else {
// FIXME: use error string to identify error
if err.Error() != "all task done" {
t.Fatal(err)
}
}
}
_, e := c.getTask()
if e == nil || e.Error() != "all task done" {
t.Error(e)
}
}
30 changes: 13 additions & 17 deletions go/master/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,27 +76,23 @@ func TestNextRecord(t *testing.T) {
curAddr := make(chan string, 1)
curAddr <- fmt.Sprintf(":%d", p)
c := master.NewClient(curAddr, 10)
err = c.SetDataset([]string{path})
req := master.SetDatasetRequest{}
req.GlobPaths = []string{path}
req.NumPasses = 50
err = c.SetDataset(req)
if err != nil {
panic(err)
t.Fatal(err)
}
for pass := 0; pass < req.NumPasses; pass++ {

for pass := 0; pass < 50; pass++ {
received := make(map[byte]bool)
for i := 0; i < total; i++ {
r, err := c.NextRecord()
if err != nil {
t.Fatal(pass, i, "Read error:", err)
}

if len(r) != 1 {
t.Fatal(pass, i, "Length should be 1.", r)
}
r, err := c.NextRecord()
if err != nil {
t.Fatal(pass, "Read error:", err)
}

if received[r[0]] {
t.Fatal(pass, i, "Received duplicate.", received, r)
}
received[r[0]] = true
if len(r) != 1 {
t.Fatal(pass, "Length should be 1.", r)
}

}
}
60 changes: 47 additions & 13 deletions go/master/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ type Service struct {
mu sync.Mutex
initDone bool
taskQueues taskQueues

numPasses int
currPass int
}

func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
Expand Down Expand Up @@ -106,6 +109,8 @@ func NewService(store Store, chunksPerTask int, timeoutDur time.Duration, failur
s.taskQueues.Pending = make(map[int]taskEntry)
s.ready = make(chan struct{})
s.store = store
s.numPasses = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could be wrong, but from the code I read, people don't usually explicitly initialize 0 values (Go already do it). It's still correct to initialize, but adds few more lines of code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should init with 1? Train always need at least 1 pass.

s.currPass = 0
recovered, err := s.recover()
if err != nil {
return nil, err
Expand Down Expand Up @@ -228,11 +233,19 @@ func readChunks(globPaths []string) ([]Chunk, error) {
return chunks, nil
}

// SetDatasetRequest is a request for setting datasets and numpaasses
type SetDatasetRequest struct {
GlobPaths []string
NumPasses int
}

// SetDataset sets dataset to dispatch for the master server.
//
// SetDataset can be call multiple times. But only the first call will
// be honored.
func (s *Service) SetDataset(globPaths []string, dummy *int) error {
func (s *Service) SetDataset(request SetDatasetRequest, dummy *int) error {
globPaths := request.GlobPaths
s.numPasses = request.NumPasses
if len(globPaths) == 0 {
return errors.New("no dataset specified")
}
Expand Down Expand Up @@ -315,6 +328,25 @@ func (s *Service) logFields() log.Fields {
}
}

// updateQueuePasses check if tasks are done, move to todo to add another pass
// IMPORTANT: must run under lock
func (s *Service) updateQueuePasses() error {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look at the name from the English language perspective, "update pass" only mean "change pass", does not have explicit meaning of "moving to next pass", maybe nextPass could describe the intention better?

if len(s.taskQueues.Todo) == 0 && len(s.taskQueues.Pending) == 0 && len(s.taskQueues.Done) > 0 {
if s.currPass >= s.numPasses-1 {
// FIXME: stop task dispatching by return error
return errors.New("all task done")
}
log.WithFields(s.logFields()).Infof("adding new pass %d", s.currPass)
s.taskQueues.Todo = s.taskQueues.Done
s.taskQueues.Done = nil
if len(s.taskQueues.Failed) > 0 {
s.taskQueues.Todo = append(s.taskQueues.Todo, s.taskQueues.Failed...)
}
s.currPass++
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里已经进入了下一个pass,为何还是return errors.New("no more available task")

另外,既然进入下一个pass的前提条件是len(s.taskQueues.Pending) == 0,是由pending task complete / finish导致的,那感觉不应该跟GetTask相关(现在GetTask调用了updateQueuePasses),而是只跟TaskFinished / TaskFailed有关。

}
return errors.New("no more available task")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can add a const error string

const NoMoreAvailableTask = "no more available task"

}

// GetTask gets a new task from the service.
func (s *Service) GetTask(dummy int, task *Task) error {
select {
Expand All @@ -323,7 +355,6 @@ func (s *Service) GetTask(dummy int, task *Task) error {

s.mu.Lock()
defer s.mu.Unlock()

if len(s.taskQueues.Todo) == 0 {
if len(s.taskQueues.Done) == 0 {
if len(s.taskQueues.Pending) == 0 {
Expand All @@ -344,9 +375,7 @@ func (s *Service) GetTask(dummy int, task *Task) error {
log.WithFields(s.logFields()).Warningln("No more available task.")
return err
}
s.taskQueues.Todo = s.taskQueues.Done
s.taskQueues.Done = nil
log.WithFields(s.logFields()).Infoln("No more todo task, but trainer is requesting task to do. Move all done task to todo.")
return s.updateQueuePasses()
}

t := s.taskQueues.Todo[0]
Expand Down Expand Up @@ -386,18 +415,18 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
delete(s.taskQueues.Pending, taskID)

log.WithFields(s.logFields()).Infof("Task #%d finished.", taskID)

if len(s.taskQueues.Pending) == 0 && len(s.taskQueues.Todo) == 0 {
log.WithFields(s.logFields()).Infoln("No more todo and pending task, start a new pass.")
s.taskQueues.Todo = append(s.taskQueues.Todo, s.taskQueues.Done...)
s.taskQueues.Done = nil
}
// update queues if pass finishes
errPass := s.updateQueuePasses()

err := s.snapshot()
if err != nil {
log.Errorln(err)
}
return err
// return updateQueuePasses errors
if errPass.Error() == "no more available task" {
return nil
}
return errPass
}

// TaskFailed tells the service that a task is failed.
Expand All @@ -416,5 +445,10 @@ func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error {
}

s.processFailedTask(t, meta.Epoch)
return nil
// update queues if pass finishes
errPass := s.updateQueuePasses()
if errPass.Error() == "no more available task" {
Copy link
Contributor

@helinwang helinwang Jul 19, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不是RPC的client端,还是直接比较err更好:

var errNoMoreTask = errors.New("no more available task")

func updateQueuePasses() {
  // ...
  return errNoMoreTask
}

if errPass == errNoMoreTask {
  // ...
}

return nil
}
return errPass
}
11 changes: 7 additions & 4 deletions go/pserver/client/c/test/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def cloud_reader():
print "connecting to master, etcd endpoints: ", etcd_endpoint
master_client = master.client(etcd_endpoint, 5, 64)
master_client.set_dataset(
["/pfs/dlnel/public/dataset/uci_housing/uci_housing-*-of-*"])
["/pfs/dlnel/public/dataset/uci_housing/uci_housing-*"])
while 1:
r, e = master_client.next_record()
if not r:
Expand All @@ -27,10 +27,12 @@ def main():
# network config
x = paddle.layer.data(name='x', type=paddle.data_type.dense_vector(13))
y_predict = paddle.layer.fc(input=x,
param_attr=paddle.attr.Param(name='w'),
param_attr=paddle.attr.Param(
name='w', learning_rate=1e-3),
size=1,
act=paddle.activation.Linear(),
bias_attr=paddle.attr.Param(name='b'))
bias_attr=paddle.attr.Param(
name='b', learning_rate=1e-3))
y = paddle.layer.data(name='y', type=paddle.data_type.dense_vector(1))
cost = paddle.layer.mse_cost(input=y_predict, label=y)

Expand All @@ -40,7 +42,6 @@ def main():
# create optimizer of new remote updater to pserver
optimizer = paddle.optimizer.Momentum(momentum=0)

print "etcd endoint: ", etcd_endpoint
trainer = paddle.trainer.SGD(cost=cost,
parameters=parameters,
update_equation=optimizer,
Expand All @@ -51,6 +52,8 @@ def main():
# event_handler to print training and testing info
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
# FIXME: for cloud data reader, pass number is managed by master
# should print the server side pass number
if event.batch_id % 100 == 0:
print "Pass %d, Batch %d, Cost %f" % (
event.pass_id, event.batch_id, event.cost)
Expand Down
2 changes: 1 addition & 1 deletion go/pserver/client/etcd_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ func (p *EtcdClient) List() []Server {
for {
for i := 0; i < psDesired; i++ {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
cancel()
psKey := pserver.PsPath + strconv.Itoa(i)
log.Debugf("checking %s", psKey)
resp, err := p.client.Get(ctx, psKey)
cancel()
if err != nil {
log.Infof("Get psKey= %s error, %v", psKey, err)
time.Sleep(p.timeout)
Expand Down
Loading