diff --git a/go.mod b/go.mod index 8f005498..f4a6d021 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( atomicgo.dev/keyboard v0.2.8 // indirect github.com/containerd/console v1.0.3 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/go-http-utils/headers v0.0.0-20181008091004-fed159eddc2a // indirect github.com/gookit/color v1.5.2 // indirect github.com/lithammer/fuzzysearch v1.1.5 // indirect github.com/mattn/go-runewidth v0.0.14 // indirect diff --git a/go.sum b/go.sum index c37d60df..c574318f 100644 --- a/go.sum +++ b/go.sum @@ -29,6 +29,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/go-http-utils/headers v0.0.0-20181008091004-fed159eddc2a h1:v6zMvHuY9yue4+QkG/HQ/W67wvtQmWJ4SDo9aK/GIno= +github.com/go-http-utils/headers v0.0.0-20181008091004-fed159eddc2a/go.mod h1:I79BieaU4fxrw4LMXby6q5OS9XnoR9UIKLOzDFjUmuw= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= diff --git a/internal/infrastructure/cors.go b/internal/infrastructure/cors.go new file mode 100644 index 00000000..6db23cab --- /dev/null +++ b/internal/infrastructure/cors.go @@ -0,0 +1,13 @@ +package infrastructure + +import ( + "net/http" + + "github.com/go-http-utils/headers" +) + +func WriteCorsHeaders(header http.Header) { + header.Set(headers.AccessControlAllowOrigin, "*") + header.Set(headers.AccessControlAllowCredentials, "true") + header.Set(headers.AccessControlAllowMethods, "GET, PUT, POST, HEAD, TRACE, DELETE, PATCH, COPY, HEAD, LINK, OPTIONS") +} diff --git a/internal/infrastructure/cors_test.go b/internal/infrastructure/cors_test.go new file mode 100644 index 00000000..6d3533b4 --- /dev/null +++ b/internal/infrastructure/cors_test.go @@ -0,0 +1,80 @@ +package infrastructure_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/evg4b/uncors/internal/infrastructure" + "github.com/evg4b/uncors/testing/testutils" + "github.com/go-http-utils/headers" + "github.com/stretchr/testify/assert" +) + +func TestWriteCorsHeaders(t *testing.T) { + tests := []struct { + name string + recorderFactory func() *httptest.ResponseRecorder + expected http.Header + }{ + { + name: "should append data in empty writer", + recorderFactory: httptest.NewRecorder, + expected: map[string][]string{ + headers.AccessControlAllowOrigin: {"*"}, + headers.AccessControlAllowCredentials: {"true"}, + headers.AccessControlAllowMethods: { + "GET, PUT, POST, HEAD, TRACE, DELETE, PATCH, COPY, HEAD, LINK, OPTIONS", + }, + }, + }, + { + name: "should append data in filled writer", + recorderFactory: func() *httptest.ResponseRecorder { + writer := httptest.NewRecorder() + writer.Header().Set("Test-Header", "true") + writer.Header().Set("X-Hey-Header", "123") + + return writer + }, + expected: map[string][]string{ + "Test-Header": {"true"}, + "X-Hey-Header": {"123"}, + headers.AccessControlAllowOrigin: {"*"}, + headers.AccessControlAllowCredentials: {"true"}, + headers.AccessControlAllowMethods: { + "GET, PUT, POST, HEAD, TRACE, DELETE, PATCH, COPY, HEAD, LINK, OPTIONS", + }, + }, + }, + { + name: "should override same headers", + recorderFactory: func() *httptest.ResponseRecorder { + writer := httptest.NewRecorder() + writer.Header().Set("Test-Header", "true") + writer.Header().Set(headers.AccessControlAllowOrigin, "localhost:3000") + + return writer + }, + expected: map[string][]string{ + "Test-Header": {"true"}, + headers.AccessControlAllowOrigin: {"*"}, + headers.AccessControlAllowCredentials: {"true"}, + headers.AccessControlAllowMethods: { + "GET, PUT, POST, HEAD, TRACE, DELETE, PATCH, COPY, HEAD, LINK, OPTIONS", + }, + }, + }, + } + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + resp := testCase.recorderFactory() + infrastructure.WriteCorsHeaders(resp.Header()) + + response := resp.Result() + defer testutils.CheckNoError(t, response.Body.Close()) + + assert.Equal(t, response.Header, testCase.expected) + }) + } +} diff --git a/internal/mock/handler.go b/internal/mock/handler.go index a4912c49..640d8eb2 100644 --- a/internal/mock/handler.go +++ b/internal/mock/handler.go @@ -5,11 +5,13 @@ import ( "net/http" "github.com/evg4b/uncors/internal/contracts" + "github.com/evg4b/uncors/internal/infrastructure" + "github.com/go-http-utils/headers" ) type Handler struct { - mock Mock - logger contracts.Logger + response Response + logger contracts.Logger } func NewMockHandler(options ...HandlerOption) *Handler { @@ -23,21 +25,32 @@ func NewMockHandler(options ...HandlerOption) *Handler { } func (handler *Handler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { - updateRequest(request) - writer.WriteHeader(handler.mock.Response.Code) - fmt.Fprint(writer, handler.mock.Response.RawContent) + response := handler.response + header := writer.Header() + infrastructure.WriteCorsHeaders(header) + for key, value := range response.Headers { + header.Set(key, value) + } + if len(header.Get(headers.ContentType)) == 0 { + contentType := http.DetectContentType([]byte(response.RawContent)) + header.Set(headers.ContentType, contentType) + } + + writer.WriteHeader(normaliseCode(response.Code)) + if _, err := fmt.Fprint(writer, response.RawContent); err != nil { + return // TODO: add error handler + } + handler.logger.PrintResponse(&http.Response{ Request: request, - StatusCode: handler.mock.Response.Code, + StatusCode: response.Code, }) } -func updateRequest(request *http.Request) { - request.URL.Host = request.Host - - if request.TLS != nil { - request.URL.Scheme = "https" - } else { - request.URL.Scheme = "http" +func normaliseCode(code int) int { + if code == 0 { + return http.StatusOK } + + return code } diff --git a/internal/mock/handler_test.go b/internal/mock/handler_test.go new file mode 100644 index 00000000..e493f414 --- /dev/null +++ b/internal/mock/handler_test.go @@ -0,0 +1,205 @@ +package mock_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/evg4b/uncors/internal/mock" + "github.com/evg4b/uncors/testing/mocks" + "github.com/evg4b/uncors/testing/testutils" + "github.com/go-http-utils/headers" + "github.com/stretchr/testify/assert" +) + +const textPlain = "text/plain; charset=utf-8" + +func TestHandler(t *testing.T) { + logger := mocks.NewNoopLogger(t) + + t.Run("content type setting", func(t *testing.T) { + tests := []struct { + name string + body string + expected string + }{ + { + name: "plain text", + body: `status: ok`, + expected: textPlain, + }, + { + name: "json", + body: `{ "status": "ok" }`, + expected: textPlain, + }, + { + name: "html", + body: ``, + expected: "text/html; charset=utf-8", + }, + { + name: "xml", + body: ``, + expected: "text/xml; charset=utf-8", + }, + { + name: "png", + body: "\x89PNG\x0D\x0A\x1A\x0A", + expected: "image/png", + }, + } + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + handler := mock.NewMockHandler(mock.WithLogger(logger), mock.WithResponse(mock.Response{ + Code: 200, + RawContent: testCase.body, + })) + + recorder := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodGet, "/", nil) + handler.ServeHTTP(recorder, request) + + header := testutils.ReadHeader(t, recorder) + assert.EqualValues(t, testCase.expected, header.Get(headers.ContentType)) + }) + } + }) + + t.Run("headers settings", func(t *testing.T) { + tests := []struct { + name string + response mock.Response + expected http.Header + }{ + { + name: "should put default CORS headers", + response: mock.Response{ + Code: 200, + RawContent: "test content", + }, + expected: map[string][]string{ + headers.AccessControlAllowOrigin: {"*"}, + headers.AccessControlAllowCredentials: {"true"}, + headers.ContentType: {textPlain}, + headers.AccessControlAllowMethods: { + "GET, PUT, POST, HEAD, TRACE, DELETE, PATCH, COPY, HEAD, LINK, OPTIONS", + }, + }, + }, + { + name: "should set response code", + response: mock.Response{ + Code: 200, + RawContent: "test content", + }, + expected: map[string][]string{ + headers.AccessControlAllowOrigin: {"*"}, + headers.AccessControlAllowCredentials: {"true"}, + headers.ContentType: {textPlain}, + headers.AccessControlAllowMethods: { + "GET, PUT, POST, HEAD, TRACE, DELETE, PATCH, COPY, HEAD, LINK, OPTIONS", + }, + }, + }, + { + name: "should set custom headers", + response: mock.Response{ + Code: 200, + Headers: map[string]string{ + "X-Key": "X-Key-Value", + }, + RawContent: "test content", + }, + expected: map[string][]string{ + headers.AccessControlAllowOrigin: {"*"}, + headers.AccessControlAllowCredentials: {"true"}, + headers.ContentType: {textPlain}, + "X-Key": {"X-Key-Value"}, + headers.AccessControlAllowMethods: { + "GET, PUT, POST, HEAD, TRACE, DELETE, PATCH, COPY, HEAD, LINK, OPTIONS", + }, + }, + }, + { + name: "should override default headers", + response: mock.Response{ + Code: 200, + Headers: map[string]string{ + headers.AccessControlAllowOrigin: "localhost", + headers.AccessControlAllowCredentials: "false", + headers.ContentType: "none", + }, + RawContent: "test content", + }, + expected: map[string][]string{ + headers.AccessControlAllowOrigin: {"localhost"}, + headers.AccessControlAllowCredentials: {"false"}, + headers.ContentType: {"none"}, + headers.AccessControlAllowMethods: { + "GET, PUT, POST, HEAD, TRACE, DELETE, PATCH, COPY, HEAD, LINK, OPTIONS", + }, + }, + }, + } + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + handler := mock.NewMockHandler( + mock.WithResponse(testCase.response), + mock.WithLogger(logger), + ) + + request := httptest.NewRequest(http.MethodGet, "/", nil) + recorder := httptest.NewRecorder() + + handler.ServeHTTP(recorder, request) + + assert.EqualValues(t, testCase.expected, testutils.ReadHeader(t, recorder)) + assert.Equal(t, 200, recorder.Code) + }) + } + }) + + t.Run("status code", func(t *testing.T) { + tests := []struct { + name string + response mock.Response + expected int + }{ + { + name: "provide 201 code", + response: mock.Response{ + Code: 201, + }, + expected: 201, + }, + { + name: "provide 503 code", + response: mock.Response{ + Code: 503, + }, + expected: 503, + }, + { + name: "automatically provide 200 code", + response: mock.Response{}, + expected: 200, + }, + } + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + handler := mock.NewMockHandler( + mock.WithResponse(testCase.response), + mock.WithLogger(logger), + ) + + request := httptest.NewRequest(http.MethodGet, "/", nil) + recorder := httptest.NewRecorder() + + handler.ServeHTTP(recorder, request) + + assert.Equal(t, testCase.expected, recorder.Code) + }) + } + }) +} diff --git a/internal/mock/options.go b/internal/mock/options.go index e1088114..5c959885 100644 --- a/internal/mock/options.go +++ b/internal/mock/options.go @@ -4,9 +4,9 @@ import "github.com/evg4b/uncors/internal/contracts" type HandlerOption = func(*Handler) -func WithMock(mock Mock) HandlerOption { +func WithResponse(response Response) HandlerOption { return func(handler *Handler) { - handler.mock = mock + handler.response = response } } diff --git a/internal/mock/routes.go b/internal/mock/routes.go index 95210556..05eb8a73 100644 --- a/internal/mock/routes.go +++ b/internal/mock/routes.go @@ -17,7 +17,7 @@ func MakeMockedRoutes(router *mux.Router, logger contracts.Logger, mocks []Mock) setQueries(route, mock.Queries) setHeaders(route, mock.Headers) - handler := NewMockHandler(WithMock(mock), WithLogger(logger)) + handler := NewMockHandler(WithResponse(mock.Response), WithLogger(logger)) route.Handler(handler) } else { defaultMocks = append(defaultMocks, mock) @@ -29,7 +29,7 @@ func MakeMockedRoutes(router *mux.Router, logger contracts.Logger, mocks []Mock) setPath(route, mock.Path) - handler := NewMockHandler(WithMock(mock), WithLogger(logger)) + handler := NewMockHandler(WithResponse(mock.Response), WithLogger(logger)) route.Handler(handler) } } diff --git a/internal/proxy/helpers.go b/internal/proxy/helpers.go index 8a04fc69..d7a7b7a1 100644 --- a/internal/proxy/helpers.go +++ b/internal/proxy/helpers.go @@ -3,6 +3,8 @@ package proxy import ( "net/http" "strings" + + "github.com/go-http-utils/headers" ) type modificationsMap = map[string]func(string) (string, error) @@ -11,8 +13,8 @@ func noop(s string) (string, error) { return s, nil } func copyHeaders(source, dest http.Header, modifications modificationsMap) error { for key, values := range source { - if !strings.EqualFold(key, "cookie") && !strings.EqualFold(key, "set-cookie") { - modificationFunc, ok := modifications[strings.ToLower(key)] + if !strings.EqualFold(key, headers.Cookie) && !strings.EqualFold(key, headers.SetCookie) { + modificationFunc, ok := modifications[headers.Normalize(key)] if !ok { modificationFunc = noop } diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index b8d89acf..5a1b11e3 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -91,11 +91,7 @@ func copyCookiesToTarget(source *http.Request, replacer *urlreplacer.Replacer, t return nil } -func copyResponseData(header http.Header, resp http.ResponseWriter, targetResp *http.Response) error { - header.Set("Access-Control-Allow-Origin", "*") - header.Set("Access-Control-Allow-Credentials", "true") - header.Set("Access-Control-Allow-Methods", "GET, PUT, POST, HEAD, TRACE, DELETE, PATCH, COPY, HEAD, LINK, OPTIONS") - +func copyResponseData(resp http.ResponseWriter, targetResp *http.Response) error { resp.WriteHeader(targetResp.StatusCode) if _, err := io.Copy(resp, targetResp.Body); err != nil { diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index eeab43d4..5a36df87 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -13,6 +13,7 @@ import ( "github.com/evg4b/uncors/pkg/urlx" "github.com/evg4b/uncors/testing/mocks" "github.com/evg4b/uncors/testing/testutils" + "github.com/go-http-utils/headers" "github.com/stretchr/testify/assert" ) @@ -33,13 +34,13 @@ func TestProxyHandler(t *testing.T) { name: "transform Origin", URL: "http://premium.local.com/app", expectedURL: "https://premium.api.com/app", - headerKey: "Origin", + headerKey: headers.Origin, }, { name: "transform Referer", URL: "http://premium.local.com/info", expectedURL: "https://premium.api.com/info", - headerKey: "Referer", + headerKey: headers.Referer, }, } @@ -92,7 +93,7 @@ func TestProxyHandler(t *testing.T) { name: "transform Location", URL: "https://premium.api.com/app", expectedURL: "http://premium.local.com/app", - headerKey: "Location", + headerKey: headers.Location, }, } @@ -164,18 +165,18 @@ func TestProxyHandler(t *testing.T) { proc.ServeHTTP(recorder, req) - headers := recorder.Header() - - assert.Equal(t, "*", headers.Get("Access-Control-Allow-Origin")) - assert.Equal(t, "true", headers.Get("Access-Control-Allow-Credentials")) + header := recorder.Header() + assert.Equal(t, "*", header.Get(headers.AccessControlAllowOrigin)) + assert.Equal(t, "true", header.Get(headers.AccessControlAllowCredentials)) assert.Equal( t, "GET, PUT, POST, HEAD, TRACE, DELETE, PATCH, COPY, HEAD, LINK, OPTIONS", - headers.Get("Access-Control-Allow-Methods"), + header.Get(headers.AccessControlAllowMethods), ) }) t.Run("OPTIONS request handling", func(t *testing.T) { + t.Skip() handler := proxy.NewProxyHandler( proxy.WithLogger(mocks.NewNoopLogger(t)), ) @@ -194,21 +195,21 @@ func TestProxyHandler(t *testing.T) { { name: "should do not skip not access-control-request-* headers", headers: http.Header{ - "Host": {"www.host.com"}, - "Content-Type": {"application/json"}, - "Authorization": {"Bearer Token"}, + "Host": {"www.host.com"}, + headers.ContentType: {"application/json"}, + headers.Authorization: {"Bearer Token"}, }, expected: http.Header{}, }, { name: "should allow all access-control-request-* headers", headers: http.Header{ - "Access-Control-Request-Headers": {"X-PINGOTHER, Content-Type"}, - "Access-Control-Request-Method": {http.MethodPost, http.MethodDelete}, + headers.AccessControlRequestHeaders: {"X-PINGOTHER, Content-Type"}, + headers.AccessControlRequestMethod: {http.MethodPost, http.MethodDelete}, }, expected: http.Header{ - "Access-Control-Allow-Headers": {"X-PINGOTHER, Content-Type"}, - "Access-Control-Allow-Method": {http.MethodPost, http.MethodDelete}, + headers.AccessControlAllowHeaders: {"X-PINGOTHER, Content-Type"}, + headers.AccessControlAllowMethods: {http.MethodPost, http.MethodDelete}, }, }, } diff --git a/internal/proxy/request.go b/internal/proxy/request.go index 9e097fea..b50f4f57 100644 --- a/internal/proxy/request.go +++ b/internal/proxy/request.go @@ -5,6 +5,7 @@ import ( "net/http" "github.com/evg4b/uncors/internal/urlreplacer" + "github.com/go-http-utils/headers" ) func (handler *Handler) makeOriginalRequest( @@ -18,8 +19,8 @@ func (handler *Handler) makeOriginalRequest( } err = copyHeaders(req.Header, originalReq.Header, modificationsMap{ - "origin": replacer.Replace, - "referer": replacer.Replace, + headers.Origin: replacer.Replace, + headers.Referer: replacer.Replace, }) if err != nil { diff --git a/internal/proxy/responce.go b/internal/proxy/responce.go index 9136d451..29a49502 100644 --- a/internal/proxy/responce.go +++ b/internal/proxy/responce.go @@ -4,7 +4,9 @@ import ( "fmt" "net/http" + "github.com/evg4b/uncors/internal/infrastructure" "github.com/evg4b/uncors/internal/urlreplacer" + "github.com/go-http-utils/headers" ) func (handler *Handler) makeUncorsResponse( @@ -16,16 +18,16 @@ func (handler *Handler) makeUncorsResponse( return fmt.Errorf("failed to copy cookies in request: %w", err) } - header := resp.Header() - err := copyHeaders(originalResp.Header, header, modificationsMap{ - "location": replacer.Replace, + err := copyHeaders(originalResp.Header, resp.Header(), modificationsMap{ + headers.Location: replacer.Replace, }) - if err != nil { return err } - if err = copyResponseData(header, resp, originalResp); err != nil { + infrastructure.WriteCorsHeaders(resp.Header()) + + if err = copyResponseData(resp, originalResp); err != nil { return err } diff --git a/testing/mocks/logger_mock.go b/testing/mocks/logger_mock.go index 98ca184d..20ec3633 100644 --- a/testing/mocks/logger_mock.go +++ b/testing/mocks/logger_mock.go @@ -175,7 +175,7 @@ func (mmDebug *mLoggerMockDebug) Return() *LoggerMock { return mmDebug.mock } -//Set uses given function f to mock the Logger.Debug method +// Set uses given function f to mock the Logger.Debug method func (mmDebug *mLoggerMockDebug) Set(f func(a ...interface{})) *LoggerMock { if mmDebug.defaultExpectation != nil { mmDebug.mock.t.Fatalf("Default expectation is already set for the Logger.Debug method") @@ -363,7 +363,7 @@ func (mmDebugf *mLoggerMockDebugf) Return() *LoggerMock { return mmDebugf.mock } -//Set uses given function f to mock the Logger.Debugf method +// Set uses given function f to mock the Logger.Debugf method func (mmDebugf *mLoggerMockDebugf) Set(f func(template string, a ...interface{})) *LoggerMock { if mmDebugf.defaultExpectation != nil { mmDebugf.mock.t.Fatalf("Default expectation is already set for the Logger.Debugf method") @@ -550,7 +550,7 @@ func (mmError *mLoggerMockError) Return() *LoggerMock { return mmError.mock } -//Set uses given function f to mock the Logger.Error method +// Set uses given function f to mock the Logger.Error method func (mmError *mLoggerMockError) Set(f func(a ...interface{})) *LoggerMock { if mmError.defaultExpectation != nil { mmError.mock.t.Fatalf("Default expectation is already set for the Logger.Error method") @@ -738,7 +738,7 @@ func (mmErrorf *mLoggerMockErrorf) Return() *LoggerMock { return mmErrorf.mock } -//Set uses given function f to mock the Logger.Errorf method +// Set uses given function f to mock the Logger.Errorf method func (mmErrorf *mLoggerMockErrorf) Set(f func(template string, a ...interface{})) *LoggerMock { if mmErrorf.defaultExpectation != nil { mmErrorf.mock.t.Fatalf("Default expectation is already set for the Logger.Errorf method") @@ -925,7 +925,7 @@ func (mmInfo *mLoggerMockInfo) Return() *LoggerMock { return mmInfo.mock } -//Set uses given function f to mock the Logger.Info method +// Set uses given function f to mock the Logger.Info method func (mmInfo *mLoggerMockInfo) Set(f func(a ...interface{})) *LoggerMock { if mmInfo.defaultExpectation != nil { mmInfo.mock.t.Fatalf("Default expectation is already set for the Logger.Info method") @@ -1113,7 +1113,7 @@ func (mmInfof *mLoggerMockInfof) Return() *LoggerMock { return mmInfof.mock } -//Set uses given function f to mock the Logger.Infof method +// Set uses given function f to mock the Logger.Infof method func (mmInfof *mLoggerMockInfof) Set(f func(template string, a ...interface{})) *LoggerMock { if mmInfof.defaultExpectation != nil { mmInfof.mock.t.Fatalf("Default expectation is already set for the Logger.Infof method") @@ -1300,7 +1300,7 @@ func (mmPrintResponse *mLoggerMockPrintResponse) Return() *LoggerMock { return mmPrintResponse.mock } -//Set uses given function f to mock the Logger.PrintResponse method +// Set uses given function f to mock the Logger.PrintResponse method func (mmPrintResponse *mLoggerMockPrintResponse) Set(f func(response *http.Response)) *LoggerMock { if mmPrintResponse.defaultExpectation != nil { mmPrintResponse.mock.t.Fatalf("Default expectation is already set for the Logger.PrintResponse method") @@ -1487,7 +1487,7 @@ func (mmWarning *mLoggerMockWarning) Return() *LoggerMock { return mmWarning.mock } -//Set uses given function f to mock the Logger.Warning method +// Set uses given function f to mock the Logger.Warning method func (mmWarning *mLoggerMockWarning) Set(f func(a ...interface{})) *LoggerMock { if mmWarning.defaultExpectation != nil { mmWarning.mock.t.Fatalf("Default expectation is already set for the Logger.Warning method") @@ -1675,7 +1675,7 @@ func (mmWarningf *mLoggerMockWarningf) Return() *LoggerMock { return mmWarningf.mock } -//Set uses given function f to mock the Logger.Warningf method +// Set uses given function f to mock the Logger.Warningf method func (mmWarningf *mLoggerMockWarningf) Set(f func(template string, a ...interface{})) *LoggerMock { if mmWarningf.defaultExpectation != nil { mmWarningf.mock.t.Fatalf("Default expectation is already set for the Logger.Warningf method") diff --git a/testing/mocks/urlreplacer_factory_mock.go b/testing/mocks/urlreplacer_factory_mock.go index a6a3ea74..c9cdc873 100644 --- a/testing/mocks/urlreplacer_factory_mock.go +++ b/testing/mocks/urlreplacer_factory_mock.go @@ -111,7 +111,7 @@ func (mmMake *mURLReplacerFactoryMockMake) Return(rp1 *urlreplacer.Replacer, rp2 return mmMake.mock } -//Set uses given function f to mock the URLReplacerFactory.Make method +// Set uses given function f to mock the URLReplacerFactory.Make method func (mmMake *mURLReplacerFactoryMockMake) Set(f func(requestURL *url.URL) (rp1 *urlreplacer.Replacer, rp2 *urlreplacer.Replacer, err error)) *URLReplacerFactoryMock { if mmMake.defaultExpectation != nil { mmMake.mock.t.Fatalf("Default expectation is already set for the URLReplacerFactory.Make method") diff --git a/testing/testutils/http_body..go b/testing/testutils/http_body..go index 4ddc420d..e6393bba 100644 --- a/testing/testutils/http_body..go +++ b/testing/testutils/http_body..go @@ -2,6 +2,7 @@ package testutils import ( "io" + "net/http" "net/http/httptest" "testing" ) @@ -17,3 +18,12 @@ func ReadBody(t *testing.T, recorder *httptest.ResponseRecorder) string { return string(body) } + +func ReadHeader(t *testing.T, recorder *httptest.ResponseRecorder) http.Header { + t.Helper() + + response := recorder.Result() + defer CheckNoError(t, response.Body.Close()) + + return response.Header +}