Skip to content

Commit

Permalink
client, tests: make client check if the cluster ID matches (#5281)
Browse files Browse the repository at this point in the history
close #5278

Let the client check if the cluster ID matches during the initialization and updating.

Signed-off-by: JmPotato <ghzpotato@gmail.com>
  • Loading branch information
JmPotato authored Jul 7, 2022
1 parent de5cf5b commit 0c3303e
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 5 deletions.
34 changes: 29 additions & 5 deletions client/base_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,22 +245,46 @@ func (c *baseClient) gcAllocatorLeaderAddr(curAllocatorMap map[string]*pdpb.Memb
func (c *baseClient) initClusterID() error {
ctx, cancel := context.WithCancel(c.ctx)
defer cancel()
var clusterID uint64
for _, u := range c.GetURLs() {
members, err := c.getMembers(ctx, u, c.option.timeout)
if err != nil || members.GetHeader() == nil {
log.Warn("[pd] failed to get cluster id", zap.String("url", u), errs.ZapError(err))
continue
}
c.clusterID = members.GetHeader().GetClusterId()
return nil
if clusterID == 0 {
clusterID = members.GetHeader().GetClusterId()
continue
}
failpoint.Inject("skipClusterIDCheck", func() {
failpoint.Continue()
})
// All URLs passed in should have the same cluster ID.
if members.GetHeader().GetClusterId() != clusterID {
return errors.WithStack(errUnmatchedClusterID)
}
}
// Failed to init the cluster ID.
if clusterID == 0 {
return errors.WithStack(errFailInitClusterID)
}
return errors.WithStack(errFailInitClusterID)
c.clusterID = clusterID
return nil
}

func (c *baseClient) updateMember() error {
for _, u := range c.GetURLs() {
for i, u := range c.GetURLs() {
failpoint.Inject("skipFirstUpdateMember", func() {
if i == 0 {
failpoint.Continue()
}
})
members, err := c.getMembers(c.ctx, u, updateMemberTimeout)

// Check the cluster ID.
if err == nil && members.GetHeader().GetClusterId() != c.clusterID {
err = errs.ErrClientUpdateMember.FastGenByArgs("cluster id does not match")
}
// Check the TSO Allocator Leader.
var errTSO error
if err == nil {
if members.GetLeader() == nil || len(members.GetLeader().GetClientUrls()) == 0 {
Expand Down
2 changes: 2 additions & 0 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,8 @@ const (
var LeaderHealthCheckInterval = time.Second

var (
// errUnmatchedClusterID is returned when found a PD with a different cluster ID.
errUnmatchedClusterID = errors.New("[pd] unmatched cluster id")
// errFailInitClusterID is returned when failed to load clusterID from all supplied PD addresses.
errFailInitClusterID = errors.New("[pd] failed to get cluster id")
// errClosing is returned when request is canceled when client is closing.
Expand Down
1 change: 1 addition & 0 deletions client/errs/errno.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ var (
ErrClientGetTSO = errors.Normalize("get TSO failed, %v", errors.RFCCodeText("PD:client:ErrClientGetTSO"))
ErrClientGetLeader = errors.Normalize("get leader from %v error", errors.RFCCodeText("PD:client:ErrClientGetLeader"))
ErrClientGetMember = errors.Normalize("get member failed", errors.RFCCodeText("PD:client:ErrClientGetMember"))
ErrClientUpdateMember = errors.Normalize("update member failed, %v", errors.RFCCodeText("PD:client:ErrUpdateMember"))
)

// grpcutil errors
Expand Down
33 changes: 33 additions & 0 deletions tests/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,39 @@ type client interface {
GetAllocatorLeaderURLs() map[string]string
}

func TestClientClusterIDCheck(t *testing.T) {
re := require.New(t)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Create the cluster #1.
cluster1, err := tests.NewTestCluster(ctx, 3)
re.NoError(err)
defer cluster1.Destroy()
endpoints1 := runServer(re, cluster1)
// Create the cluster #2.
cluster2, err := tests.NewTestCluster(ctx, 3)
re.NoError(err)
defer cluster2.Destroy()
endpoints2 := runServer(re, cluster2)
// Try to create a client with the mixed endpoints.
_, err = pd.NewClientWithContext(
ctx, append(endpoints1, endpoints2...),
pd.SecurityOption{}, pd.WithMaxErrorRetry(1),
)
re.Error(err)
re.Contains(err.Error(), "unmatched cluster id")
// updateMember should fail due to unmatched cluster ID found.
re.NoError(failpoint.Enable("github.com/tikv/pd/client/skipClusterIDCheck", `return(true)`))
re.NoError(failpoint.Enable("github.com/tikv/pd/client/skipFirstUpdateMember", `return(true)`))
_, err = pd.NewClientWithContext(ctx, []string{endpoints1[0], endpoints2[0]},
pd.SecurityOption{}, pd.WithMaxErrorRetry(1),
)
re.Error(err)
re.Contains(err.Error(), "ErrClientGetLeader")
re.NoError(failpoint.Disable("github.com/tikv/pd/client/skipFirstUpdateMember"))
re.NoError(failpoint.Disable("github.com/tikv/pd/client/skipClusterIDCheck"))
}

func TestClientLeaderChange(t *testing.T) {
re := require.New(t)
ctx, cancel := context.WithCancel(context.Background())
Expand Down

0 comments on commit 0c3303e

Please sign in to comment.