diff --git a/pkg/apiutil/apiutil.go b/pkg/apiutil/apiutil.go index a5ed3bc0256..83c2c3b75ec 100644 --- a/pkg/apiutil/apiutil.go +++ b/pkg/apiutil/apiutil.go @@ -28,11 +28,11 @@ import ( ) var ( - // ComponentSignatureKey is used for http request header key + // componentSignatureKey is used for http request header key // to identify component signature - ComponentSignatureKey = "component" - // ComponentAnonymousValue identifies anonymous request source - ComponentAnonymousValue = "anonymous" + componentSignatureKey = "component" + // componentAnonymousValue identifies anonymous request source + componentAnonymousValue = "anonymous" ) // DeferClose captures the error returned from closing (if an error occurs). @@ -138,23 +138,31 @@ func ErrorResp(rd *render.Render, w http.ResponseWriter, err error) { // GetComponentNameOnHTTP returns component name from Request Header func GetComponentNameOnHTTP(r *http.Request) string { - componentName := r.Header.Get(ComponentSignatureKey) + componentName := r.Header.Get(componentSignatureKey) if componentName == "" { - componentName = ComponentAnonymousValue + componentName = componentAnonymousValue } return componentName } // ComponentSignatureRoundTripper is used to add component signature in HTTP header type ComponentSignatureRoundTripper struct { - Proxied http.RoundTripper - Component string + proxied http.RoundTripper + component *string +} + +// NewComponentSignatureRoundTripper returns a new ComponentSignatureRoundTripper. +func NewComponentSignatureRoundTripper(roundTripper http.RoundTripper, componentName *string) *ComponentSignatureRoundTripper { + return &ComponentSignatureRoundTripper{ + proxied: roundTripper, + component: componentName, + } } // RoundTrip is used to implement RoundTripper func (rt *ComponentSignatureRoundTripper) RoundTrip(req *http.Request) (resp *http.Response, err error) { - req.Header.Add(ComponentSignatureKey, rt.Component) + req.Header.Add(componentSignatureKey, *rt.component) // Send the request, get the response and the error - resp, err = rt.Proxied.RoundTrip(req) + resp, err = rt.proxied.RoundTrip(req) return } diff --git a/tools/pd-ctl/pdctl/command/global.go b/tools/pd-ctl/pdctl/command/global.go index a60a221a59b..21a0f3d4ec8 100644 --- a/tools/pd-ctl/pdctl/command/global.go +++ b/tools/pd-ctl/pdctl/command/global.go @@ -30,11 +30,9 @@ import ( ) var ( - dialClient = &http.Client{ - Transport: &apiutil.ComponentSignatureRoundTripper{ - Component: "pdctl", - Proxied: http.DefaultTransport, - }, + pdControllerComponentName = "pdctl" + dialClient = &http.Client{ + Transport: apiutil.NewComponentSignatureRoundTripper(http.DefaultTransport, &pdControllerComponentName), } pingPrefix = "pd/api/v1/ping" ) @@ -52,11 +50,8 @@ func InitHTTPSClient(caPath, certPath, keyPath string) error { } dialClient = &http.Client{ - Transport: &apiutil.ComponentSignatureRoundTripper{ - Component: "pdctl", - Proxied: &http.Transport{ - TLSClientConfig: tlsConfig}, - }, + Transport: apiutil.NewComponentSignatureRoundTripper( + &http.Transport{TLSClientConfig: tlsConfig}, &pdControllerComponentName), } return nil