diff --git a/pkg/integration_test/cluster.go b/pkg/integration_test/cluster.go index cfe9c057e1a4..9cd9e2497586 100644 --- a/pkg/integration_test/cluster.go +++ b/pkg/integration_test/cluster.go @@ -117,6 +117,12 @@ func (s *testServer) GetClusterID() uint64 { return s.server.ClusterID() } +func (s *testServer) GetLeader() *pdpb.Member { + s.RLock() + defer s.RUnlock() + return s.server.GetLeader() +} + func (s *testServer) GetClusterVersion() semver.Version { s.RLock() defer s.RUnlock() diff --git a/pkg/integration_test/leader_watch_test.go b/pkg/integration_test/leader_watch_test.go new file mode 100644 index 000000000000..bcaeaeb917ad --- /dev/null +++ b/pkg/integration_test/leader_watch_test.go @@ -0,0 +1,58 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration + +import ( + "context" + "time" + + gofail "github.com/etcd-io/gofail/runtime" + . "github.com/pingcap/check" + "github.com/pingcap/pd/pkg/testutil" +) + +func (s *integrationTestSuite) TestWatcher(c *C) { + c.Parallel() + cluster, err := newTestCluster(1) + c.Assert(err, IsNil) + defer cluster.Destroy() + + err = cluster.RunInitialServers() + c.Assert(err, IsNil) + cluster.WaitLeader() + pd1 := cluster.GetServer(cluster.GetLeader()) + c.Assert(pd1, NotNil) + + pd2, err := cluster.Join() + c.Assert(err, IsNil) + err = pd2.Run(context.TODO()) + c.Assert(err, IsNil) + cluster.WaitLeader() + + time.Sleep(5 * time.Second) + pd3, err := cluster.Join() + c.Assert(err, IsNil) + gofail.Enable("github.com/pingcap/pd/server/delayWatcher", `sleep("20s")`) + err = pd3.Run(context.Background()) + c.Assert(err, IsNil) + time.Sleep(200 * time.Millisecond) + c.Assert(pd3.GetLeader().GetName(), Equals, pd1.GetConfig().Name) + pd1.Stop() + cluster.WaitLeader() + c.Assert(pd2.GetLeader().GetName(), Equals, pd2.GetConfig().Name) + testutil.WaitUntil(c, func(c *C) bool { + return c.Check(pd3.GetLeader().GetName(), Equals, pd2.GetConfig().Name) + }) + c.Succeed() +} diff --git a/server/leader.go b/server/leader.go index b135f8b16b7e..d1ca21565c22 100644 --- a/server/leader.go +++ b/server/leader.go @@ -82,7 +82,7 @@ func (s *Server) leaderLoop() { continue } - leader, err := getLeader(s.client, s.getLeaderPath()) + leader, rev, err := getLeader(s.client, s.getLeaderPath()) if err != nil { log.Errorf("get leader err %v", err) time.Sleep(200 * time.Millisecond) @@ -100,7 +100,7 @@ func (s *Server) leaderLoop() { } } else { log.Infof("leader is %s, watch it", leader) - s.watchLeader(leader) + s.watchLeader(leader, rev) log.Info("leader changed, try to campaign leader") } } @@ -157,17 +157,17 @@ func (s *Server) etcdLeaderLoop() { } // getLeader gets server leader from etcd. -func getLeader(c *clientv3.Client, leaderPath string) (*pdpb.Member, error) { +func getLeader(c *clientv3.Client, leaderPath string) (*pdpb.Member, int64, error) { leader := &pdpb.Member{} - ok, err := getProtoMsg(c, leaderPath, leader) + ok, rev, err := getProtoMsgWithModRev(c, leaderPath, leader) if err != nil { - return nil, err + return nil, 0, err } if !ok { - return nil, nil + return nil, 0, nil } - return leader, nil + return leader, rev, nil } // GetEtcdLeader returns the etcd leader ID. @@ -289,7 +289,7 @@ func (s *Server) campaignLeader() error { } } -func (s *Server) watchLeader(leader *pdpb.Member) { +func (s *Server) watchLeader(leader *pdpb.Member, revision int64) { s.leader.Store(leader) defer s.leader.Store(&pdpb.Member{}) @@ -309,7 +309,8 @@ func (s *Server) watchLeader(leader *pdpb.Member) { } for { - rch := watcher.Watch(ctx, s.getLeaderPath()) + // gofail: var delayWatcher struct{} + rch := watcher.Watch(ctx, s.getLeaderPath(), clientv3.WithRev(revision)) for wresp := range rch { if wresp.Canceled { return diff --git a/server/tso_test.go b/server/tso_test.go index 9214e55fb853..61a02e6df27b 100644 --- a/server/tso_test.go +++ b/server/tso_test.go @@ -182,7 +182,7 @@ func (s *testTimeFallBackSuite) TestTimeFallBack(c *C) { func mustGetLeader(c *C, client *clientv3.Client, leaderPath string) *pdpb.Member { for i := 0; i < 20; i++ { - leader, err := getLeader(client, leaderPath) + leader, _, err := getLeader(client, leaderPath) c.Assert(err, IsNil) if leader != nil { return leader diff --git a/server/util.go b/server/util.go index b1fcfc05eed9..624f35e1390a 100644 --- a/server/util.go +++ b/server/util.go @@ -83,6 +83,17 @@ func CheckPDVersion(opt *scheduleOption) { // A helper function to get value with key from etcd. // TODO: return the value revision for outer use. func getValue(c *clientv3.Client, key string, opts ...clientv3.OpOption) ([]byte, error) { + resp, err := get(c, key, opts...) + if err != nil { + return nil, err + } + if resp == nil { + return nil, nil + } + return resp.Kvs[0].Value, nil +} + +func get(c *clientv3.Client, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { resp, err := kvGet(c, key, opts...) if err != nil { return nil, err @@ -93,26 +104,23 @@ func getValue(c *clientv3.Client, key string, opts ...clientv3.OpOption) ([]byte } else if n > 1 { return nil, errors.Errorf("invalid get value resp %v, must only one", resp.Kvs) } - - return resp.Kvs[0].Value, nil + return resp, nil } // Return boolean to indicate whether the key exists or not. -// TODO: return the value revision for outer use. -func getProtoMsg(c *clientv3.Client, key string, msg proto.Message, opts ...clientv3.OpOption) (bool, error) { - value, err := getValue(c, key, opts...) +func getProtoMsgWithModRev(c *clientv3.Client, key string, msg proto.Message, opts ...clientv3.OpOption) (bool, int64, error) { + resp, err := get(c, key, opts...) if err != nil { - return false, err + return false, 0, err } - if value == nil { - return false, nil + if resp == nil { + return false, 0, nil } - + value := resp.Kvs[0].Value if err = proto.Unmarshal(value, msg); err != nil { - return false, errors.WithStack(err) + return false, 0, errors.WithStack(err) } - - return true, nil + return true, resp.Kvs[0].ModRevision, nil } func initOrGetClusterID(c *clientv3.Client, key string) (uint64, error) {