-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pserver Save state #2716
Pserver Save state #2716
Conversation
go/pserver/service_test.go
Outdated
@@ -79,6 +79,8 @@ func TestServiceFull(t *testing.T) { | |||
if !reflect.DeepEqual(param1, p) { | |||
t.FailNow() | |||
} | |||
var dummy int | |||
s.Save("", &dummy) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pass in nil
is fine: s.Save("", nil)
. I used s.Save("", &dummy)
before but later realized that it's fine to pass in nil
:)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -166,3 +168,7 @@ func TestBlockUntilInitialized(t *testing.T) { | |||
|
|||
wg.Wait() | |||
} | |||
|
|||
func TestCheckpointSpeed(t *testing.T) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Speed can be tested with benchmark. Here is an example: https://dave.cheney.net/2013/06/30/how-to-write-benchmarks-in-go
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
leave a TODO here, will be tested after reaching an agreement with @Yancey1989 's recover logic.
go/pserver/service.go
Outdated
@@ -38,6 +52,7 @@ type Parameter struct { | |||
type ParameterWithConfig struct { | |||
Param Parameter | |||
Config []byte // parameter configuration in Proto Buffer format | |||
State []byte // parameter training state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ParameterWithConfig
is the data sent from the trainer to the pserver. But State
is saved by pserver, loaded by pserver, which is not related to trainer.
So State
should not be part of this struct.
Maybe:
type checkpoint struct {
Uuid string
Md5sum string
Timestamp string
ParameterWithConfig // this is called embedded field
State []byte
}
embedded field: https://golang.org/ref/spec#Struct_types
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
split into info
and data
part. fix done.
go/pserver/service.go
Outdated
@@ -142,8 +177,51 @@ func (s *Service) GetParam(name string, parameter *Parameter) error { | |||
|
|||
// Save tells the parameter server to save parameters. | |||
func (s *Service) Save(path string, dummy *int) error { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Save was intended for saving model. But now we no longer use pservers to save model. Can you rename save to checkpoint? Also, I think at least for the first implementation, checkpoint should not be exposed as a RPC method to the trainer, instead, pservers periodically checkpoints, so can you make this a private function: func (s *Service) checkpoint(path string) error
? (note that we don't need parameter dummy *int
anymore if it's not used for RPC).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fix Done.
go/pserver/service.go
Outdated
log.Infof("parameter checkpoint %s", ckbytes) | ||
|
||
if _, err = os.Stat(ck.Uuid); os.IsNotExist(err) { | ||
log.Info("checkpoint not exists.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
checkpoint not exists. -> checkpoint does not exist.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fix done.
go/pserver/service.go
Outdated
log.Info("checkpoint not exists.") | ||
} else { | ||
err = os.Remove(ck.Uuid) | ||
log.Infof("remove %s", ck.Uuid) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove %s -> checkpoint %s already exists, removing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fix done.
go/pserver/service.go
Outdated
log.Infof("remove %s", ck.Uuid) | ||
} | ||
f, err := os.Create(ck.Uuid) | ||
defer f.Close() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
defer f.Close()
will close when this function returns, not when the for loop goes to the next loop. And the for loop may be very long. So perhaps call f.Close()
at the end of for loop.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fix done.
go/pserver/service.go
Outdated
@@ -14,6 +24,10 @@ const ( | |||
Uninitialized = "pserver not fully initialized" | |||
) | |||
|
|||
const ( | |||
checkpoint_path = "./checkpoints/" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Go's naming convention is camelCase, not snake_case.
checkpointPath need to be an argument (flag.String
) passed to go/cmd/pserver program. Since the k8s will configure the path.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
go/pserver/service.go
Outdated
log.Errorln(err) | ||
} | ||
// TODO: according design doc, need to save Uuid to etcd in json format | ||
// {\"Uuid\": [UUID], \"md5\", \"MD5 sum\", \"Timestamp\": xxxx} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the design doc mentioned using etcd to save checkpoint information as well. Maybe add a TODO?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add etcd saving logic. fix done.
go/pserver/service.go
Outdated
} | ||
|
||
//serialize ParameterWithConfig to byte stream | ||
func GetBytes(content ...interface{}) ([]byte, error) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Making content ...interface{}
adds more complexity to the code: since it's interface type that we need to encode, we have to call gob.Register
. It's harder to understand the code (people need to search for what does gob.Register
do. And it's harder to maintain the code (whenever adds a new type for GetBytes
to use, maintainer need to add gob.Register
as well, it's hard to track.
Since here we only need to call GetBytes
twice, and this function does not have much code. Maybe just put it inline? (and remove gob.Register
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fix Done.
go/pserver/service.go
Outdated
@@ -52,14 +67,34 @@ type Service struct { | |||
optMap map[string]*optimizer | |||
} | |||
|
|||
type checkpoint struct { | |||
Uuid string |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding json: "uuid"
at the end of the line, so we can use Json.marshal
to a format JSON.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agree with another PR. fix Done.
go/pserver/service.go
Outdated
err = os.Remove(ck.Uuid) | ||
log.Infof("remove %s", ck.Uuid) | ||
} | ||
f, err := os.Create(ck.Uuid) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will create so many files for each paramter. Following the design doc, we will only have one checkpoint file named UUID?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fix done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM++ except small comments.
go/cmd/pserver/pserver.go
Outdated
@@ -20,6 +20,8 @@ func main() { | |||
"comma separated endpoint string for pserver to connect to etcd") | |||
etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls") | |||
numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job") | |||
checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path") | |||
checkpointInterval := flag.Int("checkpoint-interval", 10, "save checkpoint per interval seconds") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Default 10 seconds maybe too quick?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, fix interval is not proper for every training job. Time consumed always determined by training data amount. Round
count may be better here.
Change it to 10 min(600seconds)
go/pserver/service.go
Outdated
} | ||
|
||
// Checkpoint is the pserver shard persist in file | ||
type Checkpoint []parameterCheckPoint |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Exported type is an array of unexported type, maybe inconvenience to use.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
fix #2566