From e34f5e714cb93fe70e03e5d7d115a826f0dd938b Mon Sep 17 00:00:00 2001 From: Evgeny Abramovich Date: Sun, 18 Jun 2023 13:49:30 -0300 Subject: [PATCH] added custom handler type --- internal/contracts/http.go | 31 ++++++++++++++ internal/contracts/http_test.go | 42 +++++++++++++++++++ .../static => contracts}/response_writer.go | 2 +- .../response_writer_test.go | 6 +-- .../{handler.go => url_replacer_factory.go} | 0 internal/handler/mock/middleware.go | 2 +- internal/handler/mock/middleware_test.go | 13 +++--- internal/handler/proxy/middleware.go | 2 +- internal/handler/proxy/middleware_test.go | 9 ++-- internal/handler/static/middleware.go | 6 +-- internal/handler/static/middleware_options.go | 4 +- internal/handler/static/middleware_test.go | 21 ++++------ internal/handler/static_routes.go | 10 ++++- internal/handler/uncors_handler.go | 24 ++++++----- internal/handler/uncors_handler_test.go | 29 ++++++------- internal/server/server.go | 5 ++- internal/server/server_test.go | 3 +- testing/mocks/handler.go | 13 ++++++ 18 files changed, 158 insertions(+), 64 deletions(-) create mode 100644 internal/contracts/http.go create mode 100644 internal/contracts/http_test.go rename internal/{handler/static => contracts}/response_writer.go (94%) rename internal/{handler/static => contracts}/response_writer_test.go (83%) rename internal/contracts/{handler.go => url_replacer_factory.go} (100%) create mode 100644 testing/mocks/handler.go diff --git a/internal/contracts/http.go b/internal/contracts/http.go new file mode 100644 index 00000000..c4019ab0 --- /dev/null +++ b/internal/contracts/http.go @@ -0,0 +1,31 @@ +package contracts + +import ( + "errors" + "net/http" +) + +type Request = http.Request + +type Handler interface { + ServeHTTP(*ResponseWriter, *Request) +} + +type HandlerFunc func(*ResponseWriter, *Request) + +func (f HandlerFunc) ServeHTTP(w *ResponseWriter, r *Request) { + f(w, r) +} + +var ErrResponceNotCasted = errors.New("received incorrect response writer type") + +func CastToHTTPHandler(handler Handler) http.Handler { + return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) { + writer, ok := response.(*ResponseWriter) + if !ok { + panic(ErrResponceNotCasted) + } + + handler.ServeHTTP(writer, request) + }) +} diff --git a/internal/contracts/http_test.go b/internal/contracts/http_test.go new file mode 100644 index 00000000..3efa9c55 --- /dev/null +++ b/internal/contracts/http_test.go @@ -0,0 +1,42 @@ +package contracts_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/evg4b/uncors/internal/contracts" + "github.com/evg4b/uncors/internal/sfmt" + "github.com/evg4b/uncors/testing/testutils" + "github.com/stretchr/testify/assert" +) + +func TestCastToHTTPHandler(t *testing.T) { + const expectedBody = `{ "OK": true }` + uncorsHandler := contracts.HandlerFunc(func(w *contracts.ResponseWriter, r *contracts.Request) { + w.WriteHeader(http.StatusOK) + sfmt.Fprint(w, expectedBody) + }) + + httpHandler := contracts.CastToHTTPHandler(uncorsHandler) + + req := httptest.NewRequest(http.MethodGet, "/data", nil) + + t.Run("cast correctly", func(t *testing.T) { + recirder := httptest.NewRecorder() + responceWriter := contracts.WrapResponseWriter(recirder) + + assert.NotPanics(t, func() { + httpHandler.ServeHTTP(responceWriter, req) + assert.Equal(t, expectedBody, testutils.ReadBody(t, recirder)) + }) + }) + + t.Run("panit when request is not wrapped", func(t *testing.T) { + recirder := httptest.NewRecorder() + + assert.PanicsWithValue(t, contracts.ErrResponceNotCasted, func() { + httpHandler.ServeHTTP(recirder, req) + }) + }) +} diff --git a/internal/handler/static/response_writer.go b/internal/contracts/response_writer.go similarity index 94% rename from internal/handler/static/response_writer.go rename to internal/contracts/response_writer.go index bd2ae83b..34258bc6 100644 --- a/internal/handler/static/response_writer.go +++ b/internal/contracts/response_writer.go @@ -1,4 +1,4 @@ -package static +package contracts import ( "net/http" diff --git a/internal/handler/static/response_writer_test.go b/internal/contracts/response_writer_test.go similarity index 83% rename from internal/handler/static/response_writer_test.go rename to internal/contracts/response_writer_test.go index e7bd1f9b..b6c99fff 100644 --- a/internal/handler/static/response_writer_test.go +++ b/internal/contracts/response_writer_test.go @@ -1,10 +1,10 @@ -package static_test +package contracts_test import ( "net/http/httptest" "testing" - "github.com/evg4b/uncors/internal/handler/static" + "github.com/evg4b/uncors/internal/contracts" "github.com/evg4b/uncors/internal/sfmt" "github.com/evg4b/uncors/testing/testutils" "github.com/stretchr/testify/assert" @@ -15,7 +15,7 @@ func TestResponseWriterWrapper(t *testing.T) { const expectedCode = 201 recorder := httptest.NewRecorder() - writer := static.WrapResponseWriter(recorder) + writer := contracts.WrapResponseWriter(recorder) writer.WriteHeader(expectedCode) sfmt.Fprint(writer, expectedValye) diff --git a/internal/contracts/handler.go b/internal/contracts/url_replacer_factory.go similarity index 100% rename from internal/contracts/handler.go rename to internal/contracts/url_replacer_factory.go diff --git a/internal/handler/mock/middleware.go b/internal/handler/mock/middleware.go index fdd909b4..56eac0bc 100644 --- a/internal/handler/mock/middleware.go +++ b/internal/handler/mock/middleware.go @@ -27,7 +27,7 @@ func NewMockMiddleware(options ...MiddlewareOption) *Middleware { return middleware } -func (m *Middleware) ServeHTTP(writer http.ResponseWriter, request *http.Request) { +func (m *Middleware) ServeHTTP(writer *contracts.ResponseWriter, request *contracts.Request) { response := m.response header := writer.Header() diff --git a/internal/handler/mock/middleware_test.go b/internal/handler/mock/middleware_test.go index 59a1c433..17a7b63d 100644 --- a/internal/handler/mock/middleware_test.go +++ b/internal/handler/mock/middleware_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/evg4b/uncors/internal/config" + "github.com/evg4b/uncors/internal/contracts" "github.com/evg4b/uncors/internal/handler/mock" "github.com/evg4b/uncors/testing/mocks" "github.com/evg4b/uncors/testing/testconstants" @@ -74,7 +75,7 @@ func TestHandler(t *testing.T) { recorder := httptest.NewRecorder() request := httptest.NewRequest(http.MethodGet, "/", nil) - handler.ServeHTTP(recorder, request) + handler.ServeHTTP(contracts.WrapResponseWriter(recorder), request) body := testutils.ReadBody(t, recorder) assert.EqualValues(t, testCase.expected, body) @@ -142,7 +143,7 @@ func TestHandler(t *testing.T) { recorder := httptest.NewRecorder() request := httptest.NewRequest(http.MethodGet, "/", nil) - handler.ServeHTTP(recorder, request) + handler.ServeHTTP(contracts.WrapResponseWriter(recorder), request) header := testutils.ReadHeader(t, recorder) assert.EqualValues(t, testCase.expected, header.Get(headers.ContentType)) @@ -232,7 +233,7 @@ func TestHandler(t *testing.T) { request := httptest.NewRequest(http.MethodGet, "/", nil) recorder := httptest.NewRecorder() - handler.ServeHTTP(recorder, request) + handler.ServeHTTP(contracts.WrapResponseWriter(recorder), request) assert.EqualValues(t, testCase.expected, testutils.ReadHeader(t, recorder)) assert.Equal(t, http.StatusOK, recorder.Code) @@ -280,7 +281,7 @@ func TestHandler(t *testing.T) { request := httptest.NewRequest(http.MethodGet, "/", nil) recorder := httptest.NewRecorder() - handler.ServeHTTP(recorder, request) + handler.ServeHTTP(contracts.WrapResponseWriter(recorder), request) assert.Equal(t, testCase.expected, recorder.Code) }) @@ -355,7 +356,7 @@ func TestHandler(t *testing.T) { request := httptest.NewRequest(http.MethodGet, "/", nil) recorder := httptest.NewRecorder() - handler.ServeHTTP(recorder, request) + handler.ServeHTTP(contracts.WrapResponseWriter(recorder), request) assert.Equal(t, called, testCase.shouldBeCalled) }) @@ -382,7 +383,7 @@ func TestHandler(t *testing.T) { waitGroup.Add(1) go func() { defer waitGroup.Done() - handler.ServeHTTP(recorder, request.WithContext(ctx)) + handler.ServeHTTP(contracts.WrapResponseWriter(recorder), request.WithContext(ctx)) }() cancel() diff --git a/internal/handler/proxy/middleware.go b/internal/handler/proxy/middleware.go index 8211526c..f4afdcd6 100644 --- a/internal/handler/proxy/middleware.go +++ b/internal/handler/proxy/middleware.go @@ -31,7 +31,7 @@ func NewProxyHandler(options ...HandlerOption) *Handler { return middleware } -func (m *Handler) ServeHTTP(response http.ResponseWriter, request *http.Request) { +func (m *Handler) ServeHTTP(response *contracts.ResponseWriter, request *contracts.Request) { if err := m.handle(response, request); err != nil { infra.HTTPError(response, err) } diff --git a/internal/handler/proxy/middleware_test.go b/internal/handler/proxy/middleware_test.go index 077c5628..def0bf7c 100644 --- a/internal/handler/proxy/middleware_test.go +++ b/internal/handler/proxy/middleware_test.go @@ -9,6 +9,7 @@ import ( "testing" "github.com/evg4b/uncors/internal/config" + "github.com/evg4b/uncors/internal/contracts" "github.com/evg4b/uncors/internal/handler/proxy" "github.com/evg4b/uncors/internal/helpers" "github.com/evg4b/uncors/internal/urlreplacer" @@ -80,7 +81,7 @@ func TestProxyMiddleware(t *testing.T) { req.Header.Add(testCase.headerKey, testCase.URL) - proc.ServeHTTP(httptest.NewRecorder(), req) + proc.ServeHTTP(contracts.WrapResponseWriter(httptest.NewRecorder()), req) }) } }) @@ -133,7 +134,7 @@ func TestProxyMiddleware(t *testing.T) { recorder := httptest.NewRecorder() - proc.ServeHTTP(recorder, req) + proc.ServeHTTP(contracts.WrapResponseWriter(recorder), req) assert.Equal(t, testCase.expectedURL, recorder.Header().Get(testCase.headerKey)) }) @@ -166,7 +167,7 @@ func TestProxyMiddleware(t *testing.T) { recorder := httptest.NewRecorder() - proc.ServeHTTP(recorder, req) + proc.ServeHTTP(contracts.WrapResponseWriter(recorder), req) header := recorder.Header() assert.Equal(t, "*", header.Get(headers.AccessControlAllowOrigin)) @@ -240,7 +241,7 @@ func TestProxyMiddleware(t *testing.T) { req, err := http.NewRequestWithContext(context.TODO(), http.MethodOptions, "/", nil) testutils.CheckNoError(t, err) - middleware.ServeHTTP(recorder, req) + middleware.ServeHTTP(contracts.WrapResponseWriter(recorder), req) assert.Equal(t, http.StatusOK, recorder.Code) assert.Equal(t, testCase.expected, recorder.Header()) diff --git a/internal/handler/static/middleware.go b/internal/handler/static/middleware.go index c01a302b..b7fcfdf6 100644 --- a/internal/handler/static/middleware.go +++ b/internal/handler/static/middleware.go @@ -14,7 +14,7 @@ import ( type Middleware struct { fs afero.Fs - next http.Handler + next contracts.Handler index string logger contracts.Logger prefix string @@ -30,8 +30,8 @@ func NewStaticMiddleware(options ...MiddlewareOption) *Middleware { return middleware } -func (m *Middleware) ServeHTTP(writer http.ResponseWriter, request *http.Request) { - response := WrapResponseWriter(writer) +func (m *Middleware) ServeHTTP(writer *contracts.ResponseWriter, request *contracts.Request) { + response := contracts.WrapResponseWriter(writer) filePath := m.extractFilePath(request) file, stat, err := m.openFile(filePath) diff --git a/internal/handler/static/middleware_options.go b/internal/handler/static/middleware_options.go index ac9e4083..99914b27 100644 --- a/internal/handler/static/middleware_options.go +++ b/internal/handler/static/middleware_options.go @@ -1,8 +1,6 @@ package static import ( - "net/http" - "github.com/evg4b/uncors/internal/contracts" "github.com/spf13/afero" ) @@ -21,7 +19,7 @@ func WithIndex(index string) MiddlewareOption { } } -func WithNext(next http.Handler) MiddlewareOption { +func WithNext(next contracts.Handler) MiddlewareOption { return func(m *Middleware) { m.next = next } diff --git a/internal/handler/static/middleware_test.go b/internal/handler/static/middleware_test.go index 2249d983..daa551dc 100644 --- a/internal/handler/static/middleware_test.go +++ b/internal/handler/static/middleware_test.go @@ -7,6 +7,7 @@ import ( "strings" "testing" + "github.com/evg4b/uncors/internal/contracts" "github.com/evg4b/uncors/internal/handler/static" "github.com/evg4b/uncors/internal/sfmt" "github.com/evg4b/uncors/testing/mocks" @@ -101,7 +102,7 @@ func TestMiddleware(t *testing.T) { middleware := static.NewStaticMiddleware( static.WithFileSystem(fs), static.WithLogger(loggerMock), - static.WithNext(http.HandlerFunc(func(writer http.ResponseWriter, _ *http.Request) { + static.WithNext(contracts.HandlerFunc(func(writer *contracts.ResponseWriter, _ *contracts.Request) { writer.WriteHeader(testHTTPStatusCode) sfmt.Fprint(writer, testHTTPBody) })), @@ -114,7 +115,7 @@ func TestMiddleware(t *testing.T) { requestURI, err := url.Parse(testCase.path) testutils.CheckNoError(t, err) - middleware.ServeHTTP(recorder, &http.Request{ + middleware.ServeHTTP(contracts.WrapResponseWriter(recorder), &http.Request{ Method: http.MethodGet, URL: requestURI, }) @@ -132,7 +133,7 @@ func TestMiddleware(t *testing.T) { requestURI, err := url.Parse(testCase.path) testutils.CheckNoError(t, err) - middleware.ServeHTTP(recorder, &http.Request{ + middleware.ServeHTTP(contracts.WrapResponseWriter(recorder), &http.Request{ Method: http.MethodGet, URL: requestURI, }) @@ -149,9 +150,7 @@ func TestMiddleware(t *testing.T) { static.WithFileSystem(fs), static.WithLogger(loggerMock), static.WithIndex(indexHTML), - static.WithNext(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { - panic("should not be called") - })), + static.WithNext(mocks.FailNowMock(t)), ) t.Run("send index file", func(t *testing.T) { @@ -161,7 +160,7 @@ func TestMiddleware(t *testing.T) { requestURI, err := url.Parse(testCase.path) testutils.CheckNoError(t, err) - middleware.ServeHTTP(recorder, &http.Request{ + middleware.ServeHTTP(contracts.WrapResponseWriter(recorder), &http.Request{ Method: http.MethodGet, URL: requestURI, }) @@ -179,7 +178,7 @@ func TestMiddleware(t *testing.T) { requestURI, err := url.Parse(testCase.path) testutils.CheckNoError(t, err) - middleware.ServeHTTP(recorder, &http.Request{ + middleware.ServeHTTP(contracts.WrapResponseWriter(recorder), &http.Request{ Method: http.MethodGet, URL: requestURI, }) @@ -195,16 +194,14 @@ func TestMiddleware(t *testing.T) { static.WithFileSystem(fs), static.WithLogger(loggerMock), static.WithIndex("/not-exists.html"), - static.WithNext(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { - panic("should not be called") - })), + static.WithNext(mocks.FailNowMock(t)), ) recorder := httptest.NewRecorder() requestURI, err := url.Parse("/options/") testutils.CheckNoError(t, err) - middleware.ServeHTTP(recorder, &http.Request{ + middleware.ServeHTTP(contracts.WrapResponseWriter(recorder), &http.Request{ Method: http.MethodGet, URL: requestURI, }) diff --git a/internal/handler/static_routes.go b/internal/handler/static_routes.go index 90837076..e2b98957 100644 --- a/internal/handler/static_routes.go +++ b/internal/handler/static_routes.go @@ -5,13 +5,18 @@ import ( "strings" "github.com/evg4b/uncors/internal/config" + "github.com/evg4b/uncors/internal/contracts" "github.com/evg4b/uncors/internal/handler/static" "github.com/evg4b/uncors/internal/ui" "github.com/gorilla/mux" "github.com/spf13/afero" ) -func (m *RequestHandler) makeStaticRoutes(router *mux.Router, statics config.StaticDirectories, next http.Handler) { +func (m *RequestHandler) makeStaticRoutes( + router *mux.Router, + statics config.StaticDirectories, + next contracts.Handler, +) { for _, staticDir := range statics { clearPath := strings.TrimSuffix(staticDir.Path, "/") path := clearPath + "/" @@ -29,6 +34,7 @@ func (m *RequestHandler) makeStaticRoutes(router *mux.Router, statics config.Sta static.WithPrefix(path), ) - route.PathPrefix(path).Handler(handler) + httpHandler := contracts.CastToHTTPHandler(handler) + route.PathPrefix(path).Handler(httpHandler) } } diff --git a/internal/handler/uncors_handler.go b/internal/handler/uncors_handler.go index ac8b5225..699b0d59 100644 --- a/internal/handler/uncors_handler.go +++ b/internal/handler/uncors_handler.go @@ -63,29 +63,31 @@ func NewUncorsRequestHandler(options ...UncorsRequestHandlerOption) *RequestHand setDefaultHandler(router, proxyHandler) } - setDefaultHandler(handler.router, http.HandlerFunc(func(writer http.ResponseWriter, _ *http.Request) { + setDefaultHandler(handler.router, contracts.HandlerFunc(func(writer *contracts.ResponseWriter, _ *http.Request) { infra.HTTPError(writer, errHostNotMapped) })) return handler } -func (m *RequestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { +func (m *RequestHandler) ServeHTTP(writer *contracts.ResponseWriter, request *contracts.Request) { m.router.ServeHTTP(writer, request) } -func (m *RequestHandler) createHandler(response config.Response) *mock.Middleware { - return mock.NewMockMiddleware( - mock.WithLogger(ui.MockLogger), - mock.WithResponse(response), - mock.WithFileSystem(m.fs), - mock.WithAfter(time.After), +func (m *RequestHandler) createHandler(response config.Response) http.Handler { + return contracts.CastToHTTPHandler( + mock.NewMockMiddleware( + mock.WithLogger(ui.MockLogger), + mock.WithResponse(response), + mock.WithFileSystem(m.fs), + mock.WithAfter(time.After), + ), ) } -func setDefaultHandler(router *mux.Router, handler http.Handler) { - router.NotFoundHandler = handler - router.MethodNotAllowedHandler = handler +func setDefaultHandler(router *mux.Router, handler contracts.Handler) { + router.NotFoundHandler = contracts.CastToHTTPHandler(handler) + router.MethodNotAllowedHandler = contracts.CastToHTTPHandler(handler) } const wildcard = "*" diff --git a/internal/handler/uncors_handler_test.go b/internal/handler/uncors_handler_test.go index 96a9da70..60d692d3 100644 --- a/internal/handler/uncors_handler_test.go +++ b/internal/handler/uncors_handler_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/evg4b/uncors/internal/config" + "github.com/evg4b/uncors/internal/contracts" "github.com/evg4b/uncors/internal/handler" "github.com/evg4b/uncors/internal/helpers" "github.com/evg4b/uncors/internal/log" @@ -153,7 +154,7 @@ func TestUncorsRequestHandler(t *testing.T) { request := httptest.NewRequest(http.MethodGet, testCase.url, nil) helpers.NormaliseRequest(request) - hand.ServeHTTP(recorder, request) + hand.ServeHTTP(contracts.WrapResponseWriter(recorder), request) assert.Equal(t, 200, recorder.Code) assert.Equal(t, testCase.expected, testutils.ReadBody(t, recorder)) @@ -166,7 +167,7 @@ func TestUncorsRequestHandler(t *testing.T) { request := httptest.NewRequest(http.MethodGet, "http://localhost/cc/unknown.html", nil) helpers.NormaliseRequest(request) - hand.ServeHTTP(recorder, request) + hand.ServeHTTP(contracts.WrapResponseWriter(recorder), request) assert.Equal(t, http.StatusOK, recorder.Code) assert.Equal(t, indexHTML, testutils.ReadBody(t, recorder)) @@ -177,7 +178,7 @@ func TestUncorsRequestHandler(t *testing.T) { request := httptest.NewRequest(http.MethodGet, "http://localhost/pnp/unknown.html", nil) helpers.NormaliseRequest(request) - hand.ServeHTTP(recorder, request) + hand.ServeHTTP(contracts.WrapResponseWriter(recorder), request) assert.Equal(t, http.StatusInternalServerError, recorder.Code) expectedMessage := "filed to opend index file: open /assets/index.php: file does not exist" @@ -209,7 +210,7 @@ func TestUncorsRequestHandler(t *testing.T) { request := httptest.NewRequest(http.MethodGet, testCase.url, nil) helpers.NormaliseRequest(request) - hand.ServeHTTP(recorder, request) + hand.ServeHTTP(contracts.WrapResponseWriter(recorder), request) assert.Equal(t, http.StatusOK, recorder.Code) assert.Equal(t, testCase.expected, testutils.ReadBody(t, recorder)) @@ -222,7 +223,7 @@ func TestUncorsRequestHandler(t *testing.T) { request := httptest.NewRequest(http.MethodGet, "http://localhost/img/original.png", nil) helpers.NormaliseRequest(request) - hand.ServeHTTP(recorder, request) + hand.ServeHTTP(contracts.WrapResponseWriter(recorder), request) assert.Equal(t, http.StatusOK, recorder.Code) assert.Equal(t, "original.png", testutils.ReadBody(t, recorder)) @@ -263,7 +264,7 @@ func TestUncorsRequestHandler(t *testing.T) { request := httptest.NewRequest(http.MethodGet, testCase.url, nil) helpers.NormaliseRequest(request) - hand.ServeHTTP(recorder, request) + hand.ServeHTTP(contracts.WrapResponseWriter(recorder), request) assert.Equal(t, testCase.expectedCode, recorder.Code) assert.Equal(t, testCase.expected, testutils.ReadBody(t, recorder)) @@ -276,7 +277,7 @@ func TestUncorsRequestHandler(t *testing.T) { request := httptest.NewRequest(http.MethodGet, "http://localhost/api/mocks/4", nil) helpers.NormaliseRequest(request) - hand.ServeHTTP(recorder, request) + hand.ServeHTTP(contracts.WrapResponseWriter(recorder), request) assert.Equal(t, http.StatusInternalServerError, recorder.Code) expectedMessage := "filed to opent file /unknown.json: open /unknown.json: file does not exist" @@ -327,7 +328,7 @@ func TestMockMiddleware(t *testing.T) { request := httptest.NewRequest(method, api, nil) recorder := httptest.NewRecorder() - middleware.ServeHTTP(recorder, request) + middleware.ServeHTTP(contracts.WrapResponseWriter(recorder), request) body := testutils.ReadBody(t, recorder) assert.Equal(t, http.StatusOK, recorder.Code) @@ -381,7 +382,7 @@ func TestMockMiddleware(t *testing.T) { request := httptest.NewRequest(method, api, nil) recorder := httptest.NewRecorder() - middleware.ServeHTTP(recorder, request) + middleware.ServeHTTP(contracts.WrapResponseWriter(recorder), request) assert.Equal(t, expectedCode, recorder.Code) assert.Equal(t, expectedBody, testutils.ReadBody(t, recorder)) @@ -392,7 +393,7 @@ func TestMockMiddleware(t *testing.T) { request := httptest.NewRequest(http.MethodOptions, api, nil) recorder := httptest.NewRecorder() - middleware.ServeHTTP(recorder, request) + middleware.ServeHTTP(contracts.WrapResponseWriter(recorder), request) assert.Equal(t, http.StatusOK, recorder.Code) assert.Equal(t, "", testutils.ReadBody(t, recorder)) @@ -403,7 +404,7 @@ func TestMockMiddleware(t *testing.T) { request := httptest.NewRequest(http.MethodPut, api, nil) recorder := httptest.NewRecorder() - middleware.ServeHTTP(recorder, request) + middleware.ServeHTTP(contracts.WrapResponseWriter(recorder), request) body := testutils.ReadBody(t, recorder) assert.Equal(t, http.StatusOK, recorder.Code) @@ -512,7 +513,7 @@ func TestMockMiddleware(t *testing.T) { request := httptest.NewRequest(http.MethodGet, testCase.url, nil) recorder := httptest.NewRecorder() - middleware.ServeHTTP(recorder, request) + middleware.ServeHTTP(contracts.WrapResponseWriter(recorder), request) body := testutils.ReadBody(t, recorder) assert.Equal(t, testCase.statusCode, recorder.Code) @@ -608,7 +609,7 @@ func TestMockMiddleware(t *testing.T) { request := httptest.NewRequest(http.MethodGet, testCase.url, nil) recorder := httptest.NewRecorder() - middleware.ServeHTTP(recorder, request) + middleware.ServeHTTP(contracts.WrapResponseWriter(recorder), request) body := testutils.ReadBody(t, recorder) assert.Equal(t, testCase.statusCode, recorder.Code) @@ -726,7 +727,7 @@ func TestMockMiddleware(t *testing.T) { } recorder := httptest.NewRecorder() - middleware.ServeHTTP(recorder, request) + middleware.ServeHTTP(contracts.WrapResponseWriter(recorder), request) body := testutils.ReadBody(t, recorder) assert.Equal(t, testCase.statusCode, recorder.Code) diff --git a/internal/server/server.go b/internal/server/server.go index 5000091a..73f9b037 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -7,6 +7,7 @@ import ( "net/http" "time" + "github.com/evg4b/uncors/internal/contracts" "github.com/evg4b/uncors/internal/helpers" "github.com/evg4b/uncors/internal/log" "golang.org/x/net/context" @@ -22,7 +23,7 @@ const ( shutdownTimeout = 15 * time.Second ) -func NewUncorsServer(ctx context.Context, handler http.Handler) *UncorsServer { +func NewUncorsServer(ctx context.Context, handler contracts.Handler) *UncorsServer { globalCtx, globalCtxCancel := context.WithCancel(ctx) server := &http.Server{ BaseContext: func(listener net.Listener) context.Context { @@ -31,7 +32,7 @@ func NewUncorsServer(ctx context.Context, handler http.Handler) *UncorsServer { ReadHeaderTimeout: readHeaderTimeout, Handler: http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { helpers.NormaliseRequest(request) - handler.ServeHTTP(writer, request) + handler.ServeHTTP(contracts.WrapResponseWriter(writer), request) }), ErrorLog: log.StandardErrorLogAdapter(), } diff --git a/internal/server/server_test.go b/internal/server/server_test.go index befe0808..f8a58728 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/evg4b/uncors/internal/contracts" "github.com/evg4b/uncors/internal/helpers" "github.com/evg4b/uncors/internal/server" "github.com/evg4b/uncors/internal/sfmt" @@ -19,7 +20,7 @@ func TestNewUncorsServer(t *testing.T) { ctx := context.Background() expectedResponse := "UNCORS OK!" - var handler http.HandlerFunc = func(w http.ResponseWriter, _r *http.Request) { + var handler contracts.HandlerFunc = func(w *contracts.ResponseWriter, _r *contracts.Request) { w.WriteHeader(http.StatusOK) sfmt.Fprint(w, expectedResponse) } diff --git a/testing/mocks/handler.go b/testing/mocks/handler.go new file mode 100644 index 00000000..66dd37cb --- /dev/null +++ b/testing/mocks/handler.go @@ -0,0 +1,13 @@ +package mocks + +import ( + "testing" + + "github.com/evg4b/uncors/internal/contracts" +) + +func FailNowMock(t *testing.T) contracts.Handler { + return contracts.HandlerFunc(func(_ *contracts.ResponseWriter, _ *contracts.Request) { + t.Fatal("should not be called") + }) +}