diff --git a/go.mod b/go.mod index b9ff644dc81..db593ed3ef3 100644 --- a/go.mod +++ b/go.mod @@ -39,6 +39,7 @@ require ( github.com/sirupsen/logrus v1.2.0 github.com/spf13/cobra v1.0.0 github.com/spf13/pflag v1.0.5 + github.com/stretchr/testify v1.7.0 github.com/swaggo/http-swagger v0.0.0-20200308142732-58ac5e232fba github.com/swaggo/swag v1.6.6-0.20200529100950-7c765ddd0476 github.com/syndtr/goleveldb v1.0.1-0.20190318030020-c3a204f8e965 diff --git a/pkg/dashboard/adapter/redirector.go b/pkg/dashboard/adapter/redirector.go index 1b8f0bafd61..57e10170a07 100644 --- a/pkg/dashboard/adapter/redirector.go +++ b/pkg/dashboard/adapter/redirector.go @@ -75,7 +75,7 @@ func (h *Redirector) SetAddress(addr string) { defaultDirector := h.proxy.Director h.proxy.Director = func(r *http.Request) { defaultDirector(r) - r.Header.Set(proxyHeader, h.name) + r.Header.Add(proxyHeader, h.name) } if h.tlsConfig != nil { @@ -117,9 +117,12 @@ func (h *Redirector) ReverseProxy(w http.ResponseWriter, r *http.Request) { return } - if len(r.Header.Get(proxyHeader)) > 0 { - w.WriteHeader(http.StatusLoopDetected) - return + proxySources := r.Header.Values(proxyHeader) + for _, proxySource := range proxySources { + if proxySource == h.name { + w.WriteHeader(http.StatusLoopDetected) + return + } } proxy.ServeHTTP(w, r) diff --git a/pkg/dashboard/adapter/redirector_test.go b/pkg/dashboard/adapter/redirector_test.go new file mode 100644 index 00000000000..f192fbeb1f2 --- /dev/null +++ b/pkg/dashboard/adapter/redirector_test.go @@ -0,0 +1,112 @@ +// Copyright 2022 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +type redirectorTestSuite struct { + suite.Suite + + tempText string + tempServer *httptest.Server + + testName string + redirector *Redirector + + noRedirectHTTPClient *http.Client +} + +func TestRedirectorTestSuite(t *testing.T) { + suite.Run(t, new(redirectorTestSuite)) +} + +func (suite *redirectorTestSuite) SetupSuite() { + suite.tempText = "temp1" + suite.tempServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, suite.tempText) + })) + + suite.testName = "test1" + suite.redirector = NewRedirector(suite.testName, nil) + suite.noRedirectHTTPClient = &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + // ErrUseLastResponse can be returned by Client.CheckRedirect hooks to + // control how redirects are processed. If returned, the next request + // is not sent and the most recent response is returned with its body + // unclosed. + return http.ErrUseLastResponse + }, + } +} + +func (suite *redirectorTestSuite) TearDownSuite() { + suite.tempServer.Close() + suite.noRedirectHTTPClient.CloseIdleConnections() +} + +func (suite *redirectorTestSuite) TestReverseProxy() { + redirectorServer := httptest.NewServer(http.HandlerFunc(suite.redirector.ReverseProxy)) + defer redirectorServer.Close() + + suite.redirector.SetAddress(suite.tempServer.URL) + // Test normal forwarding + req, err := http.NewRequest(http.MethodGet, redirectorServer.URL, nil) + suite.NoError(err) + checkHTTPRequest(suite.Require(), suite.noRedirectHTTPClient, req, http.StatusOK, suite.tempText) + // Test the requests that are forwarded by others + req, err = http.NewRequest(http.MethodGet, redirectorServer.URL, nil) + suite.NoError(err) + req.Header.Set(proxyHeader, "other") + checkHTTPRequest(suite.Require(), suite.noRedirectHTTPClient, req, http.StatusOK, suite.tempText) + // Test LoopDetected + suite.redirector.SetAddress(redirectorServer.URL) + req, err = http.NewRequest(http.MethodGet, redirectorServer.URL, nil) + suite.NoError(err) + checkHTTPRequest(suite.Require(), suite.noRedirectHTTPClient, req, http.StatusLoopDetected, "") +} + +func (suite *redirectorTestSuite) TestTemporaryRedirect() { + redirectorServer := httptest.NewServer(http.HandlerFunc(suite.redirector.TemporaryRedirect)) + defer redirectorServer.Close() + suite.redirector.SetAddress(suite.tempServer.URL) + // Test TemporaryRedirect + req, err := http.NewRequest(http.MethodGet, redirectorServer.URL, nil) + suite.NoError(err) + checkHTTPRequest(suite.Require(), suite.noRedirectHTTPClient, req, http.StatusTemporaryRedirect, "") + // Test Response + req, err = http.NewRequest(http.MethodGet, redirectorServer.URL, nil) + suite.NoError(err) + checkHTTPRequest(suite.Require(), http.DefaultClient, req, http.StatusOK, suite.tempText) +} + +func checkHTTPRequest(re *require.Assertions, client *http.Client, req *http.Request, expectedCode int, expectedText string) { + resp, err := client.Do(req) + re.NoError(err) + defer resp.Body.Close() + re.Equal(expectedCode, resp.StatusCode) + if expectedCode >= http.StatusOK && expectedCode <= http.StatusAlreadyReported { + text, err := io.ReadAll(resp.Body) + re.NoError(err) + re.Equal(expectedText, string(text)) + } +}