From 0c3303e765f2c3580b3ec80ea89982e20cfe5219 Mon Sep 17 00:00:00 2001 From: JmPotato Date: Thu, 7 Jul 2022 12:35:02 +0800 Subject: [PATCH] client, tests: make client check if the cluster ID matches (#5281) close tikv/pd#5278 Let the client check if the cluster ID matches during the initialization and updating. Signed-off-by: JmPotato --- client/base_client.go | 34 +++++++++++++++++++++++++++++----- client/client.go | 2 ++ client/errs/errno.go | 1 + tests/client/client_test.go | 33 +++++++++++++++++++++++++++++++++ 4 files changed, 65 insertions(+), 5 deletions(-) diff --git a/client/base_client.go b/client/base_client.go index a6b4064af6c..16639bc2218 100644 --- a/client/base_client.go +++ b/client/base_client.go @@ -245,22 +245,46 @@ func (c *baseClient) gcAllocatorLeaderAddr(curAllocatorMap map[string]*pdpb.Memb func (c *baseClient) initClusterID() error { ctx, cancel := context.WithCancel(c.ctx) defer cancel() + var clusterID uint64 for _, u := range c.GetURLs() { members, err := c.getMembers(ctx, u, c.option.timeout) if err != nil || members.GetHeader() == nil { log.Warn("[pd] failed to get cluster id", zap.String("url", u), errs.ZapError(err)) continue } - c.clusterID = members.GetHeader().GetClusterId() - return nil + if clusterID == 0 { + clusterID = members.GetHeader().GetClusterId() + continue + } + failpoint.Inject("skipClusterIDCheck", func() { + failpoint.Continue() + }) + // All URLs passed in should have the same cluster ID. + if members.GetHeader().GetClusterId() != clusterID { + return errors.WithStack(errUnmatchedClusterID) + } + } + // Failed to init the cluster ID. + if clusterID == 0 { + return errors.WithStack(errFailInitClusterID) } - return errors.WithStack(errFailInitClusterID) + c.clusterID = clusterID + return nil } func (c *baseClient) updateMember() error { - for _, u := range c.GetURLs() { + for i, u := range c.GetURLs() { + failpoint.Inject("skipFirstUpdateMember", func() { + if i == 0 { + failpoint.Continue() + } + }) members, err := c.getMembers(c.ctx, u, updateMemberTimeout) - + // Check the cluster ID. + if err == nil && members.GetHeader().GetClusterId() != c.clusterID { + err = errs.ErrClientUpdateMember.FastGenByArgs("cluster id does not match") + } + // Check the TSO Allocator Leader. var errTSO error if err == nil { if members.GetLeader() == nil || len(members.GetLeader().GetClientUrls()) == 0 { diff --git a/client/client.go b/client/client.go index 07598910af4..8d9ed44c77d 100644 --- a/client/client.go +++ b/client/client.go @@ -328,6 +328,8 @@ const ( var LeaderHealthCheckInterval = time.Second var ( + // errUnmatchedClusterID is returned when found a PD with a different cluster ID. + errUnmatchedClusterID = errors.New("[pd] unmatched cluster id") // errFailInitClusterID is returned when failed to load clusterID from all supplied PD addresses. errFailInitClusterID = errors.New("[pd] failed to get cluster id") // errClosing is returned when request is canceled when client is closing. diff --git a/client/errs/errno.go b/client/errs/errno.go index 118b92f4127..ec49a46a7be 100644 --- a/client/errs/errno.go +++ b/client/errs/errno.go @@ -30,6 +30,7 @@ var ( ErrClientGetTSO = errors.Normalize("get TSO failed, %v", errors.RFCCodeText("PD:client:ErrClientGetTSO")) ErrClientGetLeader = errors.Normalize("get leader from %v error", errors.RFCCodeText("PD:client:ErrClientGetLeader")) ErrClientGetMember = errors.Normalize("get member failed", errors.RFCCodeText("PD:client:ErrClientGetMember")) + ErrClientUpdateMember = errors.Normalize("update member failed, %v", errors.RFCCodeText("PD:client:ErrUpdateMember")) ) // grpcutil errors diff --git a/tests/client/client_test.go b/tests/client/client_test.go index 3c79fbbdbfd..0dca158f75e 100644 --- a/tests/client/client_test.go +++ b/tests/client/client_test.go @@ -63,6 +63,39 @@ type client interface { GetAllocatorLeaderURLs() map[string]string } +func TestClientClusterIDCheck(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // Create the cluster #1. + cluster1, err := tests.NewTestCluster(ctx, 3) + re.NoError(err) + defer cluster1.Destroy() + endpoints1 := runServer(re, cluster1) + // Create the cluster #2. + cluster2, err := tests.NewTestCluster(ctx, 3) + re.NoError(err) + defer cluster2.Destroy() + endpoints2 := runServer(re, cluster2) + // Try to create a client with the mixed endpoints. + _, err = pd.NewClientWithContext( + ctx, append(endpoints1, endpoints2...), + pd.SecurityOption{}, pd.WithMaxErrorRetry(1), + ) + re.Error(err) + re.Contains(err.Error(), "unmatched cluster id") + // updateMember should fail due to unmatched cluster ID found. + re.NoError(failpoint.Enable("github.com/tikv/pd/client/skipClusterIDCheck", `return(true)`)) + re.NoError(failpoint.Enable("github.com/tikv/pd/client/skipFirstUpdateMember", `return(true)`)) + _, err = pd.NewClientWithContext(ctx, []string{endpoints1[0], endpoints2[0]}, + pd.SecurityOption{}, pd.WithMaxErrorRetry(1), + ) + re.Error(err) + re.Contains(err.Error(), "ErrClientGetLeader") + re.NoError(failpoint.Disable("github.com/tikv/pd/client/skipFirstUpdateMember")) + re.NoError(failpoint.Disable("github.com/tikv/pd/client/skipClusterIDCheck")) +} + func TestClientLeaderChange(t *testing.T) { re := require.New(t) ctx, cancel := context.WithCancel(context.Background())