diff --git a/client/client.go b/client/client.go index bdde1a1b675..6781182a44b 100644 --- a/client/client.go +++ b/client/client.go @@ -33,28 +33,15 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/tikv/pd/client/caller" "github.com/tikv/pd/client/clients/metastorage" + "github.com/tikv/pd/client/constants" "github.com/tikv/pd/client/errs" "github.com/tikv/pd/client/metrics" "github.com/tikv/pd/client/opt" + sd "github.com/tikv/pd/client/servicediscovery" "github.com/tikv/pd/client/utils/tlsutil" "go.uber.org/zap" ) -const ( - // defaultKeyspaceID is the default key space id. - // Valid keyspace id range is [0, 0xFFFFFF](uint24max, or 16777215) - // ​0 is reserved for default keyspace with the name "DEFAULT", It's initialized - // when PD bootstrap and reserved for users who haven't been assigned keyspace. - defaultKeyspaceID = uint32(0) - maxKeyspaceID = uint32(0xFFFFFF) - // nullKeyspaceID is used for api v1 or legacy path where is keyspace agnostic. - nullKeyspaceID = uint32(0xFFFFFFFF) - // defaultKeySpaceGroupID is the default key space group id. - // We also reserved 0 for the keyspace group for the same purpose. - defaultKeySpaceGroupID = uint32(0) - defaultKeyspaceName = "DEFAULT" -) - // Region contains information of a region's meta and its peers. type Region struct { Meta *metapb.Region @@ -175,7 +162,7 @@ type Client interface { // syncing leader from server. GetLeaderURL() string // GetServiceDiscovery returns ServiceDiscovery - GetServiceDiscovery() ServiceDiscovery + GetServiceDiscovery() sd.ServiceDiscovery // UpdateOption updates the client option. UpdateOption(option opt.DynamicOption, value any) error @@ -184,19 +171,6 @@ type Client interface { Close() } -var ( - // errUnmatchedClusterID is returned when found a PD with a different cluster ID. - errUnmatchedClusterID = errors.New("[pd] unmatched cluster id") - // errFailInitClusterID is returned when failed to load clusterID from all supplied PD addresses. - errFailInitClusterID = errors.New("[pd] failed to get cluster id") - // errClosing is returned when request is canceled when client is closing. - errClosing = errors.New("[pd] closing") - // errTSOLength is returned when the number of response timestamps is inconsistent with request. - errTSOLength = errors.New("[pd] tso length in rpc response is incorrect") - // errInvalidRespHeader is returned when the response doesn't contain service mode info unexpectedly. - errNoServiceModeReturned = errors.New("[pd] no service mode returned") -) - var _ Client = (*client)(nil) // serviceModeKeeper is for service mode switching. @@ -206,7 +180,7 @@ type serviceModeKeeper struct { sync.RWMutex serviceMode pdpb.ServiceMode tsoClient *tsoClient - tsoSvcDiscovery ServiceDiscovery + tsoSvcDiscovery sd.ServiceDiscovery } func (k *serviceModeKeeper) close() { @@ -289,7 +263,7 @@ func NewClientWithContext( security SecurityOption, opts ...opt.ClientOption, ) (Client, error) { return createClientWithKeyspace(ctx, callerComponent, - nullKeyspaceID, svrAddrs, security, opts...) + constants.NullKeyspaceID, svrAddrs, security, opts...) } // NewClientWithKeyspace creates a client with context and the specified keyspace id. @@ -300,9 +274,9 @@ func NewClientWithKeyspace( keyspaceID uint32, svrAddrs []string, security SecurityOption, opts ...opt.ClientOption, ) (Client, error) { - if keyspaceID < defaultKeyspaceID || keyspaceID > maxKeyspaceID { + if keyspaceID < constants.DefaultKeyspaceID || keyspaceID > constants.MaxKeyspaceID { return nil, errors.Errorf("invalid keyspace id %d. It must be in the range of [%d, %d]", - keyspaceID, defaultKeyspaceID, maxKeyspaceID) + keyspaceID, constants.DefaultKeyspaceID, constants.MaxKeyspaceID) } return createClientWithKeyspace(ctx, callerComponent, keyspaceID, svrAddrs, security, opts...) @@ -392,7 +366,7 @@ type apiContextV2 struct { // NewAPIContextV2 creates a API context with the specified keyspace name for V2. func NewAPIContextV2(keyspaceName string) APIContext { if len(keyspaceName) == 0 { - keyspaceName = defaultKeyspaceName + keyspaceName = constants.DefaultKeyspaceName } return &apiContextV2{keyspaceName: keyspaceName} } @@ -452,7 +426,7 @@ func newClientWithKeyspaceName( inner: &innerClient{ // Create a PD service discovery with null keyspace id, then query the real id with the keyspace name, // finally update the keyspace id to the PD service discovery for the following interactions. - keyspaceID: nullKeyspaceID, + keyspaceID: constants.NullKeyspaceID, updateTokenConnectionCh: make(chan struct{}, 1), ctx: clientCtx, cancel: clientCancel, @@ -511,7 +485,7 @@ func (c *client) GetLeaderURL() string { } // GetServiceDiscovery returns the client-side service discovery object -func (c *client) GetServiceDiscovery() ServiceDiscovery { +func (c *client) GetServiceDiscovery() sd.ServiceDiscovery { return c.inner.pdSvcDiscovery } @@ -1277,17 +1251,6 @@ func (c *client) scatterRegionsWithOptions(ctx context.Context, regionsID []uint return resp, nil } -const ( - httpSchemePrefix = "http://" - httpsSchemePrefix = "https://" -) - -func trimHTTPPrefix(str string) string { - str = strings.TrimPrefix(str, httpSchemePrefix) - str = strings.TrimPrefix(str, httpsSchemePrefix) - return str -} - // LoadGlobalConfig implements the RPCClient interface. func (c *client) LoadGlobalConfig(ctx context.Context, names []string, configPath string) ([]GlobalConfigItem, int64, error) { ctx, cancel := context.WithTimeout(ctx, c.inner.option.Timeout) diff --git a/client/client_test.go b/client/client_test.go index 234bb2da10a..8b4cc2242ca 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -20,14 +20,12 @@ import ( "time" "github.com/pingcap/errors" - "github.com/pingcap/kvproto/pkg/pdpb" "github.com/stretchr/testify/require" "github.com/tikv/pd/client/caller" "github.com/tikv/pd/client/opt" "github.com/tikv/pd/client/utils/testutil" "github.com/tikv/pd/client/utils/tsoutil" "go.uber.org/goleak" - "google.golang.org/grpc" ) func TestMain(m *testing.M) { @@ -43,36 +41,6 @@ func TestTSLessEqual(t *testing.T) { re.True(tsoutil.TSLessEqual(9, 6, 9, 8)) } -func TestUpdateURLs(t *testing.T) { - re := require.New(t) - members := []*pdpb.Member{ - {Name: "pd4", ClientUrls: []string{"tmp://pd4"}}, - {Name: "pd1", ClientUrls: []string{"tmp://pd1"}}, - {Name: "pd3", ClientUrls: []string{"tmp://pd3"}}, - {Name: "pd2", ClientUrls: []string{"tmp://pd2"}}, - } - getURLs := func(ms []*pdpb.Member) (urls []string) { - for _, m := range ms { - urls = append(urls, m.GetClientUrls()[0]) - } - return - } - cli := &pdServiceDiscovery{option: opt.NewOption()} - cli.urls.Store([]string{}) - cli.updateURLs(members[1:]) - re.Equal(getURLs([]*pdpb.Member{members[1], members[3], members[2]}), cli.GetServiceURLs()) - cli.updateURLs(members[1:]) - re.Equal(getURLs([]*pdpb.Member{members[1], members[3], members[2]}), cli.GetServiceURLs()) - cli.updateURLs(members) - re.Equal(getURLs([]*pdpb.Member{members[1], members[3], members[2], members[0]}), cli.GetServiceURLs()) - cli.updateURLs(members[1:]) - re.Equal(getURLs([]*pdpb.Member{members[1], members[3], members[2]}), cli.GetServiceURLs()) - cli.updateURLs(members[2:]) - re.Equal(getURLs([]*pdpb.Member{members[3], members[2]}), cli.GetServiceURLs()) - cli.updateURLs(members[3:]) - re.Equal(getURLs([]*pdpb.Member{members[3]}), cli.GetServiceURLs()) -} - const testClientURL = "tmp://test.url:5255" func TestClientCtx(t *testing.T) { @@ -95,25 +63,6 @@ func TestClientWithRetry(t *testing.T) { re.Less(time.Since(start), time.Second*10) } -func TestGRPCDialOption(t *testing.T) { - re := require.New(t) - start := time.Now() - ctx, cancel := context.WithTimeout(context.TODO(), 500*time.Millisecond) - defer cancel() - cli := &pdServiceDiscovery{ - checkMembershipCh: make(chan struct{}, 1), - ctx: ctx, - cancel: cancel, - tlsCfg: nil, - option: opt.NewOption(), - } - cli.urls.Store([]string{testClientURL}) - cli.option.GRPCDialOptions = []grpc.DialOption{grpc.WithBlock()} - err := cli.updateMember() - re.Error(err) - re.Greater(time.Since(start), 500*time.Millisecond) -} - func TestTsoRequestWait(t *testing.T) { re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) diff --git a/client/constants/constants.go b/client/constants/constants.go new file mode 100644 index 00000000000..10963dd10b6 --- /dev/null +++ b/client/constants/constants.go @@ -0,0 +1,32 @@ +// Copyright 2024 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package constants + +const ( + // DefaultKeyspaceID is the default keyspace ID. + // Valid keyspace id range is [0, 0xFFFFFF](uint24max, or 16777215) + // ​0 is reserved for default keyspace with the name "DEFAULT", It's initialized + // when PD bootstrap and reserved for users who haven't been assigned keyspace. + DefaultKeyspaceID = uint32(0) + // MaxKeyspaceID is the maximum keyspace ID. + MaxKeyspaceID = uint32(0xFFFFFF) + // NullKeyspaceID is used for API v1 or legacy path where is keyspace agnostic. + NullKeyspaceID = uint32(0xFFFFFFFF) + // DefaultKeyspaceGroupID is the default key space group id. + // We also reserved 0 for the keyspace group for the same purpose. + DefaultKeyspaceGroupID = uint32(0) + // DefaultKeyspaceName is the default keyspace name. + DefaultKeyspaceName = "DEFAULT" +) diff --git a/client/errs/errno.go b/client/errs/errno.go index 95c6bffdfa4..df8b677525a 100644 --- a/client/errs/errno.go +++ b/client/errs/errno.go @@ -36,6 +36,20 @@ const ( NotPrimaryErr = "not primary" ) +// internal errors +var ( + // ErrUnmatchedClusterID is returned when found a PD with a different cluster ID. + ErrUnmatchedClusterID = errors.New("[pd] unmatched cluster id") + // ErrFailInitClusterID is returned when failed to load clusterID from all supplied PD addresses. + ErrFailInitClusterID = errors.New("[pd] failed to get cluster id") + // ErrClosing is returned when request is canceled when client is closing. + ErrClosing = errors.New("[pd] closing") + // ErrTSOLength is returned when the number of response timestamps is inconsistent with request. + ErrTSOLength = errors.New("[pd] tso length in rpc response is incorrect") + // ErrNoServiceModeReturned is returned when the response doesn't contain service mode info unexpectedly. + ErrNoServiceModeReturned = errors.New("[pd] no service mode returned") +) + // client errors var ( ErrClientGetProtoClient = errors.Normalize("failed to get proto client", errors.RFCCodeText("PD:client:ErrClientGetProtoClient")) diff --git a/client/errs/errs.go b/client/errs/errs.go index da333efda4c..67a5dd8ec92 100644 --- a/client/errs/errs.go +++ b/client/errs/errs.go @@ -20,6 +20,7 @@ import ( "github.com/pingcap/errors" "go.uber.org/zap" "go.uber.org/zap/zapcore" + "google.golang.org/grpc/codes" ) // IsLeaderChange will determine whether there is a leader/primary change. @@ -38,6 +39,11 @@ func IsLeaderChange(err error) bool { strings.Contains(errMsg, NotPrimaryErr) } +// IsNetworkError returns true if the error is a network error. +func IsNetworkError(code codes.Code) bool { + return code == codes.Unavailable || code == codes.DeadlineExceeded +} + // ZapError is used to make the log output easier. func ZapError(err error, causeError ...error) zap.Field { if err == nil { diff --git a/client/http/client.go b/client/http/client.go index 123ca616422..9c522d87286 100644 --- a/client/http/client.go +++ b/client/http/client.go @@ -26,9 +26,9 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/log" "github.com/prometheus/client_golang/prometheus" - pd "github.com/tikv/pd/client" "github.com/tikv/pd/client/errs" "github.com/tikv/pd/client/retry" + sd "github.com/tikv/pd/client/servicediscovery" "go.uber.org/zap" ) @@ -56,7 +56,7 @@ type clientInner struct { ctx context.Context cancel context.CancelFunc - sd pd.ServiceDiscovery + sd sd.ServiceDiscovery // source is used to mark the source of the client creation, // it will also be used in the caller ID of the inner client. @@ -74,7 +74,7 @@ func newClientInner(ctx context.Context, cancel context.CancelFunc, source strin return &clientInner{ctx: ctx, cancel: cancel, source: source} } -func (ci *clientInner) init(sd pd.ServiceDiscovery) { +func (ci *clientInner) init(sd sd.ServiceDiscovery) { // Init the HTTP client if it's not configured. if ci.cli == nil { ci.cli = &http.Client{Timeout: defaultTimeout} @@ -305,7 +305,7 @@ func WithMetrics( // NewClientWithServiceDiscovery creates a PD HTTP client with the given PD service discovery. func NewClientWithServiceDiscovery( source string, - sd pd.ServiceDiscovery, + sd sd.ServiceDiscovery, opts ...ClientOption, ) Client { ctx, cancel := context.WithCancel(context.Background()) @@ -330,7 +330,7 @@ func NewClient( for _, opt := range opts { opt(c) } - sd := pd.NewDefaultPDServiceDiscovery(ctx, cancel, pdAddrs, c.inner.tlsConf) + sd := sd.NewDefaultPDServiceDiscovery(ctx, cancel, pdAddrs, c.inner.tlsConf) if err := sd.Init(); err != nil { log.Error("[pd] init service discovery failed", zap.String("source", source), zap.Strings("pd-addrs", pdAddrs), zap.Error(err)) @@ -430,7 +430,7 @@ func newClientWithMockServiceDiscovery( for _, opt := range opts { opt(c) } - sd := pd.NewMockPDServiceDiscovery(pdAddrs, c.inner.tlsConf) + sd := sd.NewMockPDServiceDiscovery(pdAddrs, c.inner.tlsConf) if err := sd.Init(); err != nil { log.Error("[pd] init mock service discovery failed", zap.String("source", source), zap.Strings("pd-addrs", pdAddrs), zap.Error(err)) diff --git a/client/inner_client.go b/client/inner_client.go index 62fcd84dd5d..467d6b66352 100644 --- a/client/inner_client.go +++ b/client/inner_client.go @@ -12,6 +12,7 @@ import ( "github.com/tikv/pd/client/errs" "github.com/tikv/pd/client/metrics" "github.com/tikv/pd/client/opt" + sd "github.com/tikv/pd/client/servicediscovery" "go.uber.org/zap" "google.golang.org/grpc" ) @@ -24,7 +25,7 @@ const ( type innerClient struct { keyspaceID uint32 svrUrls []string - pdSvcDiscovery *pdServiceDiscovery + pdSvcDiscovery sd.ServiceDiscovery tokenDispatcher *tokenDispatcher // For service mode switching. @@ -40,8 +41,8 @@ type innerClient struct { option *opt.Option } -func (c *innerClient) init(updateKeyspaceIDCb updateKeyspaceIDFunc) error { - c.pdSvcDiscovery = newPDServiceDiscovery( +func (c *innerClient) init(updateKeyspaceIDCb sd.UpdateKeyspaceIDFunc) error { + c.pdSvcDiscovery = sd.NewPDServiceDiscovery( c.ctx, c.cancel, &c.wg, c.setServiceMode, updateKeyspaceIDCb, c.keyspaceID, c.svrUrls, c.tlsCfg, c.option) if err := c.setup(); err != nil { @@ -83,14 +84,14 @@ func (c *innerClient) resetTSOClientLocked(mode pdpb.ServiceMode) { // Re-create a new TSO client. var ( newTSOCli *tsoClient - newTSOSvcDiscovery ServiceDiscovery + newTSOSvcDiscovery sd.ServiceDiscovery ) switch mode { case pdpb.ServiceMode_PD_SVC_MODE: newTSOCli = newTSOClient(c.ctx, c.option, c.pdSvcDiscovery, &pdTSOStreamBuilderFactory{}) case pdpb.ServiceMode_API_SVC_MODE: - newTSOSvcDiscovery = newTSOServiceDiscovery( + newTSOSvcDiscovery = sd.NewTSOServiceDiscovery( c.ctx, c, c.pdSvcDiscovery, c.keyspaceID, c.tlsCfg, c.option) // At this point, the keyspace group isn't known yet. Starts from the default keyspace group, @@ -152,7 +153,7 @@ func (c *innerClient) close() { c.pdSvcDiscovery.Close() if c.tokenDispatcher != nil { - tokenErr := errors.WithStack(errClosing) + tokenErr := errors.WithStack(errs.ErrClosing) c.tokenDispatcher.tokenBatchController.revokePendingTokenRequest(tokenErr) c.tokenDispatcher.dispatcherCancel() } @@ -179,10 +180,10 @@ func (c *innerClient) setup() error { // getClientAndContext returns the leader pd client and the original context. If leader is unhealthy, it returns // follower pd client and the context which holds forward information. -func (c *innerClient) getRegionAPIClientAndContext(ctx context.Context, allowFollower bool) (ServiceClient, context.Context) { - var serviceClient ServiceClient +func (c *innerClient) getRegionAPIClientAndContext(ctx context.Context, allowFollower bool) (sd.ServiceClient, context.Context) { + var serviceClient sd.ServiceClient if allowFollower { - serviceClient = c.pdSvcDiscovery.getServiceClientByKind(regionAPIKind) + serviceClient = c.pdSvcDiscovery.GetServiceClientByKind(sd.UniversalAPIKind) if serviceClient != nil { return serviceClient, serviceClient.BuildGRPCTargetContext(ctx, !allowFollower) } @@ -202,7 +203,7 @@ func (c *innerClient) gRPCErrorHandler(err error) { } func (c *innerClient) getOrCreateGRPCConn() (*grpc.ClientConn, error) { - cc, err := c.pdSvcDiscovery.GetOrCreateGRPCConn(c.pdSvcDiscovery.getLeaderURL()) + cc, err := c.pdSvcDiscovery.GetOrCreateGRPCConn(c.pdSvcDiscovery.GetServingURL()) if err != nil { return nil, err } diff --git a/client/mock_pd_service_discovery.go b/client/servicediscovery/mock_pd_service_discovery.go similarity index 92% rename from client/mock_pd_service_discovery.go rename to client/servicediscovery/mock_pd_service_discovery.go index 16462b0b1e6..87b74ae2136 100644 --- a/client/mock_pd_service_discovery.go +++ b/client/servicediscovery/mock_pd_service_discovery.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package pd +package servicediscovery import ( "crypto/tls" @@ -62,6 +62,9 @@ func (*mockPDServiceDiscovery) GetClusterID() uint64 { return 0 } // GetKeyspaceID implements the ServiceDiscovery interface. func (*mockPDServiceDiscovery) GetKeyspaceID() uint32 { return 0 } +// SetKeyspaceID implements the ServiceDiscovery interface. +func (*mockPDServiceDiscovery) SetKeyspaceID(uint32) {} + // GetKeyspaceGroupID implements the ServiceDiscovery interface. func (*mockPDServiceDiscovery) GetKeyspaceGroupID() uint32 { return 0 } @@ -83,6 +86,9 @@ func (*mockPDServiceDiscovery) GetBackupURLs() []string { return nil } // GetServiceClient implements the ServiceDiscovery interface. func (*mockPDServiceDiscovery) GetServiceClient() ServiceClient { return nil } +// GetServiceClientByKind implements the ServiceDiscovery interface. +func (*mockPDServiceDiscovery) GetServiceClientByKind(APIKind) ServiceClient { return nil } + // GetOrCreateGRPCConn implements the ServiceDiscovery interface. func (*mockPDServiceDiscovery) GetOrCreateGRPCConn(string) (*grpc.ClientConn, error) { return nil, nil diff --git a/client/pd_service_discovery.go b/client/servicediscovery/pd_service_discovery.go similarity index 89% rename from client/pd_service_discovery.go rename to client/servicediscovery/pd_service_discovery.go index 0bdc6868c65..2d106559b76 100644 --- a/client/pd_service_discovery.go +++ b/client/servicediscovery/pd_service_discovery.go @@ -12,13 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -package pd +package servicediscovery import ( "context" "crypto/tls" "fmt" - "net/url" "reflect" "sort" "strings" @@ -30,10 +29,12 @@ import ( "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/pdpb" "github.com/pingcap/log" + "github.com/tikv/pd/client/constants" "github.com/tikv/pd/client/errs" "github.com/tikv/pd/client/opt" "github.com/tikv/pd/client/retry" "github.com/tikv/pd/client/utils/grpcutil" + "github.com/tikv/pd/client/utils/tlsutil" "go.uber.org/zap" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -42,23 +43,29 @@ import ( ) const ( - memberUpdateInterval = time.Minute - serviceModeUpdateInterval = 3 * time.Second - updateMemberTimeout = time.Second // Use a shorter timeout to recover faster from network isolation. - updateMemberBackOffBaseTime = 100 * time.Millisecond - - httpScheme = "http" - httpsScheme = "https" + // MemberUpdateInterval is the interval to update the member list. + MemberUpdateInterval = time.Minute + // UpdateMemberTimeout is the timeout to update the member list. + // Use a shorter timeout to recover faster from network isolation. + UpdateMemberTimeout = time.Second + // UpdateMemberBackOffBaseTime is the base time to back off when updating the member list. + UpdateMemberBackOffBaseTime = 100 * time.Millisecond + + serviceModeUpdateInterval = 3 * time.Second ) // MemberHealthCheckInterval might be changed in the unit to shorten the testing time. var MemberHealthCheckInterval = time.Second -type apiKind int +// APIKind defines how this API should be handled. +type APIKind int const ( - forwardAPIKind apiKind = iota - regionAPIKind + // ForwardAPIKind means this API should be forwarded from the followers to the leader. + ForwardAPIKind APIKind = iota + // UniversalAPIKind means this API can be handled by both the leader and the followers. + UniversalAPIKind + apiKindCount ) @@ -80,6 +87,8 @@ type ServiceDiscovery interface { GetClusterID() uint64 // GetKeyspaceID returns the ID of the keyspace GetKeyspaceID() uint32 + // SetKeyspaceID sets the ID of the keyspace + SetKeyspaceID(id uint32) // GetKeyspaceGroupID returns the ID of the keyspace group GetKeyspaceGroupID() uint32 // GetServiceURLs returns the URLs of the servers providing the service @@ -101,6 +110,8 @@ type ServiceDiscovery interface { // If the leader ServiceClient meets network problem, // it returns a follower/secondary ServiceClient which can forward the request to leader. GetServiceClient() ServiceClient + // GetServiceClientByKind tries to get the ServiceClient with the given API kind. + GetServiceClientByKind(kind APIKind) ServiceClient // GetAllServiceClients tries to get all ServiceClient. // If the leader is not nil, it will put the leader service client first in the slice. GetAllServiceClients() []ServiceClient @@ -221,17 +232,13 @@ func (c *pdServiceClient) checkNetworkAvailable(ctx context.Context) { } }) rpcErr, ok := status.FromError(err) - if (ok && isNetworkError(rpcErr.Code())) || resp.GetStatus() != healthpb.HealthCheckResponse_SERVING { + if (ok && errs.IsNetworkError(rpcErr.Code())) || resp.GetStatus() != healthpb.HealthCheckResponse_SERVING { c.networkFailure.Store(true) } else { c.networkFailure.Store(false) } } -func isNetworkError(code codes.Code) bool { - return code == codes.Unavailable || code == codes.DeadlineExceeded -} - // GetClientConn implements ServiceClient. func (c *pdServiceClient) GetClientConn() *grpc.ClientConn { if c == nil { @@ -383,18 +390,19 @@ func (c *pdServiceBalancer) get() (ret ServiceClient) { return } -type updateKeyspaceIDFunc func() error +// UpdateKeyspaceIDFunc is the function type for updating the keyspace ID. +type UpdateKeyspaceIDFunc func() error type tsoLeaderURLUpdatedFunc func(string) error -// tsoEventSource subscribes to events related to changes in the TSO leader/primary from the service discovery. -type tsoEventSource interface { +// TSOEventSource subscribes to events related to changes in the TSO leader/primary from the service discovery. +type TSOEventSource interface { // SetTSOLeaderURLUpdatedCallback adds a callback which will be called when the TSO leader/primary is updated. SetTSOLeaderURLUpdatedCallback(callback tsoLeaderURLUpdatedFunc) } var ( _ ServiceDiscovery = (*pdServiceDiscovery)(nil) - _ tsoEventSource = (*pdServiceDiscovery)(nil) + _ TSOEventSource = (*pdServiceDiscovery)(nil) ) // pdServiceDiscovery is the service discovery client of PD/API service which is quorum based @@ -433,7 +441,7 @@ type pdServiceDiscovery struct { cancel context.CancelFunc closeOnce sync.Once - updateKeyspaceIDFunc updateKeyspaceIDFunc + updateKeyspaceIDFunc UpdateKeyspaceIDFunc keyspaceID uint32 tlsCfg *tls.Config // Client option. @@ -444,20 +452,20 @@ type pdServiceDiscovery struct { func NewDefaultPDServiceDiscovery( ctx context.Context, cancel context.CancelFunc, urls []string, tlsCfg *tls.Config, -) *pdServiceDiscovery { +) ServiceDiscovery { var wg sync.WaitGroup - return newPDServiceDiscovery(ctx, cancel, &wg, nil, nil, defaultKeyspaceID, urls, tlsCfg, opt.NewOption()) + return NewPDServiceDiscovery(ctx, cancel, &wg, nil, nil, constants.DefaultKeyspaceID, urls, tlsCfg, opt.NewOption()) } -// newPDServiceDiscovery returns a new PD service discovery-based client. -func newPDServiceDiscovery( +// NewPDServiceDiscovery returns a new PD service discovery-based client. +func NewPDServiceDiscovery( ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, serviceModeUpdateCb func(pdpb.ServiceMode), - updateKeyspaceIDFunc updateKeyspaceIDFunc, + updateKeyspaceIDFunc UpdateKeyspaceIDFunc, keyspaceID uint32, urls []string, tlsCfg *tls.Config, option *opt.Option, -) *pdServiceDiscovery { +) ServiceDiscovery { pdsd := &pdServiceDiscovery{ checkMembershipCh: make(chan struct{}, 1), ctx: ctx, @@ -470,7 +478,7 @@ func newPDServiceDiscovery( tlsCfg: tlsCfg, option: option, } - urls = addrsToURLs(urls, tlsCfg) + urls = tlsutil.AddrsToURLs(urls, tlsCfg) pdsd.urls.Store(urls) return pdsd } @@ -493,7 +501,7 @@ func (c *pdServiceDiscovery) Init() error { // We need to update the keyspace ID before we discover and update the service mode // so that TSO in API mode can be initialized with the correct keyspace ID. - if c.keyspaceID == nullKeyspaceID && c.updateKeyspaceIDFunc != nil { + if c.keyspaceID == constants.NullKeyspaceID && c.updateKeyspaceIDFunc != nil { if err := c.initRetry(c.updateKeyspaceIDFunc); err != nil { return err } @@ -534,10 +542,10 @@ func (c *pdServiceDiscovery) updateMemberLoop() { ctx, cancel := context.WithCancel(c.ctx) defer cancel() - ticker := time.NewTicker(memberUpdateInterval) + ticker := time.NewTicker(MemberUpdateInterval) defer ticker.Stop() - bo := retry.InitialBackoffer(updateMemberBackOffBaseTime, updateMemberTimeout, updateMemberBackOffBaseTime) + bo := retry.InitialBackoffer(UpdateMemberBackOffBaseTime, UpdateMemberTimeout, UpdateMemberBackOffBaseTime) for { select { case <-ctx.Done(): @@ -546,9 +554,6 @@ func (c *pdServiceDiscovery) updateMemberLoop() { case <-ticker.C: case <-c.checkMembershipCh: } - failpoint.Inject("skipUpdateMember", func() { - failpoint.Continue() - }) if err := bo.Exec(ctx, c.updateMember); err != nil { log.Error("[pd] failed to update member", zap.Strings("urls", c.GetServiceURLs()), errs.ZapError(err)) } @@ -663,7 +668,7 @@ func (c *pdServiceDiscovery) SetKeyspaceID(keyspaceID uint32) { // GetKeyspaceGroupID returns the ID of the keyspace group func (*pdServiceDiscovery) GetKeyspaceGroupID() uint32 { // PD/API service only supports the default keyspace group - return defaultKeySpaceGroupID + return constants.DefaultKeyspaceGroupID } // DiscoverMicroservice discovers the microservice with the specified type and returns the server urls. @@ -733,8 +738,8 @@ func (c *pdServiceDiscovery) getLeaderServiceClient() *pdServiceClient { return leader.(*pdServiceClient) } -// getServiceClientByKind returns ServiceClient of the specific kind. -func (c *pdServiceDiscovery) getServiceClientByKind(kind apiKind) ServiceClient { +// GetServiceClientByKind returns ServiceClient of the specific kind. +func (c *pdServiceDiscovery) GetServiceClientByKind(kind APIKind) ServiceClient { client := c.apiCandidateNodes[kind].get() if client == nil { return nil @@ -746,7 +751,7 @@ func (c *pdServiceDiscovery) getServiceClientByKind(kind apiKind) ServiceClient func (c *pdServiceDiscovery) GetServiceClient() ServiceClient { leaderClient := c.getLeaderServiceClient() if c.option.EnableForwarding && !leaderClient.Available() { - if followerClient := c.getServiceClientByKind(forwardAPIKind); followerClient != nil { + if followerClient := c.GetServiceClientByKind(ForwardAPIKind); followerClient != nil { log.Debug("[pd] use follower client", zap.String("url", followerClient.GetURL())) return followerClient } @@ -833,17 +838,14 @@ func (c *pdServiceDiscovery) initClusterID() error { clusterID = members.GetHeader().GetClusterId() continue } - failpoint.Inject("skipClusterIDCheck", func() { - failpoint.Continue() - }) // All URLs passed in should have the same cluster ID. if members.GetHeader().GetClusterId() != clusterID { - return errors.WithStack(errUnmatchedClusterID) + return errors.WithStack(errs.ErrUnmatchedClusterID) } } // Failed to init the cluster ID. if clusterID == 0 { - return errors.WithStack(errFailInitClusterID) + return errors.WithStack(errs.ErrFailInitClusterID) } c.clusterID = clusterID return nil @@ -869,7 +871,7 @@ func (c *pdServiceDiscovery) checkServiceModeChanged() error { return err } if clusterInfo == nil || len(clusterInfo.ServiceModes) == 0 { - return errors.WithStack(errNoServiceModeReturned) + return errors.WithStack(errs.ErrNoServiceModeReturned) } if c.serviceModeUpdateCb != nil { c.serviceModeUpdateCb(clusterInfo.ServiceModes[0]) @@ -878,14 +880,8 @@ func (c *pdServiceDiscovery) checkServiceModeChanged() error { } func (c *pdServiceDiscovery) updateMember() error { - for i, url := range c.GetServiceURLs() { - failpoint.Inject("skipFirstUpdateMember", func() { - if i == 0 { - failpoint.Continue() - } - }) - - members, err := c.getMembers(c.ctx, url, updateMemberTimeout) + for _, url := range c.GetServiceURLs() { + members, err := c.getMembers(c.ctx, url, UpdateMemberTimeout) // Check the cluster ID. updatedClusterID := members.GetHeader().GetClusterId() if err == nil && updatedClusterID != c.clusterID { @@ -1016,7 +1012,7 @@ func (c *pdServiceDiscovery) updateFollowers(members []*pdpb.Member, leaderID ui followerURLs = append(followerURLs, member.GetClientUrls()...) // FIXME: How to safely compare urls(also for leader)? For now, only allows one client url. - url := pickMatchedURL(member.GetClientUrls(), c.tlsCfg) + url := tlsutil.PickMatchedURL(member.GetClientUrls(), c.tlsCfg) if client, ok := c.followers.Load(url); ok { if client.(*pdServiceClient).GetClientConn() == nil { conn, err := c.GetOrCreateGRPCConn(url) @@ -1053,7 +1049,7 @@ func (c *pdServiceDiscovery) updateFollowers(members []*pdpb.Member, leaderID ui func (c *pdServiceDiscovery) updateServiceClient(members []*pdpb.Member, leader *pdpb.Member) error { // FIXME: How to safely compare leader urls? For now, only allows one client url. - leaderURL := pickMatchedURL(leader.GetClientUrls(), c.tlsCfg) + leaderURL := tlsutil.PickMatchedURL(leader.GetClientUrls(), c.tlsCfg) leaderChanged, err := c.switchLeader(leaderURL) followerChanged := c.updateFollowers(members, leader.GetMemberId(), leaderURL) // don't need to recreate balancer if no changes. @@ -1082,51 +1078,3 @@ func (c *pdServiceDiscovery) updateServiceClient(members []*pdpb.Member, leader func (c *pdServiceDiscovery) GetOrCreateGRPCConn(url string) (*grpc.ClientConn, error) { return grpcutil.GetOrCreateGRPCConn(c.ctx, &c.clientConns, url, c.tlsCfg, c.option.GRPCDialOptions...) } - -func addrsToURLs(addrs []string, tlsCfg *tls.Config) []string { - // Add default schema "http://" to addrs. - urls := make([]string, 0, len(addrs)) - for _, addr := range addrs { - urls = append(urls, modifyURLScheme(addr, tlsCfg)) - } - return urls -} - -func modifyURLScheme(url string, tlsCfg *tls.Config) string { - if tlsCfg == nil { - if strings.HasPrefix(url, httpsSchemePrefix) { - url = httpSchemePrefix + strings.TrimPrefix(url, httpsSchemePrefix) - } else if !strings.HasPrefix(url, httpSchemePrefix) { - url = httpSchemePrefix + url - } - } else { - if strings.HasPrefix(url, httpSchemePrefix) { - url = httpsSchemePrefix + strings.TrimPrefix(url, httpSchemePrefix) - } else if !strings.HasPrefix(url, httpsSchemePrefix) { - url = httpsSchemePrefix + url - } - } - return url -} - -// pickMatchedURL picks the matched URL based on the TLS config. -// Note: please make sure the URLs are valid. -func pickMatchedURL(urls []string, tlsCfg *tls.Config) string { - for _, uStr := range urls { - u, err := url.Parse(uStr) - if err != nil { - continue - } - if tlsCfg != nil && u.Scheme == httpsScheme { - return uStr - } - if tlsCfg == nil && u.Scheme == httpScheme { - return uStr - } - } - ret := modifyURLScheme(urls[0], tlsCfg) - log.Warn("[pd] no matched url found", zap.Strings("urls", urls), - zap.Bool("tls-enabled", tlsCfg != nil), - zap.String("attempted-url", ret)) - return ret -} diff --git a/client/pd_service_discovery_test.go b/client/servicediscovery/pd_service_discovery_test.go similarity index 69% rename from client/pd_service_discovery_test.go rename to client/servicediscovery/pd_service_discovery_test.go index e553b087d34..45faf2aa7f1 100644 --- a/client/pd_service_discovery_test.go +++ b/client/servicediscovery/pd_service_discovery_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package pd +package servicediscovery import ( "context" @@ -30,8 +30,10 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "github.com/tikv/pd/client/errs" + "github.com/tikv/pd/client/opt" "github.com/tikv/pd/client/utils/grpcutil" "github.com/tikv/pd/client/utils/testutil" + "github.com/tikv/pd/client/utils/tlsutil" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" pb "google.golang.org/grpc/examples/helloworld/helloworld" @@ -142,12 +144,12 @@ func (suite *serviceClientTestSuite) SetupSuite() { followerConn, err2 := grpc.Dial(suite.followerServer.addr, grpc.WithTransportCredentials(insecure.NewCredentials())) if err1 == nil && err2 == nil { suite.followerClient = newPDServiceClient( - modifyURLScheme(suite.followerServer.addr, nil), - modifyURLScheme(suite.leaderServer.addr, nil), + tlsutil.ModifyURLScheme(suite.followerServer.addr, nil), + tlsutil.ModifyURLScheme(suite.leaderServer.addr, nil), followerConn, false) suite.leaderClient = newPDServiceClient( - modifyURLScheme(suite.leaderServer.addr, nil), - modifyURLScheme(suite.leaderServer.addr, nil), + tlsutil.ModifyURLScheme(suite.leaderServer.addr, nil), + tlsutil.ModifyURLScheme(suite.leaderServer.addr, nil), leaderConn, true) suite.followerServer.server.leaderConn = suite.leaderClient.GetClientConn() suite.followerServer.server.leaderAddr = suite.leaderClient.GetURL() @@ -173,8 +175,8 @@ func (suite *serviceClientTestSuite) TearDownSuite() { func (suite *serviceClientTestSuite) TestServiceClient() { re := suite.Require() - leaderAddress := modifyURLScheme(suite.leaderServer.addr, nil) - followerAddress := modifyURLScheme(suite.followerServer.addr, nil) + leaderAddress := tlsutil.ModifyURLScheme(suite.leaderServer.addr, nil) + followerAddress := tlsutil.ModifyURLScheme(suite.followerServer.addr, nil) follower := suite.followerClient leader := suite.leaderClient @@ -188,12 +190,12 @@ func (suite *serviceClientTestSuite) TestServiceClient() { re.False(follower.IsConnectedToLeader()) re.True(leader.IsConnectedToLeader()) - re.NoError(failpoint.Enable("github.com/tikv/pd/client/unreachableNetwork1", "return(true)")) + re.NoError(failpoint.Enable("github.com/tikv/pd/client/servicediscovery/unreachableNetwork1", "return(true)")) follower.(*pdServiceClient).checkNetworkAvailable(suite.ctx) leader.(*pdServiceClient).checkNetworkAvailable(suite.ctx) re.False(follower.Available()) re.False(leader.Available()) - re.NoError(failpoint.Disable("github.com/tikv/pd/client/unreachableNetwork1")) + re.NoError(failpoint.Disable("github.com/tikv/pd/client/servicediscovery/unreachableNetwork1")) follower.(*pdServiceClient).checkNetworkAvailable(suite.ctx) leader.(*pdServiceClient).checkNetworkAvailable(suite.ctx) @@ -235,7 +237,7 @@ func (suite *serviceClientTestSuite) TestServiceClient() { followerAPIClient := newPDServiceAPIClient(follower, regionAPIErrorFn) leaderAPIClient := newPDServiceAPIClient(leader, regionAPIErrorFn) - re.NoError(failpoint.Enable("github.com/tikv/pd/client/fastCheckAvailable", "return(true)")) + re.NoError(failpoint.Enable("github.com/tikv/pd/client/servicediscovery/fastCheckAvailable", "return(true)")) re.True(followerAPIClient.Available()) re.True(leaderAPIClient.Available()) @@ -267,7 +269,7 @@ func (suite *serviceClientTestSuite) TestServiceClient() { re.True(followerAPIClient.Available()) re.True(leaderAPIClient.Available()) - re.NoError(failpoint.Disable("github.com/tikv/pd/client/fastCheckAvailable")) + re.NoError(failpoint.Disable("github.com/tikv/pd/client/servicediscovery/fastCheckAvailable")) } func (suite *serviceClientTestSuite) TestServiceClientBalancer() { @@ -308,17 +310,17 @@ func (suite *serviceClientTestSuite) TestServiceClientBalancer() { func TestServiceClientScheme(t *testing.T) { re := require.New(t) - cli := newPDServiceClient(modifyURLScheme("127.0.0.1:2379", nil), modifyURLScheme("127.0.0.1:2379", nil), nil, false) + cli := newPDServiceClient(tlsutil.ModifyURLScheme("127.0.0.1:2379", nil), tlsutil.ModifyURLScheme("127.0.0.1:2379", nil), nil, false) re.Equal("http://127.0.0.1:2379", cli.GetURL()) - cli = newPDServiceClient(modifyURLScheme("https://127.0.0.1:2379", nil), modifyURLScheme("127.0.0.1:2379", nil), nil, false) + cli = newPDServiceClient(tlsutil.ModifyURLScheme("https://127.0.0.1:2379", nil), tlsutil.ModifyURLScheme("127.0.0.1:2379", nil), nil, false) re.Equal("http://127.0.0.1:2379", cli.GetURL()) - cli = newPDServiceClient(modifyURLScheme("http://127.0.0.1:2379", nil), modifyURLScheme("127.0.0.1:2379", nil), nil, false) + cli = newPDServiceClient(tlsutil.ModifyURLScheme("http://127.0.0.1:2379", nil), tlsutil.ModifyURLScheme("127.0.0.1:2379", nil), nil, false) re.Equal("http://127.0.0.1:2379", cli.GetURL()) - cli = newPDServiceClient(modifyURLScheme("127.0.0.1:2379", &tls.Config{}), modifyURLScheme("127.0.0.1:2379", &tls.Config{}), nil, false) + cli = newPDServiceClient(tlsutil.ModifyURLScheme("127.0.0.1:2379", &tls.Config{}), tlsutil.ModifyURLScheme("127.0.0.1:2379", &tls.Config{}), nil, false) re.Equal("https://127.0.0.1:2379", cli.GetURL()) - cli = newPDServiceClient(modifyURLScheme("https://127.0.0.1:2379", &tls.Config{}), modifyURLScheme("127.0.0.1:2379", &tls.Config{}), nil, false) + cli = newPDServiceClient(tlsutil.ModifyURLScheme("https://127.0.0.1:2379", &tls.Config{}), tlsutil.ModifyURLScheme("127.0.0.1:2379", &tls.Config{}), nil, false) re.Equal("https://127.0.0.1:2379", cli.GetURL()) - cli = newPDServiceClient(modifyURLScheme("http://127.0.0.1:2379", &tls.Config{}), modifyURLScheme("127.0.0.1:2379", &tls.Config{}), nil, false) + cli = newPDServiceClient(tlsutil.ModifyURLScheme("http://127.0.0.1:2379", &tls.Config{}), tlsutil.ModifyURLScheme("127.0.0.1:2379", &tls.Config{}), nil, false) re.Equal("https://127.0.0.1:2379", cli.GetURL()) } @@ -336,48 +338,97 @@ func TestSchemeFunction(t *testing.T) { "http://127.0.0.1:2379", "https://127.0.0.1:2379", } - urls := addrsToURLs(endpoints1, tlsCfg) + urls := tlsutil.AddrsToURLs(endpoints1, tlsCfg) for _, u := range urls { re.Equal("https://tc-pd:2379", u) } - urls = addrsToURLs(endpoints2, tlsCfg) + urls = tlsutil.AddrsToURLs(endpoints2, tlsCfg) for _, u := range urls { re.Equal("https://127.0.0.1:2379", u) } - urls = addrsToURLs(endpoints1, nil) + urls = tlsutil.AddrsToURLs(endpoints1, nil) for _, u := range urls { re.Equal("http://tc-pd:2379", u) } - urls = addrsToURLs(endpoints2, nil) + urls = tlsutil.AddrsToURLs(endpoints2, nil) for _, u := range urls { re.Equal("http://127.0.0.1:2379", u) } - re.Equal("https://127.0.0.1:2379", modifyURLScheme("https://127.0.0.1:2379", tlsCfg)) - re.Equal("https://127.0.0.1:2379", modifyURLScheme("http://127.0.0.1:2379", tlsCfg)) - re.Equal("https://127.0.0.1:2379", modifyURLScheme("127.0.0.1:2379", tlsCfg)) - re.Equal("https://tc-pd:2379", modifyURLScheme("tc-pd:2379", tlsCfg)) - re.Equal("http://127.0.0.1:2379", modifyURLScheme("https://127.0.0.1:2379", nil)) - re.Equal("http://127.0.0.1:2379", modifyURLScheme("http://127.0.0.1:2379", nil)) - re.Equal("http://127.0.0.1:2379", modifyURLScheme("127.0.0.1:2379", nil)) - re.Equal("http://tc-pd:2379", modifyURLScheme("tc-pd:2379", nil)) + re.Equal("https://127.0.0.1:2379", tlsutil.ModifyURLScheme("https://127.0.0.1:2379", tlsCfg)) + re.Equal("https://127.0.0.1:2379", tlsutil.ModifyURLScheme("http://127.0.0.1:2379", tlsCfg)) + re.Equal("https://127.0.0.1:2379", tlsutil.ModifyURLScheme("127.0.0.1:2379", tlsCfg)) + re.Equal("https://tc-pd:2379", tlsutil.ModifyURLScheme("tc-pd:2379", tlsCfg)) + re.Equal("http://127.0.0.1:2379", tlsutil.ModifyURLScheme("https://127.0.0.1:2379", nil)) + re.Equal("http://127.0.0.1:2379", tlsutil.ModifyURLScheme("http://127.0.0.1:2379", nil)) + re.Equal("http://127.0.0.1:2379", tlsutil.ModifyURLScheme("127.0.0.1:2379", nil)) + re.Equal("http://tc-pd:2379", tlsutil.ModifyURLScheme("tc-pd:2379", nil)) urls = []string{ "http://127.0.0.1:2379", "https://127.0.0.1:2379", } - re.Equal("https://127.0.0.1:2379", pickMatchedURL(urls, tlsCfg)) + re.Equal("https://127.0.0.1:2379", tlsutil.PickMatchedURL(urls, tlsCfg)) urls = []string{ "http://127.0.0.1:2379", } - re.Equal("https://127.0.0.1:2379", pickMatchedURL(urls, tlsCfg)) + re.Equal("https://127.0.0.1:2379", tlsutil.PickMatchedURL(urls, tlsCfg)) urls = []string{ "http://127.0.0.1:2379", "https://127.0.0.1:2379", } - re.Equal("http://127.0.0.1:2379", pickMatchedURL(urls, nil)) + re.Equal("http://127.0.0.1:2379", tlsutil.PickMatchedURL(urls, nil)) urls = []string{ "https://127.0.0.1:2379", } - re.Equal("http://127.0.0.1:2379", pickMatchedURL(urls, nil)) + re.Equal("http://127.0.0.1:2379", tlsutil.PickMatchedURL(urls, nil)) +} + +func TestUpdateURLs(t *testing.T) { + re := require.New(t) + members := []*pdpb.Member{ + {Name: "pd4", ClientUrls: []string{"tmp://pd4"}}, + {Name: "pd1", ClientUrls: []string{"tmp://pd1"}}, + {Name: "pd3", ClientUrls: []string{"tmp://pd3"}}, + {Name: "pd2", ClientUrls: []string{"tmp://pd2"}}, + } + getURLs := func(ms []*pdpb.Member) (urls []string) { + for _, m := range ms { + urls = append(urls, m.GetClientUrls()[0]) + } + return + } + cli := &pdServiceDiscovery{option: opt.NewOption()} + cli.urls.Store([]string{}) + cli.updateURLs(members[1:]) + re.Equal(getURLs([]*pdpb.Member{members[1], members[3], members[2]}), cli.GetServiceURLs()) + cli.updateURLs(members[1:]) + re.Equal(getURLs([]*pdpb.Member{members[1], members[3], members[2]}), cli.GetServiceURLs()) + cli.updateURLs(members) + re.Equal(getURLs([]*pdpb.Member{members[1], members[3], members[2], members[0]}), cli.GetServiceURLs()) + cli.updateURLs(members[1:]) + re.Equal(getURLs([]*pdpb.Member{members[1], members[3], members[2]}), cli.GetServiceURLs()) + cli.updateURLs(members[2:]) + re.Equal(getURLs([]*pdpb.Member{members[3], members[2]}), cli.GetServiceURLs()) + cli.updateURLs(members[3:]) + re.Equal(getURLs([]*pdpb.Member{members[3]}), cli.GetServiceURLs()) +} + +func TestGRPCDialOption(t *testing.T) { + re := require.New(t) + start := time.Now() + ctx, cancel := context.WithTimeout(context.TODO(), 500*time.Millisecond) + defer cancel() + cli := &pdServiceDiscovery{ + checkMembershipCh: make(chan struct{}, 1), + ctx: ctx, + cancel: cancel, + tlsCfg: nil, + option: opt.NewOption(), + } + cli.urls.Store([]string{"tmp://test.url:5255"}) + cli.option.GRPCDialOptions = []grpc.DialOption{grpc.WithBlock()} + err := cli.updateMember() + re.Error(err) + re.Greater(time.Since(start), 500*time.Millisecond) } diff --git a/client/tso_service_discovery.go b/client/servicediscovery/tso_service_discovery.go similarity index 95% rename from client/tso_service_discovery.go rename to client/servicediscovery/tso_service_discovery.go index 7d5b761e68c..81ef69c5545 100644 --- a/client/tso_service_discovery.go +++ b/client/servicediscovery/tso_service_discovery.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package pd +package servicediscovery import ( "context" @@ -30,6 +30,7 @@ import ( "github.com/pingcap/kvproto/pkg/tsopb" "github.com/pingcap/log" "github.com/tikv/pd/client/clients/metastorage" + "github.com/tikv/pd/client/constants" "github.com/tikv/pd/client/errs" "github.com/tikv/pd/client/opt" "github.com/tikv/pd/client/utils/grpcutil" @@ -52,8 +53,10 @@ const ( tsoQueryRetryInterval = 500 * time.Millisecond ) -var _ ServiceDiscovery = (*tsoServiceDiscovery)(nil) -var _ tsoEventSource = (*tsoServiceDiscovery)(nil) +var ( + _ ServiceDiscovery = (*tsoServiceDiscovery)(nil) + _ TSOEventSource = (*tsoServiceDiscovery)(nil) +) // keyspaceGroupSvcDiscovery is used for discovering the serving endpoints of the keyspace // group to which the keyspace belongs @@ -154,8 +157,8 @@ type tsoServiceDiscovery struct { option *opt.Option } -// newTSOServiceDiscovery returns a new client-side service discovery for the independent TSO service. -func newTSOServiceDiscovery( +// NewTSOServiceDiscovery returns a new client-side service discovery for the independent TSO service. +func NewTSOServiceDiscovery( ctx context.Context, metacli metastorage.Client, apiSvcDiscovery ServiceDiscovery, keyspaceID uint32, tlsCfg *tls.Config, option *opt.Option, ) ServiceDiscovery { @@ -179,7 +182,7 @@ func newTSOServiceDiscovery( c.tsoServerDiscovery = &tsoServerDiscovery{urls: make([]string, 0)} // Start with the default keyspace group. The actual keyspace group, to which the keyspace belongs, // will be discovered later. - c.defaultDiscoveryKey = fmt.Sprintf(tsoSvcDiscoveryFormat, c.clusterID, defaultKeySpaceGroupID) + c.defaultDiscoveryKey = fmt.Sprintf(tsoSvcDiscoveryFormat, c.clusterID, constants.DefaultKeyspaceGroupID) log.Info("created tso service discovery", zap.Uint64("cluster-id", c.clusterID), @@ -249,7 +252,7 @@ func (c *tsoServiceDiscovery) startCheckMemberLoop() { ctx, cancel := context.WithCancel(c.ctx) defer cancel() - ticker := time.NewTicker(memberUpdateInterval) + ticker := time.NewTicker(MemberUpdateInterval) defer ticker.Stop() for { @@ -279,13 +282,18 @@ func (c *tsoServiceDiscovery) GetKeyspaceID() uint32 { return c.keyspaceID.Load() } +// SetKeyspaceID sets the ID of the keyspace +func (c *tsoServiceDiscovery) SetKeyspaceID(keyspaceID uint32) { + c.keyspaceID.Store(keyspaceID) +} + // GetKeyspaceGroupID returns the ID of the keyspace group. If the keyspace group is unknown, // it returns the default keyspace group ID. func (c *tsoServiceDiscovery) GetKeyspaceGroupID() uint32 { c.keyspaceGroupSD.RLock() defer c.keyspaceGroupSD.RUnlock() if c.keyspaceGroupSD.group == nil { - return defaultKeySpaceGroupID + return constants.DefaultKeyspaceGroupID } return c.keyspaceGroupSD.group.Id } @@ -375,6 +383,11 @@ func (c *tsoServiceDiscovery) GetServiceClient() ServiceClient { return c.apiSvcDiscovery.GetServiceClient() } +// GetServiceClientByKind implements ServiceDiscovery +func (c *tsoServiceDiscovery) GetServiceClientByKind(kind APIKind) ServiceClient { + return c.apiSvcDiscovery.GetServiceClientByKind(kind) +} + // GetAllServiceClients implements ServiceDiscovery func (c *tsoServiceDiscovery) GetAllServiceClients() []ServiceClient { return c.apiSvcDiscovery.GetAllServiceClients() @@ -419,7 +432,7 @@ func (c *tsoServiceDiscovery) updateMember() error { keyspaceID := c.GetKeyspaceID() var keyspaceGroup *tsopb.KeyspaceGroup if len(tsoServerURL) > 0 { - keyspaceGroup, err = c.findGroupByKeyspaceID(keyspaceID, tsoServerURL, updateMemberTimeout) + keyspaceGroup, err = c.findGroupByKeyspaceID(keyspaceID, tsoServerURL, UpdateMemberTimeout) if err != nil { if c.tsoServerDiscovery.countFailure() { log.Error("[tso] failed to find the keyspace group", @@ -456,7 +469,7 @@ func (c *tsoServiceDiscovery) updateMember() error { } members[0].IsPrimary = true keyspaceGroup = &tsopb.KeyspaceGroup{ - Id: defaultKeySpaceGroupID, + Id: constants.DefaultKeyspaceGroupID, Members: members, } } @@ -541,7 +554,7 @@ func (c *tsoServiceDiscovery) findGroupByKeyspaceID( Header: &tsopb.RequestHeader{ ClusterId: c.clusterID, KeyspaceId: keyspaceID, - KeyspaceGroupId: defaultKeySpaceGroupID, + KeyspaceGroupId: constants.DefaultKeyspaceGroupID, }, KeyspaceId: keyspaceID, }) diff --git a/client/tso_client.go b/client/tso_client.go index cdd85dd2479..1d0a6385647 100644 --- a/client/tso_client.go +++ b/client/tso_client.go @@ -28,7 +28,9 @@ import ( "github.com/tikv/pd/client/errs" "github.com/tikv/pd/client/metrics" "github.com/tikv/pd/client/opt" + sd "github.com/tikv/pd/client/servicediscovery" "github.com/tikv/pd/client/utils/grpcutil" + "github.com/tikv/pd/client/utils/tlsutil" "go.uber.org/zap" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -72,7 +74,7 @@ type tsoClient struct { wg sync.WaitGroup option *opt.Option - svcDiscovery ServiceDiscovery + svcDiscovery sd.ServiceDiscovery tsoStreamBuilderFactory // leaderURL is the URL of the TSO leader. leaderURL atomic.Value @@ -86,7 +88,7 @@ type tsoClient struct { // newTSOClient returns a new TSO client. func newTSOClient( ctx context.Context, option *opt.Option, - svcDiscovery ServiceDiscovery, factory tsoStreamBuilderFactory, + svcDiscovery sd.ServiceDiscovery, factory tsoStreamBuilderFactory, ) *tsoClient { ctx, cancel := context.WithCancel(ctx) c := &tsoClient{ @@ -106,7 +108,7 @@ func newTSOClient( }, } - eventSrc := svcDiscovery.(tsoEventSource) + eventSrc := svcDiscovery.(sd.TSOEventSource) eventSrc.SetTSOLeaderURLUpdatedCallback(c.updateTSOLeaderURL) c.svcDiscovery.AddServiceURLsSwitchedCallback(c.scheduleUpdateTSOConnectionCtxs) @@ -115,7 +117,7 @@ func newTSOClient( func (c *tsoClient) getOption() *opt.Option { return c.option } -func (c *tsoClient) getServiceDiscovery() ServiceDiscovery { return c.svcDiscovery } +func (c *tsoClient) getServiceDiscovery() sd.ServiceDiscovery { return c.svcDiscovery } func (c *tsoClient) getDispatcher() *tsoDispatcher { return c.dispatcher.Load() @@ -303,7 +305,7 @@ func (c *tsoClient) tryConnectToTSO( // There is no need to wait for the transport layer timeout which can reduce the time of unavailability. // But it conflicts with the retry mechanism since we use the error code to decide if it is caused by network error. // And actually the `Canceled` error can be regarded as a kind of network error in some way. - if rpcErr, ok := status.FromError(err); ok && (isNetworkError(rpcErr.Code()) || rpcErr.Code() == codes.Canceled) { + if rpcErr, ok := status.FromError(err); ok && (errs.IsNetworkError(rpcErr.Code()) || rpcErr.Code() == codes.Canceled) { networkErrNum++ } } @@ -333,8 +335,8 @@ func (c *tsoClient) tryConnectToTSO( cctx = grpcutil.BuildForwardContext(cctx, forwardedHost) stream, err = c.tsoStreamBuilderFactory.makeBuilder(backupClientConn).build(cctx, cancel, c.option.Timeout) if err == nil { - forwardedHostTrim := trimHTTPPrefix(forwardedHost) - addr := trimHTTPPrefix(backupURL) + forwardedHostTrim := tlsutil.TrimHTTPPrefix(forwardedHost) + addr := tlsutil.TrimHTTPPrefix(backupURL) // the goroutine is used to check the network and change back to the original stream go c.checkLeader(ctx, cancel, forwardedHostTrim, addr, url, updateAndClear) metrics.RequestForwarded.WithLabelValues(forwardedHostTrim, addr).Set(1) @@ -440,8 +442,8 @@ func (c *tsoClient) tryConnectToTSOWithProxy( stream, err := tsoStreamBuilder.build(cctx, cancel, c.option.Timeout) if err == nil { if addr != leaderAddr { - forwardedHostTrim := trimHTTPPrefix(forwardedHost) - addrTrim := trimHTTPPrefix(addr) + forwardedHostTrim := tlsutil.TrimHTTPPrefix(forwardedHost) + addrTrim := tlsutil.TrimHTTPPrefix(addr) metrics.RequestForwarded.WithLabelValues(forwardedHostTrim, addrTrim).Set(1) } connectionCtxs.Store(addr, &tsoConnectionContext{cctx, cancel, addr, stream}) diff --git a/client/tso_dispatcher.go b/client/tso_dispatcher.go index c696dc26b36..1123e59dbdd 100644 --- a/client/tso_dispatcher.go +++ b/client/tso_dispatcher.go @@ -32,6 +32,7 @@ import ( "github.com/tikv/pd/client/metrics" "github.com/tikv/pd/client/opt" "github.com/tikv/pd/client/retry" + sd "github.com/tikv/pd/client/servicediscovery" "github.com/tikv/pd/client/utils/timerutil" "github.com/tikv/pd/client/utils/tsoutil" "go.uber.org/zap" @@ -70,7 +71,7 @@ type tsoInfo struct { type tsoServiceProvider interface { getOption() *opt.Option - getServiceDiscovery() ServiceDiscovery + getServiceDiscovery() sd.ServiceDiscovery updateConnectionCtxs(ctx context.Context, connectionCtxs *sync.Map) bool } @@ -179,7 +180,7 @@ func (td *tsoDispatcher) revokePendingRequests(err error) { func (td *tsoDispatcher) close() { td.cancel() - tsoErr := errors.WithStack(errClosing) + tsoErr := errors.WithStack(errs.ErrClosing) td.revokePendingRequests(tsoErr) } @@ -210,7 +211,7 @@ func (td *tsoDispatcher) handleDispatcher(wg *sync.WaitGroup) { // If you encounter this failure, please check the stack in the logs to see if it's a panic. log.Fatal("batched tso requests not cleared when exiting the tso dispatcher loop", zap.Any("panic", recover())) } - tsoErr := errors.WithStack(errClosing) + tsoErr := errors.WithStack(errs.ErrClosing) td.revokePendingRequests(tsoErr) wg.Done() }() @@ -233,7 +234,7 @@ func (td *tsoDispatcher) handleDispatcher(wg *sync.WaitGroup) { <-batchingTimer.C defer batchingTimer.Stop() - bo := retry.InitialBackoffer(updateMemberBackOffBaseTime, updateMemberTimeout, updateMemberBackOffBaseTime) + bo := retry.InitialBackoffer(sd.UpdateMemberBackOffBaseTime, sd.UpdateMemberTimeout, sd.UpdateMemberBackOffBaseTime) tsoBatchLoop: for { select { @@ -494,7 +495,7 @@ func (td *tsoDispatcher) connectionCtxsUpdater() { if enableTSOFollowerProxy && updateTicker.C == nil { // Because the TSO Follower Proxy is enabled, // the periodic check needs to be performed. - setNewUpdateTicker(time.NewTicker(memberUpdateInterval)) + setNewUpdateTicker(time.NewTicker(sd.MemberUpdateInterval)) } else if !enableTSOFollowerProxy && updateTicker.C != nil { // Because the TSO Follower Proxy is disabled, // the periodic check needs to be turned off. diff --git a/client/tso_dispatcher_test.go b/client/tso_dispatcher_test.go index 84bc6a4dc99..6cb963df3df 100644 --- a/client/tso_dispatcher_test.go +++ b/client/tso_dispatcher_test.go @@ -27,6 +27,7 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "github.com/tikv/pd/client/opt" + sd "github.com/tikv/pd/client/servicediscovery" "go.uber.org/zap/zapcore" ) @@ -47,8 +48,8 @@ func (m *mockTSOServiceProvider) getOption() *opt.Option { return m.option } -func (*mockTSOServiceProvider) getServiceDiscovery() ServiceDiscovery { - return NewMockPDServiceDiscovery([]string{mockStreamURL}, nil) +func (*mockTSOServiceProvider) getServiceDiscovery() sd.ServiceDiscovery { + return sd.NewMockPDServiceDiscovery([]string{mockStreamURL}, nil) } func (m *mockTSOServiceProvider) updateConnectionCtxs(ctx context.Context, connectionCtxs *sync.Map) bool { diff --git a/client/tso_stream.go b/client/tso_stream.go index 51ae5696dc4..ce3c513ac46 100644 --- a/client/tso_stream.go +++ b/client/tso_stream.go @@ -29,6 +29,7 @@ import ( "github.com/pingcap/kvproto/pkg/tsopb" "github.com/pingcap/log" "github.com/prometheus/client_golang/prometheus" + "github.com/tikv/pd/client/constants" "github.com/tikv/pd/client/errs" "github.com/tikv/pd/client/metrics" "go.uber.org/zap" @@ -144,7 +145,7 @@ func (s pdTSOStreamAdapter) Recv() (tsoRequestResult, error) { physical: resp.GetTimestamp().GetPhysical(), logical: resp.GetTimestamp().GetLogical(), count: resp.GetCount(), - respKeyspaceGroupID: defaultKeySpaceGroupID, + respKeyspaceGroupID: constants.DefaultKeyspaceGroupID, }, nil } @@ -432,7 +433,7 @@ recvLoop: updateEstimatedLatency(currentReq.startTime, latency) if res.count != uint32(currentReq.count) { - finishWithErr = errors.WithStack(errTSOLength) + finishWithErr = errors.WithStack(errs.ErrTSOLength) break recvLoop } diff --git a/client/utils/tlsutil/url.go b/client/utils/tlsutil/url.go new file mode 100644 index 00000000000..ccc312d195b --- /dev/null +++ b/client/utils/tlsutil/url.go @@ -0,0 +1,88 @@ +// Copyright 2024 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tlsutil + +import ( + "crypto/tls" + "net/url" + "strings" + + "github.com/pingcap/log" + "go.uber.org/zap" +) + +const ( + httpScheme = "http" + httpsScheme = "https" + httpSchemePrefix = "http://" + httpsSchemePrefix = "https://" +) + +// AddrsToURLs converts a list of addresses to a list of URLs. +func AddrsToURLs(addrs []string, tlsCfg *tls.Config) []string { + // Add default schema "http://" to addrs. + urls := make([]string, 0, len(addrs)) + for _, addr := range addrs { + urls = append(urls, ModifyURLScheme(addr, tlsCfg)) + } + return urls +} + +// ModifyURLScheme modifies the scheme of the URL based on the TLS config. +func ModifyURLScheme(url string, tlsCfg *tls.Config) string { + if tlsCfg == nil { + if strings.HasPrefix(url, httpsSchemePrefix) { + url = httpSchemePrefix + strings.TrimPrefix(url, httpsSchemePrefix) + } else if !strings.HasPrefix(url, httpSchemePrefix) { + url = httpSchemePrefix + url + } + } else { + if strings.HasPrefix(url, httpSchemePrefix) { + url = httpsSchemePrefix + strings.TrimPrefix(url, httpSchemePrefix) + } else if !strings.HasPrefix(url, httpsSchemePrefix) { + url = httpsSchemePrefix + url + } + } + return url +} + +// PickMatchedURL picks the matched URL based on the TLS config. +// Note: please make sure the URLs are valid. +func PickMatchedURL(urls []string, tlsCfg *tls.Config) string { + for _, uStr := range urls { + u, err := url.Parse(uStr) + if err != nil { + continue + } + if tlsCfg != nil && u.Scheme == httpsScheme { + return uStr + } + if tlsCfg == nil && u.Scheme == httpScheme { + return uStr + } + } + ret := ModifyURLScheme(urls[0], tlsCfg) + log.Warn("[pd] no matched url found", zap.Strings("urls", urls), + zap.Bool("tls-enabled", tlsCfg != nil), + zap.String("attempted-url", ret)) + return ret +} + +// TrimHTTPPrefix trims the HTTP/HTTPS prefix from the string. +func TrimHTTPPrefix(str string) string { + str = strings.TrimPrefix(str, httpSchemePrefix) + str = strings.TrimPrefix(str, httpsSchemePrefix) + return str +} diff --git a/tests/integrations/client/client_test.go b/tests/integrations/client/client_test.go index bc69367c72a..79f981f3bb3 100644 --- a/tests/integrations/client/client_test.go +++ b/tests/integrations/client/client_test.go @@ -41,6 +41,7 @@ import ( "github.com/tikv/pd/client/caller" "github.com/tikv/pd/client/opt" "github.com/tikv/pd/client/retry" + sd "github.com/tikv/pd/client/servicediscovery" "github.com/tikv/pd/pkg/core" "github.com/tikv/pd/pkg/mcs/utils/constant" "github.com/tikv/pd/pkg/mock/mockid" @@ -83,7 +84,7 @@ func TestClientLeaderChange(t *testing.T) { } cli := setupCli(ctx, re, endpointsWithWrongURL) defer cli.Close() - innerCli, ok := cli.(interface{ GetServiceDiscovery() pd.ServiceDiscovery }) + innerCli, ok := cli.(interface{ GetServiceDiscovery() sd.ServiceDiscovery }) re.True(ok) var ts1, ts2 uint64 @@ -324,7 +325,7 @@ func TestTSOFollowerProxy(t *testing.T) { func TestTSOFollowerProxyWithTSOService(t *testing.T) { re := require.New(t) - re.NoError(failpoint.Enable("github.com/tikv/pd/client/fastUpdateServiceMode", `return(true)`)) + re.NoError(failpoint.Enable("github.com/tikv/pd/client/servicediscovery/fastUpdateServiceMode", `return(true)`)) ctx, cancel := context.WithCancel(context.Background()) defer cancel() cluster, err := tests.NewTestAPICluster(ctx, 1) @@ -346,7 +347,7 @@ func TestTSOFollowerProxyWithTSOService(t *testing.T) { // TSO service does not support the follower proxy, so enabling it should fail. err = cli.UpdateOption(opt.EnableTSOFollowerProxy, true) re.Error(err) - re.NoError(failpoint.Disable("github.com/tikv/pd/client/fastUpdateServiceMode")) + re.NoError(failpoint.Disable("github.com/tikv/pd/client/servicediscovery/fastUpdateServiceMode")) } // TestUnavailableTimeAfterLeaderIsReady is used to test https://github.com/tikv/pd/issues/5207 @@ -451,7 +452,7 @@ func TestFollowerForwardAndHandleTestSuite(t *testing.T) { func (suite *followerForwardAndHandleTestSuite) SetupSuite() { re := suite.Require() suite.ctx, suite.clean = context.WithCancel(context.Background()) - pd.MemberHealthCheckInterval = 100 * time.Millisecond + sd.MemberHealthCheckInterval = 100 * time.Millisecond cluster, err := tests.NewTestCluster(suite.ctx, 3) re.NoError(err) suite.cluster = cluster @@ -497,13 +498,13 @@ func (suite *followerForwardAndHandleTestSuite) TestGetRegionByFollowerForwardin cli := setupCli(ctx, re, suite.endpoints, opt.WithForwardingOption(true)) defer cli.Close() - re.NoError(failpoint.Enable("github.com/tikv/pd/client/unreachableNetwork1", "return(true)")) + re.NoError(failpoint.Enable("github.com/tikv/pd/client/servicediscovery/unreachableNetwork1", "return(true)")) time.Sleep(200 * time.Millisecond) r, err := cli.GetRegion(context.Background(), []byte("a")) re.NoError(err) re.NotNil(r) - re.NoError(failpoint.Disable("github.com/tikv/pd/client/unreachableNetwork1")) + re.NoError(failpoint.Disable("github.com/tikv/pd/client/servicediscovery/unreachableNetwork1")) time.Sleep(200 * time.Millisecond) r, err = cli.GetRegion(context.Background(), []byte("a")) re.NoError(err) @@ -719,7 +720,7 @@ func (suite *followerForwardAndHandleTestSuite) TestGetRegionFromFollower() { // because we can't check whether this request is processed by followers from response, // we can disable forward and make network problem for leader. - re.NoError(failpoint.Enable("github.com/tikv/pd/client/unreachableNetwork1", fmt.Sprintf("return(\"%s\")", leader.GetAddr()))) + re.NoError(failpoint.Enable("github.com/tikv/pd/client/servicediscovery/unreachableNetwork1", fmt.Sprintf("return(\"%s\")", leader.GetAddr()))) time.Sleep(150 * time.Millisecond) cnt = 0 for range 100 { @@ -730,11 +731,11 @@ func (suite *followerForwardAndHandleTestSuite) TestGetRegionFromFollower() { re.Equal(resp.Meta.Id, suite.regionID) } re.Equal(100, cnt) - re.NoError(failpoint.Disable("github.com/tikv/pd/client/unreachableNetwork1")) + re.NoError(failpoint.Disable("github.com/tikv/pd/client/servicediscovery/unreachableNetwork1")) // make network problem for follower. follower := cluster.GetServer(cluster.GetFollower()) - re.NoError(failpoint.Enable("github.com/tikv/pd/client/unreachableNetwork1", fmt.Sprintf("return(\"%s\")", follower.GetAddr()))) + re.NoError(failpoint.Enable("github.com/tikv/pd/client/servicediscovery/unreachableNetwork1", fmt.Sprintf("return(\"%s\")", follower.GetAddr()))) time.Sleep(100 * time.Millisecond) cnt = 0 for range 100 { @@ -745,7 +746,7 @@ func (suite *followerForwardAndHandleTestSuite) TestGetRegionFromFollower() { re.Equal(resp.Meta.Id, suite.regionID) } re.Equal(100, cnt) - re.NoError(failpoint.Disable("github.com/tikv/pd/client/unreachableNetwork1")) + re.NoError(failpoint.Disable("github.com/tikv/pd/client/servicediscovery/unreachableNetwork1")) // follower client failed will retry by leader service client. re.NoError(failpoint.Enable("github.com/tikv/pd/server/followerHandleError", "return(true)")) @@ -761,8 +762,8 @@ func (suite *followerForwardAndHandleTestSuite) TestGetRegionFromFollower() { re.NoError(failpoint.Disable("github.com/tikv/pd/server/followerHandleError")) // test after being healthy - re.NoError(failpoint.Enable("github.com/tikv/pd/client/unreachableNetwork1", fmt.Sprintf("return(\"%s\")", leader.GetAddr()))) - re.NoError(failpoint.Enable("github.com/tikv/pd/client/fastCheckAvailable", "return(true)")) + re.NoError(failpoint.Enable("github.com/tikv/pd/client/servicediscovery/unreachableNetwork1", fmt.Sprintf("return(\"%s\")", leader.GetAddr()))) + re.NoError(failpoint.Enable("github.com/tikv/pd/client/servicediscovery/fastCheckAvailable", "return(true)")) time.Sleep(100 * time.Millisecond) cnt = 0 for range 100 { @@ -773,8 +774,8 @@ func (suite *followerForwardAndHandleTestSuite) TestGetRegionFromFollower() { re.Equal(resp.Meta.Id, suite.regionID) } re.Equal(100, cnt) - re.NoError(failpoint.Disable("github.com/tikv/pd/client/unreachableNetwork1")) - re.NoError(failpoint.Disable("github.com/tikv/pd/client/fastCheckAvailable")) + re.NoError(failpoint.Disable("github.com/tikv/pd/client/servicediscovery/unreachableNetwork1")) + re.NoError(failpoint.Disable("github.com/tikv/pd/client/servicediscovery/fastCheckAvailable")) } func (suite *followerForwardAndHandleTestSuite) TestGetTSFuture() { @@ -857,7 +858,7 @@ func setupCli(ctx context.Context, re *require.Assertions, endpoints []string, o return cli } -func waitLeader(re *require.Assertions, cli pd.ServiceDiscovery, leader *tests.TestServer) { +func waitLeader(re *require.Assertions, cli sd.ServiceDiscovery, leader *tests.TestServer) { testutil.Eventually(re, func() bool { cli.ScheduleCheckMemberChanged() return cli.GetServingURL() == leader.GetConfig().ClientUrls && leader.GetAddr() == cli.GetServingURL() @@ -1778,7 +1779,7 @@ func (suite *clientTestSuite) TestMemberUpdateBackOff() { endpoints := runServer(re, cluster) cli := setupCli(ctx, re, endpoints) defer cli.Close() - innerCli, ok := cli.(interface{ GetServiceDiscovery() pd.ServiceDiscovery }) + innerCli, ok := cli.(interface{ GetServiceDiscovery() sd.ServiceDiscovery }) re.True(ok) leader := cluster.GetLeader() @@ -1801,7 +1802,7 @@ func (suite *clientTestSuite) TestMemberUpdateBackOff() { re.NoError(failpoint.Disable("github.com/tikv/pd/client/retry/backOffExecute")) } -func waitLeaderChange(re *require.Assertions, cluster *tests.TestCluster, old string, cli pd.ServiceDiscovery) string { +func waitLeaderChange(re *require.Assertions, cluster *tests.TestCluster, old string, cli sd.ServiceDiscovery) string { var leader string testutil.Eventually(re, func() bool { cli.ScheduleCheckMemberChanged() diff --git a/tests/integrations/mcs/resourcemanager/resource_manager_test.go b/tests/integrations/mcs/resourcemanager/resource_manager_test.go index 5688ea8a8ac..15bec8ea8fd 100644 --- a/tests/integrations/mcs/resourcemanager/resource_manager_test.go +++ b/tests/integrations/mcs/resourcemanager/resource_manager_test.go @@ -34,6 +34,7 @@ import ( pd "github.com/tikv/pd/client" "github.com/tikv/pd/client/caller" "github.com/tikv/pd/client/resource_group/controller" + sd "github.com/tikv/pd/client/servicediscovery" "github.com/tikv/pd/pkg/mcs/resourcemanager/server" "github.com/tikv/pd/pkg/utils/testutil" "github.com/tikv/pd/pkg/utils/typeutil" @@ -142,7 +143,7 @@ func (suite *resourceManagerClientTestSuite) SetupSuite() { } func waitLeader(re *require.Assertions, cli pd.Client, leaderAddr string) { - innerCli, ok := cli.(interface{ GetServiceDiscovery() pd.ServiceDiscovery }) + innerCli, ok := cli.(interface{ GetServiceDiscovery() sd.ServiceDiscovery }) re.True(ok) re.NotNil(innerCli) testutil.Eventually(re, func() bool { diff --git a/tests/integrations/mcs/tso/server_test.go b/tests/integrations/mcs/tso/server_test.go index 8624454aec3..a78d61bf429 100644 --- a/tests/integrations/mcs/tso/server_test.go +++ b/tests/integrations/mcs/tso/server_test.go @@ -245,7 +245,7 @@ func NewAPIServerForward(re *require.Assertions) APIServerForward { re.NoError(suite.pdLeader.BootstrapCluster()) suite.addRegions() - re.NoError(failpoint.Enable("github.com/tikv/pd/client/usePDServiceMode", "return(true)")) + re.NoError(failpoint.Enable("github.com/tikv/pd/client/servicediscovery/usePDServiceMode", "return(true)")) suite.pdClient, err = pd.NewClientWithContext(context.Background(), caller.TestComponent, []string{suite.backendEndpoints}, pd.SecurityOption{}, opt.WithMaxErrorRetry(1)) @@ -267,7 +267,7 @@ func (suite *APIServerForward) ShutDown() { } suite.cluster.Destroy() suite.cancel() - re.NoError(failpoint.Disable("github.com/tikv/pd/client/usePDServiceMode")) + re.NoError(failpoint.Disable("github.com/tikv/pd/client/servicediscovery/usePDServiceMode")) } func TestForwardTSORelated(t *testing.T) { @@ -593,7 +593,7 @@ func (suite *CommonTestSuite) TestBootstrapDefaultKeyspaceGroup() { // If `EnableTSODynamicSwitching` is disabled, the PD should not provide TSO service after the TSO server is stopped. func TestTSOServiceSwitch(t *testing.T) { re := require.New(t) - re.NoError(failpoint.Enable("github.com/tikv/pd/client/fastUpdateServiceMode", `return(true)`)) + re.NoError(failpoint.Enable("github.com/tikv/pd/client/servicediscovery/fastUpdateServiceMode", `return(true)`)) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -666,7 +666,7 @@ func TestTSOServiceSwitch(t *testing.T) { // Verify PD is now providing TSO service and timestamps are monotonically increasing re.NoError(checkTSOMonotonic(ctx, pdClient, &globalLastTS, 10)) - re.NoError(failpoint.Disable("github.com/tikv/pd/client/fastUpdateServiceMode")) + re.NoError(failpoint.Disable("github.com/tikv/pd/client/servicediscovery/fastUpdateServiceMode")) } func checkTSOMonotonic(ctx context.Context, pdClient pd.Client, globalLastTS *uint64, count int) error { diff --git a/tests/integrations/tso/client_test.go b/tests/integrations/tso/client_test.go index 121a61b1986..422d578326a 100644 --- a/tests/integrations/tso/client_test.go +++ b/tests/integrations/tso/client_test.go @@ -31,6 +31,7 @@ import ( pd "github.com/tikv/pd/client" "github.com/tikv/pd/client/caller" "github.com/tikv/pd/client/opt" + sd "github.com/tikv/pd/client/servicediscovery" "github.com/tikv/pd/client/utils/testutil" bs "github.com/tikv/pd/pkg/basicserver" "github.com/tikv/pd/pkg/mcs/utils/constant" @@ -154,7 +155,7 @@ func (suite *tsoClientTestSuite) SetupTest() { caller.TestComponent, suite.getBackendEndpoints(), pd.SecurityOption{}, opt.WithForwardingOption(true)) re.NoError(err) - innerClient, ok := client.(interface{ GetServiceDiscovery() pd.ServiceDiscovery }) + innerClient, ok := client.(interface{ GetServiceDiscovery() sd.ServiceDiscovery }) re.True(ok) re.Equal(constant.NullKeyspaceID, innerClient.GetServiceDiscovery().GetKeyspaceID()) re.Equal(constant.DefaultKeyspaceGroupID, innerClient.GetServiceDiscovery().GetKeyspaceGroupID()) @@ -266,11 +267,11 @@ func (suite *tsoClientTestSuite) TestDiscoverTSOServiceWithLegacyPath() { failpointValue := fmt.Sprintf(`return(%d)`, keyspaceID) // Simulate the case that the server has lower version than the client and returns no tso addrs // in the GetClusterInfo RPC. - re.NoError(failpoint.Enable("github.com/tikv/pd/client/serverReturnsNoTSOAddrs", `return(true)`)) - re.NoError(failpoint.Enable("github.com/tikv/pd/client/unexpectedCallOfFindGroupByKeyspaceID", failpointValue)) + re.NoError(failpoint.Enable("github.com/tikv/pd/client/servicediscovery/serverReturnsNoTSOAddrs", `return(true)`)) + re.NoError(failpoint.Enable("github.com/tikv/pd/client/servicediscovery/unexpectedCallOfFindGroupByKeyspaceID", failpointValue)) defer func() { - re.NoError(failpoint.Disable("github.com/tikv/pd/client/serverReturnsNoTSOAddrs")) - re.NoError(failpoint.Disable("github.com/tikv/pd/client/unexpectedCallOfFindGroupByKeyspaceID")) + re.NoError(failpoint.Disable("github.com/tikv/pd/client/servicediscovery/serverReturnsNoTSOAddrs")) + re.NoError(failpoint.Disable("github.com/tikv/pd/client/servicediscovery/unexpectedCallOfFindGroupByKeyspaceID")) }() ctx, cancel := context.WithCancel(suite.ctx) @@ -320,14 +321,14 @@ func (suite *tsoClientTestSuite) TestGetMinTS() { } wg.Wait() - re.NoError(failpoint.Enable("github.com/tikv/pd/client/unreachableNetwork1", "return(true)")) + re.NoError(failpoint.Enable("github.com/tikv/pd/client/servicediscovery/unreachableNetwork1", "return(true)")) time.Sleep(time.Second) testutil.Eventually(re, func() bool { var err error _, _, err = suite.clients[0].GetMinTS(suite.ctx) return err == nil }) - re.NoError(failpoint.Disable("github.com/tikv/pd/client/unreachableNetwork1")) + re.NoError(failpoint.Disable("github.com/tikv/pd/client/servicediscovery/unreachableNetwork1")) } // More details can be found in this issue: https://github.com/tikv/pd/issues/4884 @@ -487,10 +488,10 @@ func TestMixedTSODeployment(t *testing.T) { re := require.New(t) re.NoError(failpoint.Enable("github.com/tikv/pd/pkg/tso/fastUpdatePhysicalInterval", "return(true)")) - re.NoError(failpoint.Enable("github.com/tikv/pd/client/skipUpdateServiceMode", "return(true)")) + re.NoError(failpoint.Enable("github.com/tikv/pd/client/servicediscovery/skipUpdateServiceMode", "return(true)")) defer func() { re.NoError(failpoint.Disable("github.com/tikv/pd/pkg/tso/fastUpdatePhysicalInterval")) - re.NoError(failpoint.Disable("github.com/tikv/pd/client/skipUpdateServiceMode")) + re.NoError(failpoint.Disable("github.com/tikv/pd/client/servicediscovery/skipUpdateServiceMode")) }() ctx, cancel := context.WithCancel(context.Background()) @@ -550,7 +551,7 @@ func TestUpgradingAPIandTSOClusters(t *testing.T) { backendEndpoints := pdLeader.GetAddr() // Create a pd client in PD mode to let the API leader to forward requests to the TSO cluster. - re.NoError(failpoint.Enable("github.com/tikv/pd/client/usePDServiceMode", "return(true)")) + re.NoError(failpoint.Enable("github.com/tikv/pd/client/servicediscovery/usePDServiceMode", "return(true)")) pdClient, err := pd.NewClientWithContext(context.Background(), caller.TestComponent, []string{backendEndpoints}, pd.SecurityOption{}, opt.WithMaxErrorRetry(1)) @@ -579,7 +580,7 @@ func TestUpgradingAPIandTSOClusters(t *testing.T) { tsoCluster.Destroy() apiCluster.Destroy() cancel() - re.NoError(failpoint.Disable("github.com/tikv/pd/client/usePDServiceMode")) + re.NoError(failpoint.Disable("github.com/tikv/pd/client/servicediscovery/usePDServiceMode")) } func checkTSO( diff --git a/tools/pd-simulator/simulator/client.go b/tools/pd-simulator/simulator/client.go index 39c2633cec5..4de2ea52f88 100644 --- a/tools/pd-simulator/simulator/client.go +++ b/tools/pd-simulator/simulator/client.go @@ -26,8 +26,8 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" - pd "github.com/tikv/pd/client" pdHttp "github.com/tikv/pd/client/http" + sd "github.com/tikv/pd/client/servicediscovery" "github.com/tikv/pd/pkg/core" "github.com/tikv/pd/pkg/utils/typeutil" sc "github.com/tikv/pd/tools/pd-simulator/simulator/config" @@ -64,7 +64,7 @@ var ( // PDHTTPClient is a client for PD HTTP API. PDHTTPClient pdHttp.Client // SD is a service discovery for PD. - SD pd.ServiceDiscovery + SD sd.ServiceDiscovery clusterID atomic.Uint64 ) diff --git a/tools/pd-simulator/simulator/drive.go b/tools/pd-simulator/simulator/drive.go index 22d4175ecc6..c8c325cfca6 100644 --- a/tools/pd-simulator/simulator/drive.go +++ b/tools/pd-simulator/simulator/drive.go @@ -30,8 +30,8 @@ import ( "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" "github.com/prometheus/client_golang/prometheus/promhttp" - pd "github.com/tikv/pd/client" pdHttp "github.com/tikv/pd/client/http" + sd "github.com/tikv/pd/client/servicediscovery" "github.com/tikv/pd/pkg/core" "github.com/tikv/pd/pkg/utils/typeutil" "github.com/tikv/pd/tools/pd-simulator/simulator/cases" @@ -163,7 +163,7 @@ func (d *Driver) allocID() error { func (d *Driver) updateNodesClient() error { urls := strings.Split(d.pdAddr, ",") ctx, cancel := context.WithCancel(context.Background()) - SD = pd.NewDefaultPDServiceDiscovery(ctx, cancel, urls, nil) + SD = sd.NewDefaultPDServiceDiscovery(ctx, cancel, urls, nil) if err := SD.Init(); err != nil { return err }