Skip to content

Commit

Permalink
client/http: introduce caller ID into the HTTP client (#7490)
Browse files Browse the repository at this point in the history
ref #7300

Introduce caller ID into the HTTP client.

Signed-off-by: JmPotato <ghzpotato@gmail.com>
  • Loading branch information
JmPotato authored Dec 4, 2023
1 parent 259435d commit 080af97
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 17 deletions.
66 changes: 49 additions & 17 deletions client/http/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
)

const (
defaultCallerID = "pd-http-client"
httpScheme = "http"
httpsScheme = "https"
networkErrorStatus = "network error"
Expand Down Expand Up @@ -79,6 +80,8 @@ type Client interface {
GetMinResolvedTSByStoresIDs(context.Context, []uint64) (uint64, map[uint64]uint64, error)

/* Client-related methods */
// WithCallerID sets and returns a new client with the given caller ID.
WithCallerID(string) Client
// WithRespHandler sets and returns a new client with the given HTTP response handler.
// This allows the caller to customize how the response is handled, including error handling logic.
// Additionally, it is important for the caller to handle the content of the response body properly
Expand All @@ -89,11 +92,20 @@ type Client interface {

var _ Client = (*client)(nil)

type client struct {
// clientInner is the inner implementation of the PD HTTP client, which will
// implement some internal logics, such as HTTP client, service discovery, etc.
type clientInner struct {
pdAddrs []string
tlsConf *tls.Config
cli *http.Client
}

type client struct {
// Wrap this struct is to make sure the inner implementation
// won't be exposed and cloud be consistent during the copy.
inner *clientInner

callerID string
respHandler func(resp *http.Response, res interface{}) error

requestCounter *prometheus.CounterVec
Expand All @@ -106,15 +118,15 @@ type ClientOption func(c *client)
// WithHTTPClient configures the client with the given initialized HTTP client.
func WithHTTPClient(cli *http.Client) ClientOption {
return func(c *client) {
c.cli = cli
c.inner.cli = cli
}
}

// WithTLSConfig configures the client with the given TLS config.
// This option won't work if the client is configured with WithHTTPClient.
func WithTLSConfig(tlsConf *tls.Config) ClientOption {
return func(c *client) {
c.tlsConf = tlsConf
c.inner.tlsConf = tlsConf
}
}

Expand All @@ -134,7 +146,7 @@ func NewClient(
pdAddrs []string,
opts ...ClientOption,
) Client {
c := &client{}
c := &client{inner: &clientInner{}, callerID: defaultCallerID}
// Apply the options first.
for _, opt := range opts {
opt(c)
Expand All @@ -143,22 +155,22 @@ func NewClient(
for i, addr := range pdAddrs {
if !strings.HasPrefix(addr, httpScheme) {
var scheme string
if c.tlsConf != nil {
if c.inner.tlsConf != nil {
scheme = httpsScheme
} else {
scheme = httpScheme
}
pdAddrs[i] = fmt.Sprintf("%s://%s", scheme, addr)
}
}
c.pdAddrs = pdAddrs
c.inner.pdAddrs = pdAddrs
// Init the HTTP client if it's not configured.
if c.cli == nil {
c.cli = &http.Client{Timeout: defaultTimeout}
if c.tlsConf != nil {
if c.inner.cli == nil {
c.inner.cli = &http.Client{Timeout: defaultTimeout}
if c.inner.tlsConf != nil {
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.TLSClientConfig = c.tlsConf
c.cli.Transport = transport
transport.TLSClientConfig = c.inner.tlsConf
c.inner.cli.Transport = transport
}
}

Expand All @@ -167,12 +179,22 @@ func NewClient(

// Close closes the HTTP client.
func (c *client) Close() {
if c.cli != nil {
c.cli.CloseIdleConnections()
if c.inner == nil {
return
}
if c.inner.cli != nil {
c.inner.cli.CloseIdleConnections()
}
log.Info("[pd] http client closed")
}

// WithCallerID sets and returns a new client with the given caller ID.
func (c *client) WithCallerID(callerID string) Client {
newClient := *c
newClient.callerID = callerID
return &newClient
}

// WithRespHandler sets and returns a new client with the given HTTP response handler.
func (c *client) WithRespHandler(
handler func(resp *http.Response, res interface{}) error,
Expand All @@ -196,13 +218,19 @@ func (c *client) execDuration(name string, duration time.Duration) {
c.executionDuration.WithLabelValues(name).Observe(duration.Seconds())
}

// Header key definition constants.
const (
pdAllowFollowerHandleKey = "PD-Allow-Follower-Handle"
componentSignatureKey = "component"
)

// HeaderOption configures the HTTP header.
type HeaderOption func(header http.Header)

// WithAllowFollowerHandle sets the header field to allow a PD follower to handle this request.
func WithAllowFollowerHandle() HeaderOption {
return func(header http.Header) {
header.Set("PD-Allow-Follower-Handle", "true")
header.Set(pdAllowFollowerHandleKey, "true")
}
}

Expand All @@ -218,8 +246,8 @@ func (c *client) requestWithRetry(
err error
addr string
)
for idx := 0; idx < len(c.pdAddrs); idx++ {
addr = c.pdAddrs[idx]
for idx := 0; idx < len(c.inner.pdAddrs); idx++ {
addr = c.inner.pdAddrs[idx]
err = c.request(ctx, name, fmt.Sprintf("%s%s", addr, uri), method, body, res, headerOpts...)
if err == nil {
break
Expand All @@ -239,6 +267,8 @@ func (c *client) request(
logFields := []zap.Field{
zap.String("name", name),
zap.String("url", url),
zap.String("method", method),
zap.String("caller-id", c.callerID),
}
log.Debug("[pd] request the http url", logFields...)
req, err := http.NewRequestWithContext(ctx, method, url, body)
Expand All @@ -249,8 +279,10 @@ func (c *client) request(
for _, opt := range headerOpts {
opt(req.Header)
}
req.Header.Set(componentSignatureKey, c.callerID)

start := time.Now()
resp, err := c.cli.Do(req)
resp, err := c.inner.cli.Do(req)
if err != nil {
c.reqCounter(name, networkErrorStatus)
log.Error("[pd] do http request failed", append(logFields, zap.Error(err))...)
Expand Down
73 changes: 73 additions & 0 deletions client/http/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Copyright 2023 TiKV Project Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package http

import (
"context"
"net/http"
"testing"

"github.com/stretchr/testify/require"
)

// requestChecker is used to check the HTTP request sent by the client.
type requestChecker struct {
checker func(req *http.Request) error
}

// RoundTrip implements the `http.RoundTripper` interface.
func (rc *requestChecker) RoundTrip(req *http.Request) (resp *http.Response, err error) {
return &http.Response{StatusCode: http.StatusOK}, rc.checker(req)
}

func newHTTPClientWithRequestChecker(checker func(req *http.Request) error) *http.Client {
return &http.Client{
Transport: &requestChecker{checker: checker},
}
}

func TestPDAllowFollowerHandleHeader(t *testing.T) {
re := require.New(t)
var expectedVal string
httpClient := newHTTPClientWithRequestChecker(func(req *http.Request) error {
val := req.Header.Get(pdAllowFollowerHandleKey)
if val != expectedVal {
re.Failf("PD allow follower handler header check failed",
"should be %s, but got %s", expectedVal, val)
}
return nil
})
c := NewClient([]string{"http://127.0.0.1"}, WithHTTPClient(httpClient))
c.GetRegions(context.Background())
expectedVal = "true"
c.GetHistoryHotRegions(context.Background(), &HistoryHotRegionsRequest{})
}

func TestCallerID(t *testing.T) {
re := require.New(t)
expectedVal := defaultCallerID
httpClient := newHTTPClientWithRequestChecker(func(req *http.Request) error {
val := req.Header.Get(componentSignatureKey)
if val != expectedVal {
re.Failf("Caller ID header check failed",
"should be %s, but got %s", expectedVal, val)
}
return nil
})
c := NewClient([]string{"http://127.0.0.1"}, WithHTTPClient(httpClient))
c.GetRegions(context.Background())
expectedVal = "test"
c.WithCallerID(expectedVal).GetRegions(context.Background())
}

0 comments on commit 080af97

Please sign in to comment.