diff --git a/config/http_config.go b/config/http_config.go index 803f6a1c..d3f5be4d 100644 --- a/config/http_config.go +++ b/config/http_config.go @@ -224,8 +224,6 @@ type OAuth2 struct { ProxyURL URL `yaml:"proxy_url,omitempty" json:"proxy_url,omitempty"` // TLSConfig is used to connect to the token URL. TLSConfig TLSConfig `yaml:"tls_config,omitempty"` - // UserAgent is used to set a custom User-Agent http header while making the oauth request. - UserAgent string `yaml:"user_agent,omitempty" json:"user_agent,omitempty"` } // SetDirectory joins any relative file paths with dir. @@ -374,6 +372,7 @@ type httpClientOptions struct { keepAlivesEnabled bool http2Enabled bool idleConnTimeout time.Duration + userAgent string } // HTTPClientOption defines an option that can be applied to the HTTP client. @@ -407,6 +406,13 @@ func WithIdleConnTimeout(timeout time.Duration) HTTPClientOption { } } +// WithIdleConnTimeout allows setting the user agent. +func WithUserAgent(ua string) HTTPClientOption { + return func(opts *httpClientOptions) { + opts.userAgent = ua + } +} + // NewClient returns a http.Client using the specified http.RoundTripper. func newClient(rt http.RoundTripper) *http.Client { return &http.Client{Transport: rt} @@ -499,8 +505,12 @@ func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HT rt = NewBasicAuthRoundTripper(cfg.BasicAuth.Username, cfg.BasicAuth.Password, cfg.BasicAuth.PasswordFile, rt) } + if opts.userAgent != "" { + rt = NewUserAgentRoundTripper(opts.userAgent, rt) + } + if cfg.OAuth2 != nil { - rt = NewOAuth2RoundTripper(cfg.OAuth2, rt) + rt = NewOAuth2RoundTripper(cfg.OAuth2, rt, &opts) } // Return a new configured RoundTripper. return rt, nil @@ -621,12 +631,14 @@ type oauth2RoundTripper struct { next http.RoundTripper secret string mtx sync.RWMutex + opts *httpClientOptions } -func NewOAuth2RoundTripper(config *OAuth2, next http.RoundTripper) http.RoundTripper { +func NewOAuth2RoundTripper(config *OAuth2, next http.RoundTripper, opts *httpClientOptions) http.RoundTripper { return &oauth2RoundTripper{ config: config, next: next, + opts: opts, } } @@ -683,8 +695,8 @@ func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, erro } } - if rt.config.UserAgent != "" { - t = NewUserAgentRoundTripper(rt.config.UserAgent, t) + if rt.opts.userAgent != "" { + t = NewUserAgentRoundTripper(rt.opts.userAgent, t) } ctx := context.WithValue(context.Background(), oauth2.HTTPClient, &http.Client{Transport: t}) diff --git a/config/http_config_test.go b/config/http_config_test.go index 9c9c6237..eb8208a4 100644 --- a/config/http_config_test.go +++ b/config/http_config_test.go @@ -1183,12 +1183,6 @@ type oauth2TestServerResponse struct { func TestOAuth2(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/token" { - if r.Header.Get("User-Agent") != "myuseragent" { - t.Fatalf("Expected User-Agent header in oauth request to be 'myuseragent', got '%s'", r.Header.Get("User-Agent")) - } - } - res, _ := json.Marshal(oauth2TestServerResponse{ AccessToken: "12345", TokenType: "Bearer", @@ -1205,7 +1199,6 @@ scopes: - A - B token_url: %s/token -user_agent: myuseragent endpoint_params: hi: hello `, ts.URL) @@ -1215,7 +1208,6 @@ endpoint_params: Scopes: []string{"A", "B"}, EndpointParams: map[string]string{"hi": "hello"}, TokenURL: fmt.Sprintf("%s/token", ts.URL), - UserAgent: "myuseragent", } var unmarshalledConfig OAuth2 @@ -1227,7 +1219,7 @@ endpoint_params: t.Fatalf("Got unmarshalled config %v, expected %v", unmarshalledConfig, expectedConfig) } - rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport) + rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport, &defaultHTTPClientOptions) client := http.Client{ Transport: rt, @@ -1240,6 +1232,50 @@ endpoint_params: } } +func TestOAuth2UserAgent(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/" { + if r.Header.Get("User-Agent") != "myuseragent" { + t.Fatalf("Expected User-Agent header in oauth request to be 'myuseragent', got '%s'", r.Header.Get("User-Agent")) + } + } + + res, _ := json.Marshal(oauth2TestServerResponse{ + AccessToken: "12345", + TokenType: "Bearer", + }) + w.Header().Add("Content-Type", "application/json") + _, _ = w.Write(res) + })) + defer ts.Close() + + config := &OAuth2{ + ClientID: "1", + ClientSecret: "2", + Scopes: []string{"A", "B"}, + EndpointParams: map[string]string{"hi": "hello"}, + TokenURL: fmt.Sprintf("%s/token", ts.URL), + } + + opts := defaultHTTPClientOptions + WithUserAgent("myuseragent")(&opts) + + rt := NewOAuth2RoundTripper(config, http.DefaultTransport, &opts) + + client := http.Client{ + Transport: rt, + } + resp, err := client.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + + authorization := resp.Request.Header.Get("Authorization") + if authorization != "Bearer 12345" { + t.Fatalf("Expected authorization header to be 'Bearer 12345', got '%s'", authorization) + } +} + func TestOAuth2WithFile(t *testing.T) { var expectedAuth *string var previousAuth string @@ -1302,7 +1338,7 @@ endpoint_params: t.Fatalf("Got unmarshalled config %v, expected %v", unmarshalledConfig, expectedConfig) } - rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport) + rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport, &defaultHTTPClientOptions) client := http.Client{ Transport: rt, @@ -1496,10 +1532,3 @@ func TestOAuth2Proxy(t *testing.T) { t.Errorf("Error loading OAuth2 client config: %v", err) } } - -func TestOAuth2UserAgent(t *testing.T) { - _, _, err := LoadHTTPConfigFile("testdata/http.conf.oauth2-user-agent.good.yml") - if err != nil { - t.Errorf("Error loading OAuth2 client config: %v", err) - } -} diff --git a/config/testdata/http.conf.oauth2-user-agent.good.yml b/config/testdata/http.conf.oauth2-user-agent.good.yml deleted file mode 100644 index a0a407f2..00000000 --- a/config/testdata/http.conf.oauth2-user-agent.good.yml +++ /dev/null @@ -1,5 +0,0 @@ -oauth2: - client_id: "myclient" - client_secret: "mysecret" - token_url: "http://auth" - user_agent: "myuseragent"