Skip to content

Commit

Permalink
added custom handler type
Browse files Browse the repository at this point in the history
  • Loading branch information
evg4b committed Jun 18, 2023
1 parent 5841b92 commit e34f5e7
Show file tree
Hide file tree
Showing 18 changed files with 158 additions and 64 deletions.
31 changes: 31 additions & 0 deletions internal/contracts/http.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package contracts

import (
"errors"
"net/http"
)

type Request = http.Request

type Handler interface {
ServeHTTP(*ResponseWriter, *Request)
}

type HandlerFunc func(*ResponseWriter, *Request)

func (f HandlerFunc) ServeHTTP(w *ResponseWriter, r *Request) {
f(w, r)
}

var ErrResponceNotCasted = errors.New("received incorrect response writer type")

func CastToHTTPHandler(handler Handler) http.Handler {
return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
writer, ok := response.(*ResponseWriter)
if !ok {
panic(ErrResponceNotCasted)
}

handler.ServeHTTP(writer, request)
})
}
42 changes: 42 additions & 0 deletions internal/contracts/http_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package contracts_test

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/evg4b/uncors/internal/contracts"
"github.com/evg4b/uncors/internal/sfmt"
"github.com/evg4b/uncors/testing/testutils"
"github.com/stretchr/testify/assert"
)

func TestCastToHTTPHandler(t *testing.T) {
const expectedBody = `{ "OK": true }`
uncorsHandler := contracts.HandlerFunc(func(w *contracts.ResponseWriter, r *contracts.Request) {
w.WriteHeader(http.StatusOK)
sfmt.Fprint(w, expectedBody)
})

httpHandler := contracts.CastToHTTPHandler(uncorsHandler)

req := httptest.NewRequest(http.MethodGet, "/data", nil)

t.Run("cast correctly", func(t *testing.T) {
recirder := httptest.NewRecorder()
responceWriter := contracts.WrapResponseWriter(recirder)

assert.NotPanics(t, func() {
httpHandler.ServeHTTP(responceWriter, req)
assert.Equal(t, expectedBody, testutils.ReadBody(t, recirder))
})
})

t.Run("panit when request is not wrapped", func(t *testing.T) {
recirder := httptest.NewRecorder()

assert.PanicsWithValue(t, contracts.ErrResponceNotCasted, func() {
httpHandler.ServeHTTP(recirder, req)
})
})
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package static
package contracts

import (
"net/http"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package static_test
package contracts_test

import (
"net/http/httptest"
"testing"

"github.com/evg4b/uncors/internal/handler/static"
"github.com/evg4b/uncors/internal/contracts"
"github.com/evg4b/uncors/internal/sfmt"
"github.com/evg4b/uncors/testing/testutils"
"github.com/stretchr/testify/assert"
Expand All @@ -15,7 +15,7 @@ func TestResponseWriterWrapper(t *testing.T) {
const expectedCode = 201

recorder := httptest.NewRecorder()
writer := static.WrapResponseWriter(recorder)
writer := contracts.WrapResponseWriter(recorder)

writer.WriteHeader(expectedCode)
sfmt.Fprint(writer, expectedValye)
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion internal/handler/mock/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func NewMockMiddleware(options ...MiddlewareOption) *Middleware {
return middleware
}

func (m *Middleware) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
func (m *Middleware) ServeHTTP(writer *contracts.ResponseWriter, request *contracts.Request) {
response := m.response
header := writer.Header()

Expand Down
13 changes: 7 additions & 6 deletions internal/handler/mock/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"time"

"github.com/evg4b/uncors/internal/config"
"github.com/evg4b/uncors/internal/contracts"
"github.com/evg4b/uncors/internal/handler/mock"
"github.com/evg4b/uncors/testing/mocks"
"github.com/evg4b/uncors/testing/testconstants"
Expand Down Expand Up @@ -74,7 +75,7 @@ func TestHandler(t *testing.T) {

recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodGet, "/", nil)
handler.ServeHTTP(recorder, request)
handler.ServeHTTP(contracts.WrapResponseWriter(recorder), request)

body := testutils.ReadBody(t, recorder)
assert.EqualValues(t, testCase.expected, body)
Expand Down Expand Up @@ -142,7 +143,7 @@ func TestHandler(t *testing.T) {

recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodGet, "/", nil)
handler.ServeHTTP(recorder, request)
handler.ServeHTTP(contracts.WrapResponseWriter(recorder), request)

header := testutils.ReadHeader(t, recorder)
assert.EqualValues(t, testCase.expected, header.Get(headers.ContentType))
Expand Down Expand Up @@ -232,7 +233,7 @@ func TestHandler(t *testing.T) {
request := httptest.NewRequest(http.MethodGet, "/", nil)
recorder := httptest.NewRecorder()

handler.ServeHTTP(recorder, request)
handler.ServeHTTP(contracts.WrapResponseWriter(recorder), request)

assert.EqualValues(t, testCase.expected, testutils.ReadHeader(t, recorder))
assert.Equal(t, http.StatusOK, recorder.Code)
Expand Down Expand Up @@ -280,7 +281,7 @@ func TestHandler(t *testing.T) {
request := httptest.NewRequest(http.MethodGet, "/", nil)
recorder := httptest.NewRecorder()

handler.ServeHTTP(recorder, request)
handler.ServeHTTP(contracts.WrapResponseWriter(recorder), request)

assert.Equal(t, testCase.expected, recorder.Code)
})
Expand Down Expand Up @@ -355,7 +356,7 @@ func TestHandler(t *testing.T) {
request := httptest.NewRequest(http.MethodGet, "/", nil)
recorder := httptest.NewRecorder()

handler.ServeHTTP(recorder, request)
handler.ServeHTTP(contracts.WrapResponseWriter(recorder), request)

assert.Equal(t, called, testCase.shouldBeCalled)
})
Expand All @@ -382,7 +383,7 @@ func TestHandler(t *testing.T) {
waitGroup.Add(1)
go func() {
defer waitGroup.Done()
handler.ServeHTTP(recorder, request.WithContext(ctx))
handler.ServeHTTP(contracts.WrapResponseWriter(recorder), request.WithContext(ctx))
}()

cancel()
Expand Down
2 changes: 1 addition & 1 deletion internal/handler/proxy/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func NewProxyHandler(options ...HandlerOption) *Handler {
return middleware
}

func (m *Handler) ServeHTTP(response http.ResponseWriter, request *http.Request) {
func (m *Handler) ServeHTTP(response *contracts.ResponseWriter, request *contracts.Request) {
if err := m.handle(response, request); err != nil {
infra.HTTPError(response, err)
}
Expand Down
9 changes: 5 additions & 4 deletions internal/handler/proxy/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"testing"

"github.com/evg4b/uncors/internal/config"
"github.com/evg4b/uncors/internal/contracts"
"github.com/evg4b/uncors/internal/handler/proxy"
"github.com/evg4b/uncors/internal/helpers"
"github.com/evg4b/uncors/internal/urlreplacer"
Expand Down Expand Up @@ -80,7 +81,7 @@ func TestProxyMiddleware(t *testing.T) {

req.Header.Add(testCase.headerKey, testCase.URL)

proc.ServeHTTP(httptest.NewRecorder(), req)
proc.ServeHTTP(contracts.WrapResponseWriter(httptest.NewRecorder()), req)
})
}
})
Expand Down Expand Up @@ -133,7 +134,7 @@ func TestProxyMiddleware(t *testing.T) {

recorder := httptest.NewRecorder()

proc.ServeHTTP(recorder, req)
proc.ServeHTTP(contracts.WrapResponseWriter(recorder), req)

assert.Equal(t, testCase.expectedURL, recorder.Header().Get(testCase.headerKey))
})
Expand Down Expand Up @@ -166,7 +167,7 @@ func TestProxyMiddleware(t *testing.T) {

recorder := httptest.NewRecorder()

proc.ServeHTTP(recorder, req)
proc.ServeHTTP(contracts.WrapResponseWriter(recorder), req)

header := recorder.Header()
assert.Equal(t, "*", header.Get(headers.AccessControlAllowOrigin))
Expand Down Expand Up @@ -240,7 +241,7 @@ func TestProxyMiddleware(t *testing.T) {
req, err := http.NewRequestWithContext(context.TODO(), http.MethodOptions, "/", nil)
testutils.CheckNoError(t, err)

middleware.ServeHTTP(recorder, req)
middleware.ServeHTTP(contracts.WrapResponseWriter(recorder), req)

assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, testCase.expected, recorder.Header())
Expand Down
6 changes: 3 additions & 3 deletions internal/handler/static/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (

type Middleware struct {
fs afero.Fs
next http.Handler
next contracts.Handler
index string
logger contracts.Logger
prefix string
Expand All @@ -30,8 +30,8 @@ func NewStaticMiddleware(options ...MiddlewareOption) *Middleware {
return middleware
}

func (m *Middleware) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
response := WrapResponseWriter(writer)
func (m *Middleware) ServeHTTP(writer *contracts.ResponseWriter, request *contracts.Request) {
response := contracts.WrapResponseWriter(writer)

filePath := m.extractFilePath(request)
file, stat, err := m.openFile(filePath)
Expand Down
4 changes: 1 addition & 3 deletions internal/handler/static/middleware_options.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package static

import (
"net/http"

"github.com/evg4b/uncors/internal/contracts"
"github.com/spf13/afero"
)
Expand All @@ -21,7 +19,7 @@ func WithIndex(index string) MiddlewareOption {
}
}

func WithNext(next http.Handler) MiddlewareOption {
func WithNext(next contracts.Handler) MiddlewareOption {
return func(m *Middleware) {
m.next = next
}
Expand Down
21 changes: 9 additions & 12 deletions internal/handler/static/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"strings"
"testing"

"github.com/evg4b/uncors/internal/contracts"
"github.com/evg4b/uncors/internal/handler/static"
"github.com/evg4b/uncors/internal/sfmt"
"github.com/evg4b/uncors/testing/mocks"
Expand Down Expand Up @@ -101,7 +102,7 @@ func TestMiddleware(t *testing.T) {
middleware := static.NewStaticMiddleware(
static.WithFileSystem(fs),
static.WithLogger(loggerMock),
static.WithNext(http.HandlerFunc(func(writer http.ResponseWriter, _ *http.Request) {
static.WithNext(contracts.HandlerFunc(func(writer *contracts.ResponseWriter, _ *contracts.Request) {
writer.WriteHeader(testHTTPStatusCode)
sfmt.Fprint(writer, testHTTPBody)
})),
Expand All @@ -114,7 +115,7 @@ func TestMiddleware(t *testing.T) {
requestURI, err := url.Parse(testCase.path)
testutils.CheckNoError(t, err)

middleware.ServeHTTP(recorder, &http.Request{
middleware.ServeHTTP(contracts.WrapResponseWriter(recorder), &http.Request{
Method: http.MethodGet,
URL: requestURI,
})
Expand All @@ -132,7 +133,7 @@ func TestMiddleware(t *testing.T) {
requestURI, err := url.Parse(testCase.path)
testutils.CheckNoError(t, err)

middleware.ServeHTTP(recorder, &http.Request{
middleware.ServeHTTP(contracts.WrapResponseWriter(recorder), &http.Request{
Method: http.MethodGet,
URL: requestURI,
})
Expand All @@ -149,9 +150,7 @@ func TestMiddleware(t *testing.T) {
static.WithFileSystem(fs),
static.WithLogger(loggerMock),
static.WithIndex(indexHTML),
static.WithNext(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
panic("should not be called")
})),
static.WithNext(mocks.FailNowMock(t)),
)

t.Run("send index file", func(t *testing.T) {
Expand All @@ -161,7 +160,7 @@ func TestMiddleware(t *testing.T) {
requestURI, err := url.Parse(testCase.path)
testutils.CheckNoError(t, err)

middleware.ServeHTTP(recorder, &http.Request{
middleware.ServeHTTP(contracts.WrapResponseWriter(recorder), &http.Request{
Method: http.MethodGet,
URL: requestURI,
})
Expand All @@ -179,7 +178,7 @@ func TestMiddleware(t *testing.T) {
requestURI, err := url.Parse(testCase.path)
testutils.CheckNoError(t, err)

middleware.ServeHTTP(recorder, &http.Request{
middleware.ServeHTTP(contracts.WrapResponseWriter(recorder), &http.Request{
Method: http.MethodGet,
URL: requestURI,
})
Expand All @@ -195,16 +194,14 @@ func TestMiddleware(t *testing.T) {
static.WithFileSystem(fs),
static.WithLogger(loggerMock),
static.WithIndex("/not-exists.html"),
static.WithNext(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
panic("should not be called")
})),
static.WithNext(mocks.FailNowMock(t)),
)

recorder := httptest.NewRecorder()
requestURI, err := url.Parse("/options/")
testutils.CheckNoError(t, err)

middleware.ServeHTTP(recorder, &http.Request{
middleware.ServeHTTP(contracts.WrapResponseWriter(recorder), &http.Request{
Method: http.MethodGet,
URL: requestURI,
})
Expand Down
10 changes: 8 additions & 2 deletions internal/handler/static_routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,18 @@ import (
"strings"

"github.com/evg4b/uncors/internal/config"
"github.com/evg4b/uncors/internal/contracts"
"github.com/evg4b/uncors/internal/handler/static"
"github.com/evg4b/uncors/internal/ui"
"github.com/gorilla/mux"
"github.com/spf13/afero"
)

func (m *RequestHandler) makeStaticRoutes(router *mux.Router, statics config.StaticDirectories, next http.Handler) {
func (m *RequestHandler) makeStaticRoutes(
router *mux.Router,
statics config.StaticDirectories,
next contracts.Handler,
) {
for _, staticDir := range statics {
clearPath := strings.TrimSuffix(staticDir.Path, "/")
path := clearPath + "/"
Expand All @@ -29,6 +34,7 @@ func (m *RequestHandler) makeStaticRoutes(router *mux.Router, statics config.Sta
static.WithPrefix(path),
)

route.PathPrefix(path).Handler(handler)
httpHandler := contracts.CastToHTTPHandler(handler)
route.PathPrefix(path).Handler(httpHandler)
}
}
Loading

0 comments on commit e34f5e7

Please sign in to comment.