diff --git a/client/client.go b/client/client.go index 9c4f9db22261..800f97ed82c3 100644 --- a/client/client.go +++ b/client/client.go @@ -141,6 +141,9 @@ type Client interface { // SetExternalTimestamp sets external timestamp SetExternalTimestamp(ctx context.Context, timestamp uint64) error + // GetServiceDiscovery returns ServiceDiscovery + GetServiceDiscovery() ServiceDiscovery + // TSOClient is the TSO client. TSOClient // MetaStorageClient is the meta storage client. diff --git a/client/errs/errno.go b/client/errs/errno.go index 0f93ebf14723..c095bbe4b4af 100644 --- a/client/errs/errno.go +++ b/client/errs/errno.go @@ -50,6 +50,7 @@ var ( ErrClientGetMember = errors.Normalize("get member failed", errors.RFCCodeText("PD:client:ErrClientGetMember")) ErrClientGetClusterInfo = errors.Normalize("get cluster info failed", errors.RFCCodeText("PD:client:ErrClientGetClusterInfo")) ErrClientUpdateMember = errors.Normalize("update member failed, %v", errors.RFCCodeText("PD:client:ErrUpdateMember")) + ErrClientNoAvailableMember = errors.Normalize("no available member", errors.RFCCodeText("PD:client:ErrClientNoAvailableMember")) ErrClientProtoUnmarshal = errors.Normalize("failed to unmarshal proto", errors.RFCCodeText("PD:proto:ErrClientProtoUnmarshal")) ErrClientGetMultiResponse = errors.Normalize("get invalid value response %v, must only one", errors.RFCCodeText("PD:client:ErrClientGetMultiResponse")) ErrClientGetServingEndpoint = errors.Normalize("get serving endpoint failed", errors.RFCCodeText("PD:client:ErrClientGetServingEndpoint")) diff --git a/client/go.mod b/client/go.mod index fcb8fd9bfe52..eb49eb674d88 100644 --- a/client/go.mod +++ b/client/go.mod @@ -13,7 +13,6 @@ require ( github.com/pingcap/kvproto v0.0.0-20231222062942-c0c73f41d0b2 github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 github.com/prometheus/client_golang v1.11.1 - github.com/prometheus/client_model v0.2.0 github.com/stretchr/testify v1.8.2 go.uber.org/atomic v1.10.0 go.uber.org/goleak v1.1.11 @@ -32,6 +31,7 @@ require ( github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/prometheus/client_model v0.2.0 // indirect github.com/prometheus/common v0.26.0 // indirect github.com/prometheus/procfs v0.6.0 // indirect go.uber.org/multierr v1.11.0 // indirect diff --git a/client/http/client.go b/client/http/client.go index ac14de50b332..f16a3abed89e 100644 --- a/client/http/client.go +++ b/client/http/client.go @@ -19,17 +19,16 @@ import ( "context" "crypto/tls" "encoding/json" - "fmt" "io" "net/http" "os" - "strings" - "sync" "time" "github.com/pingcap/errors" "github.com/pingcap/log" "github.com/prometheus/client_golang/prometheus" + pd "github.com/tikv/pd/client" + "github.com/tikv/pd/client/errs" "github.com/tikv/pd/client/retry" "go.uber.org/zap" ) @@ -58,9 +57,7 @@ type clientInner struct { ctx context.Context cancel context.CancelFunc - sync.RWMutex - pdAddrs []string - leaderAddrIdx int + sd pd.ServiceDiscovery // source is used to mark the source of the client creation, // it will also be used in the caller ID of the inner client. @@ -72,12 +69,11 @@ type clientInner struct { executionDuration *prometheus.HistogramVec } -func newClientInner(source string) *clientInner { - ctx, cancel := context.WithCancel(context.Background()) - return &clientInner{ctx: ctx, cancel: cancel, leaderAddrIdx: -1, source: source} +func newClientInner(ctx context.Context, cancel context.CancelFunc, source string) *clientInner { + return &clientInner{ctx: ctx, cancel: cancel, source: source} } -func (ci *clientInner) init() { +func (ci *clientInner) init(sd pd.ServiceDiscovery) { // Init the HTTP client if it's not configured. if ci.cli == nil { ci.cli = &http.Client{Timeout: defaultTimeout} @@ -87,8 +83,7 @@ func (ci *clientInner) init() { ci.cli.Transport = transport } } - // Start the members info updater daemon. - go ci.membersInfoUpdater(ci.ctx) + ci.sd = sd } func (ci *clientInner) close() { @@ -98,33 +93,6 @@ func (ci *clientInner) close() { } } -// getPDAddrs returns the current PD addresses and the index of the leader address. -func (ci *clientInner) getPDAddrs() ([]string, int) { - ci.RLock() - defer ci.RUnlock() - return ci.pdAddrs, ci.leaderAddrIdx -} - -func (ci *clientInner) setPDAddrs(pdAddrs []string, leaderAddrIdx int) { - ci.Lock() - defer ci.Unlock() - // Normalize the addresses with correct scheme prefix. - var scheme string - if ci.tlsConf == nil { - scheme = httpScheme - } else { - scheme = httpsScheme - } - for i, addr := range pdAddrs { - if strings.HasPrefix(addr, httpScheme) { - continue - } - pdAddrs[i] = fmt.Sprintf("%s://%s", scheme, addr) - } - ci.pdAddrs = pdAddrs - ci.leaderAddrIdx = leaderAddrIdx -} - func (ci *clientInner) reqCounter(name, status string) { if ci.requestCounter == nil { return @@ -151,32 +119,19 @@ func (ci *clientInner) requestWithRetry( err error ) execFunc := func() error { - var ( - addr string - pdAddrs, leaderAddrIdx = ci.getPDAddrs() - ) - // Try to send the request to the PD leader first. - if leaderAddrIdx != -1 { - addr = pdAddrs[leaderAddrIdx] - statusCode, err = ci.doRequest(ctx, addr, reqInfo, headerOpts...) - if err == nil || noNeedRetry(statusCode) { - return err - } - log.Debug("[pd] request leader addr failed", - zap.String("source", ci.source), zap.Int("leader-idx", leaderAddrIdx), zap.String("addr", addr), zap.Error(err)) + // It will try to send the request to the PD leader first and then try to send the request to the other PD followers. + clients := ci.sd.GetAllServiceClients() + if len(clients) == 0 { + return errs.ErrClientNoAvailableMember } - // Try to send the request to the other PD followers. - for idx := 0; idx < len(pdAddrs); idx++ { - if idx == leaderAddrIdx { - continue - } - addr = ci.pdAddrs[idx] + for _, cli := range clients { + addr := cli.GetHTTPAddress() statusCode, err = ci.doRequest(ctx, addr, reqInfo, headerOpts...) if err == nil || noNeedRetry(statusCode) { - break + return err } - log.Debug("[pd] request follower addr failed", - zap.String("source", ci.source), zap.Int("idx", idx), zap.String("addr", addr), zap.Error(err)) + log.Debug("[pd] request addr failed", + zap.String("source", ci.source), zap.Bool("is-leader", cli.IsConnectedToLeader()), zap.String("addr", addr), zap.Error(err)) } return err } @@ -278,73 +233,6 @@ func (ci *clientInner) doRequest( return resp.StatusCode, nil } -func (ci *clientInner) membersInfoUpdater(ctx context.Context) { - ci.updateMembersInfo(ctx) - log.Info("[pd] http client member info updater started", zap.String("source", ci.source)) - ticker := time.NewTicker(defaultMembersInfoUpdateInterval) - defer ticker.Stop() - for { - select { - case <-ctx.Done(): - log.Info("[pd] http client member info updater stopped", zap.String("source", ci.source)) - return - case <-ticker.C: - ci.updateMembersInfo(ctx) - } - } -} - -func (ci *clientInner) updateMembersInfo(ctx context.Context) { - var membersInfo MembersInfo - err := ci.requestWithRetry(ctx, newRequestInfo(). - WithCallerID(fmt.Sprintf("%s-%s", ci.source, defaultInnerCallerID)). - WithName(getMembersName). - WithURI(membersPrefix). - WithMethod(http.MethodGet). - WithResp(&membersInfo)) - if err != nil { - log.Error("[pd] http client get members info failed", zap.String("source", ci.source), zap.Error(err)) - return - } - if len(membersInfo.Members) == 0 { - log.Error("[pd] http client get empty members info", zap.String("source", ci.source)) - return - } - var ( - newPDAddrs []string - newLeaderAddrIdx int = -1 - ) - for _, member := range membersInfo.Members { - if membersInfo.Leader != nil && member.GetMemberId() == membersInfo.Leader.GetMemberId() { - newLeaderAddrIdx = len(newPDAddrs) - } - newPDAddrs = append(newPDAddrs, member.GetClientUrls()...) - } - // Prevent setting empty addresses. - if len(newPDAddrs) == 0 { - log.Error("[pd] http client get empty member addresses", zap.String("source", ci.source)) - return - } - oldPDAddrs, oldLeaderAddrIdx := ci.getPDAddrs() - ci.setPDAddrs(newPDAddrs, newLeaderAddrIdx) - // Log the member info change if it happens. - var oldPDLeaderAddr, newPDLeaderAddr string - if oldLeaderAddrIdx != -1 { - oldPDLeaderAddr = oldPDAddrs[oldLeaderAddrIdx] - } - if newLeaderAddrIdx != -1 { - newPDLeaderAddr = newPDAddrs[newLeaderAddrIdx] - } - oldMemberNum, newMemberNum := len(oldPDAddrs), len(newPDAddrs) - if oldPDLeaderAddr != newPDLeaderAddr || oldMemberNum != newMemberNum { - log.Info("[pd] http client members info changed", zap.String("source", ci.source), - zap.Int("old-member-num", oldMemberNum), zap.Int("new-member-num", newMemberNum), - zap.Strings("old-addrs", oldPDAddrs), zap.Strings("new-addrs", newPDAddrs), - zap.Int("old-leader-addr-idx", oldLeaderAddrIdx), zap.Int("new-leader-addr-idx", newLeaderAddrIdx), - zap.String("old-leader-addr", oldPDLeaderAddr), zap.String("new-leader-addr", newPDLeaderAddr)) - } -} - type client struct { inner *clientInner @@ -397,19 +285,36 @@ func WithLoggerRedirection(logLevel, fileName string) ClientOption { return func(c *client) {} } +// NewClientWithServiceDiscovery creates a PD HTTP client with the given PD service discovery. +func NewClientWithServiceDiscovery( + source string, + sd pd.ServiceDiscovery, + opts ...ClientOption, +) Client { + ctx, cancel := context.WithCancel(context.Background()) + c := &client{inner: newClientInner(ctx, cancel, source), callerID: defaultCallerID} + // Apply the options first. + for _, opt := range opts { + opt(c) + } + c.inner.init(sd) + return c +} + // NewClient creates a PD HTTP client with the given PD addresses and TLS config. func NewClient( source string, pdAddrs []string, opts ...ClientOption, ) Client { - c := &client{inner: newClientInner(source), callerID: defaultCallerID} + ctx, cancel := context.WithCancel(context.Background()) + c := &client{inner: newClientInner(ctx, cancel, source), callerID: defaultCallerID} // Apply the options first. for _, opt := range opts { opt(c) } - c.inner.setPDAddrs(pdAddrs, -1) - c.inner.init() + sd := pd.NewDefaultPDServiceDiscovery(ctx, cancel, pdAddrs, c.inner.tlsConf) + c.inner.init(sd) return c } @@ -466,16 +371,17 @@ func (c *client) request(ctx context.Context, reqInfo *requestInfo, headerOpts . headerOpts...) } -// UpdateMembersInfo updates the members info of the PD cluster in the inner client. -// Exported for testing. -func (c *client) UpdateMembersInfo() { - c.inner.updateMembersInfo(c.inner.ctx) +// requestChecker is used to check the HTTP request sent by the client. +type requestChecker func(req *http.Request) error + +// RoundTrip implements the `http.RoundTripper` interface. +func (rc requestChecker) RoundTrip(req *http.Request) (resp *http.Response, err error) { + return &http.Response{StatusCode: http.StatusOK}, rc(req) } -// setLeaderAddrIdx sets the index of the leader address in the inner client. -// only used for testing. -func (c *client) setLeaderAddrIdx(idx int) { - c.inner.Lock() - defer c.inner.Unlock() - c.inner.leaderAddrIdx = idx +// NewHTTPClientWithRequestChecker returns a http client with checker. +func NewHTTPClientWithRequestChecker(checker requestChecker) *http.Client { + return &http.Client{ + Transport: checker, + } } diff --git a/client/http/client_test.go b/client/http/client_test.go index 558a9b12a982..b9fcb5a75e0f 100644 --- a/client/http/client_test.go +++ b/client/http/client_test.go @@ -16,55 +16,19 @@ package http import ( "context" - "crypto/tls" "net/http" "strings" "testing" "time" - "github.com/pingcap/errors" - "github.com/prometheus/client_golang/prometheus" - dto "github.com/prometheus/client_model/go" "github.com/stretchr/testify/require" "github.com/tikv/pd/client/retry" "go.uber.org/atomic" ) -func TestPDAddrNormalization(t *testing.T) { - re := require.New(t) - c := NewClient("test-http-pd-addr", []string{"127.0.0.1"}) - pdAddrs, leaderAddrIdx := c.(*client).inner.getPDAddrs() - re.Len(pdAddrs, 1) - re.Equal(-1, leaderAddrIdx) - re.Contains(pdAddrs[0], httpScheme) - c.Close() - c = NewClient("test-https-pd-addr", []string{"127.0.0.1"}, WithTLSConfig(&tls.Config{})) - pdAddrs, leaderAddrIdx = c.(*client).inner.getPDAddrs() - re.Len(pdAddrs, 1) - re.Equal(-1, leaderAddrIdx) - re.Contains(pdAddrs[0], httpsScheme) - c.Close() -} - -// requestChecker is used to check the HTTP request sent by the client. -type requestChecker struct { - checker func(req *http.Request) error -} - -// RoundTrip implements the `http.RoundTripper` interface. -func (rc *requestChecker) RoundTrip(req *http.Request) (resp *http.Response, err error) { - return &http.Response{StatusCode: http.StatusOK}, rc.checker(req) -} - -func newHTTPClientWithRequestChecker(checker func(req *http.Request) error) *http.Client { - return &http.Client{ - Transport: &requestChecker{checker: checker}, - } -} - func TestPDAllowFollowerHandleHeader(t *testing.T) { re := require.New(t) - httpClient := newHTTPClientWithRequestChecker(func(req *http.Request) error { + httpClient := NewHTTPClientWithRequestChecker(func(req *http.Request) error { var expectedVal string if req.URL.Path == HotHistory { expectedVal = "true" @@ -85,7 +49,7 @@ func TestPDAllowFollowerHandleHeader(t *testing.T) { func TestCallerID(t *testing.T) { re := require.New(t) expectedVal := atomic.NewString(defaultCallerID) - httpClient := newHTTPClientWithRequestChecker(func(req *http.Request) error { + httpClient := NewHTTPClientWithRequestChecker(func(req *http.Request) error { val := req.Header.Get(xCallerIDKey) // Exclude the request sent by the inner client. if !strings.Contains(val, defaultInnerCallerID) && val != expectedVal.Load() { @@ -101,77 +65,6 @@ func TestCallerID(t *testing.T) { c.Close() } -func TestRedirectWithMetrics(t *testing.T) { - re := require.New(t) - - pdAddrs := []string{"127.0.0.1", "172.0.0.1", "192.0.0.1"} - metricCnt := prometheus.NewCounterVec( - prometheus.CounterOpts{ - Name: "check", - }, []string{"name", ""}) - - // 1. Test all followers failed, need to send all followers. - httpClient := newHTTPClientWithRequestChecker(func(req *http.Request) error { - if req.URL.Path == Schedulers { - return errors.New("mock error") - } - return nil - }) - c := NewClient("test-http-pd-redirect", pdAddrs, WithHTTPClient(httpClient), WithMetrics(metricCnt, nil)) - pdAddrs, leaderAddrIdx := c.(*client).inner.getPDAddrs() - re.Equal(-1, leaderAddrIdx) - c.CreateScheduler(context.Background(), "test", 0) - var out dto.Metric - failureCnt, err := c.(*client).inner.requestCounter.GetMetricWithLabelValues([]string{createSchedulerName, networkErrorStatus}...) - re.NoError(err) - failureCnt.Write(&out) - re.Equal(float64(3), out.Counter.GetValue()) - c.Close() - - // 2. Test the Leader success, just need to send to leader. - httpClient = newHTTPClientWithRequestChecker(func(req *http.Request) error { - // mock leader success. - if !strings.Contains(pdAddrs[0], req.Host) { - return errors.New("mock error") - } - return nil - }) - c = NewClient("test-http-pd-redirect", pdAddrs, WithHTTPClient(httpClient), WithMetrics(metricCnt, nil)) - // force to update members info. - c.(*client).setLeaderAddrIdx(0) - c.CreateScheduler(context.Background(), "test", 0) - successCnt, err := c.(*client).inner.requestCounter.GetMetricWithLabelValues([]string{createSchedulerName, ""}...) - re.NoError(err) - successCnt.Write(&out) - re.Equal(float64(1), out.Counter.GetValue()) - c.Close() - - // 3. Test when the leader fails, needs to be sent to the follower in order, - // and returns directly if one follower succeeds - httpClient = newHTTPClientWithRequestChecker(func(req *http.Request) error { - // mock leader failure. - if strings.Contains(pdAddrs[0], req.Host) { - return errors.New("mock error") - } - return nil - }) - c = NewClient("test-http-pd-redirect", pdAddrs, WithHTTPClient(httpClient), WithMetrics(metricCnt, nil)) - // force to update members info. - c.(*client).setLeaderAddrIdx(0) - c.CreateScheduler(context.Background(), "test", 0) - successCnt, err = c.(*client).inner.requestCounter.GetMetricWithLabelValues([]string{createSchedulerName, ""}...) - re.NoError(err) - successCnt.Write(&out) - // only one follower success - re.Equal(float64(2), out.Counter.GetValue()) - failureCnt, err = c.(*client).inner.requestCounter.GetMetricWithLabelValues([]string{createSchedulerName, networkErrorStatus}...) - re.NoError(err) - failureCnt.Write(&out) - // leader failure - re.Equal(float64(4), out.Counter.GetValue()) - c.Close() -} - func TestWithBackoffer(t *testing.T) { re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) diff --git a/client/pd_service_discovery.go b/client/pd_service_discovery.go index 914afa2d3fc3..28bb8bbd6611 100644 --- a/client/pd_service_discovery.go +++ b/client/pd_service_discovery.go @@ -17,6 +17,7 @@ package pd import ( "context" "crypto/tls" + "fmt" "reflect" "sort" "strings" @@ -44,6 +45,9 @@ const ( serviceModeUpdateInterval = 3 * time.Second updateMemberTimeout = time.Second // Use a shorter timeout to recover faster from network isolation. updateMemberBackOffBaseTime = 100 * time.Millisecond + + httpScheme = "http" + httpsScheme = "https" ) // MemberHealthCheckInterval might be changed in the unit to shorten the testing time. @@ -96,6 +100,9 @@ type ServiceDiscovery interface { // If the leader ServiceClient meets network problem, // it returns a follower/secondary ServiceClient which can forward the request to leader. GetServiceClient() ServiceClient + // GetAllServiceClients tries to get all ServiceClient. + // If the leader is not nil, it will put the leader service client first in the slice. + GetAllServiceClients() []ServiceClient // GetOrCreateGRPCConn returns the corresponding grpc client connection of the given addr GetOrCreateGRPCConn(addr string) (*grpc.ClientConn, error) // ScheduleCheckMemberChanged is used to trigger a check to see if there is any membership change @@ -119,6 +126,8 @@ type ServiceDiscovery interface { type ServiceClient interface { // GetAddress returns the address information of the PD server. GetAddress() string + // GetHTTPAddress returns the address with HTTP scheme of the PD server. + GetHTTPAddress() string // GetClientConn returns the gRPC connection of the service client GetClientConn() *grpc.ClientConn // BuildGRPCTargetContext builds a context object with a gRPC context. @@ -140,20 +149,43 @@ var ( ) type pdServiceClient struct { - addr string - conn *grpc.ClientConn - isLeader bool - leaderAddr string + addr string + httpAddress string + conn *grpc.ClientConn + isLeader bool + leaderAddr string networkFailure atomic.Bool } -func newPDServiceClient(addr, leaderAddr string, conn *grpc.ClientConn, isLeader bool) ServiceClient { +func newPDServiceClient(addr, leaderAddr string, tlsCfg *tls.Config, conn *grpc.ClientConn, isLeader bool) ServiceClient { + var httpAddress string + if tlsCfg == nil { + if strings.HasPrefix(addr, httpsScheme) { + addr = strings.TrimPrefix(addr, httpsScheme) + httpAddress = fmt.Sprintf("%s%s", httpScheme, addr) + } else if strings.HasPrefix(addr, httpScheme) { + httpAddress = addr + } else { + httpAddress = fmt.Sprintf("%s://%s", httpScheme, addr) + } + } else { + if strings.HasPrefix(addr, httpsScheme) { + httpAddress = addr + } else if strings.HasPrefix(addr, httpScheme) { + addr = strings.TrimPrefix(addr, httpScheme) + httpAddress = fmt.Sprintf("%s%s", httpsScheme, addr) + } else { + httpAddress = fmt.Sprintf("%s://%s", httpsScheme, addr) + } + } + cli := &pdServiceClient{ - addr: addr, - conn: conn, - isLeader: isLeader, - leaderAddr: leaderAddr, + addr: addr, + httpAddress: httpAddress, + conn: conn, + isLeader: isLeader, + leaderAddr: leaderAddr, } if conn == nil { cli.networkFailure.Store(true) @@ -169,6 +201,14 @@ func (c *pdServiceClient) GetAddress() string { return c.addr } +// GetHTTPAddress implements ServiceClient. +func (c *pdServiceClient) GetHTTPAddress() string { + if c == nil { + return "" + } + return c.httpAddress +} + // BuildGRPCTargetContext implements ServiceClient. func (c *pdServiceClient) BuildGRPCTargetContext(ctx context.Context, toLeader bool) context.Context { if c == nil || c.isLeader { @@ -325,11 +365,11 @@ func (c *pdServiceBalancer) set(clients []ServiceClient) { } c.totalNode = len(clients) head := &pdServiceBalancerNode{ - pdServiceAPIClient: newPDServiceAPIClient(clients[0], c.errFn).(*pdServiceAPIClient), + pdServiceAPIClient: newPDServiceAPIClient(clients[c.totalNode-1], c.errFn).(*pdServiceAPIClient), } head.next = head last := head - for i := 1; i < c.totalNode; i++ { + for i := c.totalNode - 2; i >= 0; i-- { next := &pdServiceBalancerNode{ pdServiceAPIClient: newPDServiceAPIClient(clients[i], c.errFn).(*pdServiceAPIClient), next: head, @@ -392,10 +432,12 @@ type pdServiceDiscovery struct { isInitialized bool urls atomic.Value // Store as []string - // PD leader URL + // PD leader leader atomic.Value // Store as pdServiceClient - // PD follower URLs - followers sync.Map // Store as map[string]pdServiceClient + // PD follower + followers sync.Map // Store as map[string]pdServiceClient + // PD leader and PD followers + all atomic.Value // Store as []pdServiceClient apiCandidateNodes [apiKindCount]*pdServiceBalancer // PD follower URLs. Only for tso. followerAddresses atomic.Value // Store as []string @@ -432,6 +474,26 @@ type pdServiceDiscovery struct { option *option } +// NewDefaultPDServiceDiscovery returns a new default PD service discovery-based client. +func NewDefaultPDServiceDiscovery( + ctx context.Context, cancel context.CancelFunc, + urls []string, tlsCfg *tls.Config, +) *pdServiceDiscovery { + var wg sync.WaitGroup + pdsd := &pdServiceDiscovery{ + checkMembershipCh: make(chan struct{}, 1), + ctx: ctx, + cancel: cancel, + wg: &wg, + apiCandidateNodes: [apiKindCount]*pdServiceBalancer{newPDServiceBalancer(emptyErrorFn), newPDServiceBalancer(regionAPIErrorFn)}, + keyspaceID: defaultKeyspaceID, + tlsCfg: tlsCfg, + option: newOption(), + } + pdsd.urls.Store(urls) + return pdsd +} + // newPDServiceDiscovery returns a new PD service discovery-based client. func newPDServiceDiscovery( ctx context.Context, cancel context.CancelFunc, @@ -732,6 +794,16 @@ func (c *pdServiceDiscovery) GetServiceClient() ServiceClient { return leaderClient } +// GetAllServiceClients implments ServiceDiscovery +func (c *pdServiceDiscovery) GetAllServiceClients() []ServiceClient { + all := c.all.Load() + if all == nil { + return nil + } + ret := all.([]ServiceClient) + return append(ret[:0:0], ret...) +} + // ScheduleCheckMemberChanged is used to check if there is any membership // change among the leader and the followers. func (c *pdServiceDiscovery) ScheduleCheckMemberChanged() { @@ -964,7 +1036,7 @@ func (c *pdServiceDiscovery) switchLeader(addrs []string) (bool, error) { // If gRPC connect is created successfully or leader is new, still saves. if addr != oldLeader.GetAddress() || newConn != nil { // Set PD leader and Global TSO Allocator (which is also the PD leader) - leaderClient := newPDServiceClient(addr, addr, newConn, true) + leaderClient := newPDServiceClient(addr, addr, c.tlsCfg, newConn, true) c.leader.Store(leaderClient) } // Run callbacks @@ -1001,7 +1073,7 @@ func (c *pdServiceDiscovery) updateFollowers(members []*pdpb.Member, leader *pdp log.Warn("[pd] failed to connect follower", zap.String("follower", addr), errs.ZapError(err)) continue } - follower := newPDServiceClient(addr, leader.GetClientUrls()[0], conn, false) + follower := newPDServiceClient(addr, leader.GetClientUrls()[0], c.tlsCfg, conn, false) c.followers.Store(addr, follower) changed = true } @@ -1009,7 +1081,7 @@ func (c *pdServiceDiscovery) updateFollowers(members []*pdpb.Member, leader *pdp } else { changed = true conn, err := c.GetOrCreateGRPCConn(addr) - follower := newPDServiceClient(addr, leader.GetClientUrls()[0], conn, false) + follower := newPDServiceClient(addr, leader.GetClientUrls()[0], c.tlsCfg, conn, false) if err != nil || conn == nil { log.Warn("[pd] failed to connect follower", zap.String("follower", addr), errs.ZapError(err)) } @@ -1037,14 +1109,15 @@ func (c *pdServiceDiscovery) updateServiceClient(members []*pdpb.Member, leader } // If error is not nil, still updates candidates. clients := make([]ServiceClient, 0) - c.followers.Range(func(_, value any) bool { - clients = append(clients, value.(*pdServiceClient)) - return true - }) leaderClient := c.getLeaderServiceClient() if leaderClient != nil { clients = append(clients, leaderClient) } + c.followers.Range(func(_, value any) bool { + clients = append(clients, value.(*pdServiceClient)) + return true + }) + c.all.Store(clients) // create candidate services for all kinds of request. for i := 0; i < int(apiKindCount); i++ { c.apiCandidateNodes[i].set(clients) diff --git a/client/pd_service_discovery_test.go b/client/pd_service_discovery_test.go index 98c5daed561c..1dc73af1f5f5 100644 --- a/client/pd_service_discovery_test.go +++ b/client/pd_service_discovery_test.go @@ -16,6 +16,7 @@ package pd import ( "context" + "crypto/tls" "errors" "log" "net" @@ -26,6 +27,7 @@ import ( "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "github.com/tikv/pd/client/grpcutil" "github.com/tikv/pd/client/testutil" @@ -137,8 +139,8 @@ func (suite *serviceClientTestSuite) SetupSuite() { leaderConn, err1 := grpc.Dial(suite.leaderServer.addr, grpc.WithInsecure()) //nolint followerConn, err2 := grpc.Dial(suite.followerServer.addr, grpc.WithInsecure()) //nolint if err1 == nil && err2 == nil { - suite.followerClient = newPDServiceClient(suite.followerServer.addr, suite.leaderServer.addr, followerConn, false) - suite.leaderClient = newPDServiceClient(suite.leaderServer.addr, suite.leaderServer.addr, leaderConn, true) + suite.followerClient = newPDServiceClient(suite.followerServer.addr, suite.leaderServer.addr, nil, followerConn, false) + suite.leaderClient = newPDServiceClient(suite.leaderServer.addr, suite.leaderServer.addr, nil, leaderConn, true) suite.followerServer.server.leaderConn = suite.leaderClient.GetClientConn() suite.followerServer.server.leaderAddr = suite.leaderClient.GetAddress() return @@ -169,6 +171,8 @@ func (suite *serviceClientTestSuite) TestServiceClient() { re.Equal(follower.GetAddress(), followerAddress) re.Equal(leader.GetAddress(), leaderAddress) + re.Equal(follower.GetHTTPAddress(), "http://"+followerAddress) + re.Equal(leader.GetHTTPAddress(), "http://"+leaderAddress) re.True(follower.Available()) re.True(leader.Available()) @@ -293,3 +297,19 @@ func (suite *serviceClientTestSuite) TestServiceClientBalancer() { re.Equal(int32(0), suite.followerServer.server.getHandleCount()) re.Equal(int32(5), suite.followerServer.server.getForwardCount()) } + +func TestHTTPScheme(t *testing.T) { + re := require.New(t) + cli := newPDServiceClient("127.0.0.1:2379", "127.0.0.1:2379", nil, nil, false) + re.Equal("http://127.0.0.1:2379", cli.GetHTTPAddress()) + cli = newPDServiceClient("https://127.0.0.1:2379", "127.0.0.1:2379", nil, nil, false) + re.Equal("http://127.0.0.1:2379", cli.GetHTTPAddress()) + cli = newPDServiceClient("http://127.0.0.1:2379", "127.0.0.1:2379", nil, nil, false) + re.Equal("http://127.0.0.1:2379", cli.GetHTTPAddress()) + cli = newPDServiceClient("127.0.0.1:2379", "127.0.0.1:2379", &tls.Config{}, nil, false) + re.Equal("https://127.0.0.1:2379", cli.GetHTTPAddress()) + cli = newPDServiceClient("https://127.0.0.1:2379", "127.0.0.1:2379", &tls.Config{}, nil, false) + re.Equal("https://127.0.0.1:2379", cli.GetHTTPAddress()) + cli = newPDServiceClient("http://127.0.0.1:2379", "127.0.0.1:2379", &tls.Config{}, nil, false) + re.Equal("https://127.0.0.1:2379", cli.GetHTTPAddress()) +} diff --git a/client/tso_service_discovery.go b/client/tso_service_discovery.go index d439af85b611..3d7c0745f498 100644 --- a/client/tso_service_discovery.go +++ b/client/tso_service_discovery.go @@ -378,6 +378,11 @@ func (c *tsoServiceDiscovery) GetServiceClient() ServiceClient { return c.apiSvcDiscovery.GetServiceClient() } +// GetAllServiceClients implements ServiceDiscovery +func (c *tsoServiceDiscovery) GetAllServiceClients() []ServiceClient { + return c.apiSvcDiscovery.GetAllServiceClients() +} + // getPrimaryAddr returns the primary address. func (c *tsoServiceDiscovery) getPrimaryAddr() string { c.keyspaceGroupSD.RLock() diff --git a/tests/integrations/client/http_client_test.go b/tests/integrations/client/http_client_test.go index 54337ed470e8..8c9e7722ef7f 100644 --- a/tests/integrations/client/http_client_test.go +++ b/tests/integrations/client/http_client_test.go @@ -16,14 +16,19 @@ package client_test import ( "context" + "errors" "math" "net/http" "sort" + "strings" "testing" "time" + "github.com/prometheus/client_golang/prometheus" + dto "github.com/prometheus/client_model/go" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + pdCli "github.com/tikv/pd/client" pd "github.com/tikv/pd/client/http" "github.com/tikv/pd/client/retry" "github.com/tikv/pd/pkg/core" @@ -41,6 +46,7 @@ type httpClientTestSuite struct { cancelFunc context.CancelFunc cluster *tests.TestCluster client pd.Client + sd pdCli.ServiceDiscovery } func TestHTTPClientTestSuite(t *testing.T) { @@ -74,7 +80,9 @@ func (suite *httpClientTestSuite) SetupSuite() { for _, s := range testServers { endpoints = append(endpoints, s.GetConfig().AdvertiseClientUrls) } - suite.client = pd.NewClient("pd-http-client-it", endpoints) + cli := setupCli(re, suite.ctx, endpoints) + suite.sd = cli.GetServiceDiscovery() + suite.client = pd.NewClientWithServiceDiscovery("pd-http-client-it", suite.sd) } func (suite *httpClientTestSuite) TearDownSuite() { @@ -475,10 +483,11 @@ func (suite *httpClientTestSuite) TestTransferLeader() { re.NoError(err) re.NotEqual(leader.GetName(), newLeader) // Force to update the members info. - suite.client.(interface{ UpdateMembersInfo() }).UpdateMembersInfo() - leader, err = suite.client.GetLeader(suite.ctx) - re.NoError(err) - re.Equal(newLeader, leader.GetName()) + testutil.Eventually(re, func() bool { + leader, err = suite.client.GetLeader(suite.ctx) + re.NoError(err) + return newLeader == leader.GetName() + }) members, err = suite.client.GetMembers(suite.ctx) re.NoError(err) re.Len(members.Members, 2) @@ -505,3 +514,61 @@ func (suite *httpClientTestSuite) TestWithBackoffer() { re.ErrorContains(err, http.StatusText(http.StatusNotFound)) re.Nil(rule) } + +func (suite *httpClientTestSuite) TestRedirectWithMetrics() { + re := suite.Require() + metricCnt := prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "check", + }, []string{"name", ""}) + // 1. Test all followers failed, need to send all followers. + httpClient := pd.NewHTTPClientWithRequestChecker(func(req *http.Request) error { + if req.URL.Path == pd.Schedulers { + return errors.New("mock error") + } + return nil + }) + c := pd.NewClientWithServiceDiscovery("pd-http-client-it", suite.sd, pd.WithHTTPClient(httpClient), pd.WithMetrics(metricCnt, nil)) + c.CreateScheduler(context.Background(), "test", 0) + var out dto.Metric + failureCnt, err := metricCnt.GetMetricWithLabelValues([]string{"CreateScheduler", "network error"}...) + re.NoError(err) + failureCnt.Write(&out) + re.Equal(float64(2), out.Counter.GetValue()) + c.Close() + + leader := suite.sd.GetServingAddr() + httpClient = pd.NewHTTPClientWithRequestChecker(func(req *http.Request) error { + // mock leader success. + if !strings.Contains(leader, req.Host) { + return errors.New("mock error") + } + return nil + }) + c = pd.NewClientWithServiceDiscovery("pd-http-client-it", suite.sd, pd.WithHTTPClient(httpClient), pd.WithMetrics(metricCnt, nil)) + c.CreateScheduler(context.Background(), "test", 0) + successCnt, err := metricCnt.GetMetricWithLabelValues([]string{"CreateScheduler", ""}...) + re.NoError(err) + successCnt.Write(&out) + re.Equal(float64(1), out.Counter.GetValue()) + c.Close() + + httpClient = pd.NewHTTPClientWithRequestChecker(func(req *http.Request) error { + // mock leader success. + if strings.Contains(leader, req.Host) { + return errors.New("mock error") + } + return nil + }) + c = pd.NewClientWithServiceDiscovery("pd-http-client-it", suite.sd, pd.WithHTTPClient(httpClient), pd.WithMetrics(metricCnt, nil)) + c.CreateScheduler(context.Background(), "test", 0) + successCnt, err = metricCnt.GetMetricWithLabelValues([]string{"CreateScheduler", ""}...) + re.NoError(err) + successCnt.Write(&out) + re.Equal(float64(2), out.Counter.GetValue()) + failureCnt, err = metricCnt.GetMetricWithLabelValues([]string{"CreateScheduler", "network error"}...) + re.NoError(err) + failureCnt.Write(&out) + re.Equal(float64(3), out.Counter.GetValue()) + c.Close() +} diff --git a/tests/integrations/go.mod b/tests/integrations/go.mod index ca412ad696f0..06db591af100 100644 --- a/tests/integrations/go.mod +++ b/tests/integrations/go.mod @@ -18,6 +18,8 @@ require ( github.com/pingcap/failpoint v0.0.0-20220801062533-2eaa32854a6c github.com/pingcap/kvproto v0.0.0-20231226064240-4f28b82c7860 github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 + github.com/prometheus/client_golang v1.11.1 + github.com/prometheus/client_model v0.4.0 github.com/stretchr/testify v1.8.4 github.com/tikv/pd v0.0.0-00010101000000-000000000000 github.com/tikv/pd/client v0.0.0-00010101000000-000000000000 @@ -132,8 +134,6 @@ require ( github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/power-devops/perfstat v0.0.0-20221212215047-62379fc7944b // indirect - github.com/prometheus/client_golang v1.11.1 // indirect - github.com/prometheus/client_model v0.4.0 // indirect github.com/prometheus/common v0.26.0 // indirect github.com/prometheus/procfs v0.6.0 // indirect github.com/rs/cors v1.7.0 // indirect diff --git a/tools/pd-api-bench/main.go b/tools/pd-api-bench/main.go index d7f8209dd3bf..56e7ee761b22 100644 --- a/tools/pd-api-bench/main.go +++ b/tools/pd-api-bench/main.go @@ -161,7 +161,8 @@ func main() { } httpClis := make([]pdHttp.Client, cfg.Client) for i := int64(0); i < cfg.Client; i++ { - httpClis[i] = pdHttp.NewClient("tools-api-bench", []string{cfg.PDAddr}, pdHttp.WithTLSConfig(loadTLSConfig(cfg))) + sd := pdClis[i].GetServiceDiscovery() + httpClis[i] = pdHttp.NewClientWithServiceDiscovery("tools-api-bench", sd, pdHttp.WithTLSConfig(loadTLSConfig(cfg))) } err = cases.InitCluster(ctx, pdClis[0], httpClis[0]) if err != nil {