Skip to content

Commit

Permalink
Add methods for NotFound and MethodNotAllowed, add response.EmbeddedS…
Browse files Browse the repository at this point in the history
…etter (#185)
  • Loading branch information
vearutop authored Jan 1, 2024
1 parent d64d866 commit 2ccfa52
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 4 deletions.
9 changes: 9 additions & 0 deletions chirouter/wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,15 @@ func (r *Wrapper) Trace(pattern string, handlerFn http.HandlerFunc) {
r.Method(http.MethodTrace, pattern, handlerFn)
}

// HandlerFunc prepares handler and returns its function.
//
// Can be used as input for NotFound, MethodNotAllowed.
func (r *Wrapper) HandlerFunc(h http.Handler) http.HandlerFunc {
h = nethttp.WrapHandler(h, r.wraps...)

return h.ServeHTTP
}

func (r *Wrapper) resolvePattern(pattern string) string {
return r.basePattern + strings.ReplaceAll(pattern, "/*/", "/")
}
Expand Down
7 changes: 5 additions & 2 deletions chirouter/wrapper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ func (h HandlerWithBar) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
}

func TestNewWrapper(t *testing.T) {
r := chirouter.NewWrapper(chi.NewRouter()).With(func(handler http.Handler) http.Handler {
w := chirouter.NewWrapper(chi.NewRouter())
r := w.With(func(handler http.Handler) http.Handler {
return http.HandlerFunc(handler.ServeHTTP)
})

Expand Down Expand Up @@ -79,6 +80,8 @@ func TestNewWrapper(t *testing.T) {

r.Use(mw)

r.NotFound(r.(*chirouter.Wrapper).HandlerFunc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})))

r.Group(func(r chi.Router) {
r.Method(http.MethodPost,
"/baz/{id}/",
Expand Down Expand Up @@ -125,7 +128,7 @@ func TestNewWrapper(t *testing.T) {
}

assert.Equal(t, 14, handlersCnt)
assert.Equal(t, 21, totalCnt)
assert.Equal(t, 22, totalCnt)
}

func TestWrapper_Use_precedence(t *testing.T) {
Expand Down
7 changes: 5 additions & 2 deletions request/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ func DecoderMiddleware(factory DecoderMaker) func(http.Handler) http.Handler {
)

if !nethttp.HandlerAs(handler, &setRequestDecoder) ||
!nethttp.HandlerAs(handler, &withRoute) ||
!nethttp.HandlerAs(handler, &withUseCase) ||
!usecase.As(withUseCase.UseCase(), &useCaseWithInput) {
return handler
Expand All @@ -45,7 +44,11 @@ func DecoderMiddleware(factory DecoderMaker) func(http.Handler) http.Handler {

input := useCaseWithInput.InputPort()
if input != nil {
method := withRoute.RouteMethod()
method := http.MethodPost // Default for handlers without method (for example NotFound handler).
if nethttp.HandlerAs(handler, &withRoute) {
method = withRoute.RouteMethod()
}

dec := factory.MakeDecoder(method, input, customMapping)
setRequestDecoder.SetRequestDecoder(dec)
}
Expand Down
35 changes: 35 additions & 0 deletions response/encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@ import (
"github.com/swaggest/usecase/status"
)

type (
// Setter captures original http.ResponseWriter.
//
// Implement this interface on a pointer to your output structure to get access to http.ResponseWriter.
Setter interface {
SetResponseWriter(rw http.ResponseWriter)
}
)

// Encoder prepares and writes http response.
type Encoder struct {
JSONWriter func(v interface{})
Expand All @@ -30,6 +39,7 @@ type Encoder struct {
unwrapInterface bool

dynamicWithHeadersSetup bool
dynamicSetter bool
dynamicETagged bool
dynamicNoContent bool
}
Expand Down Expand Up @@ -187,6 +197,10 @@ func (h *Encoder) SetupOutput(output interface{}, ht *rest.HandlerTrait) {
h.dynamicWithHeadersSetup = true
}

if _, ok := output.(Setter); ok || h.unwrapInterface {
h.dynamicSetter = true
}

if _, ok := output.(rest.ETagged); ok || h.unwrapInterface {
h.dynamicETagged = true
}
Expand Down Expand Up @@ -504,6 +518,12 @@ func (h *Encoder) MakeOutput(w http.ResponseWriter, ht rest.HandlerTrait) interf
}
}

if h.dynamicSetter {
if setter, ok := output.(Setter); ok {
setter.SetResponseWriter(w)
}
}

return output
}

Expand Down Expand Up @@ -550,3 +570,18 @@ func (w *writerWithHeaders) Write(data []byte) (int, error) {

return w.ResponseWriter.Write(data)
}

// EmbeddedSetter can capture http.ResponseWriter in your output structure.
type EmbeddedSetter struct {
rw http.ResponseWriter
}

// SetResponseWriter implements Setter.
func (e *EmbeddedSetter) SetResponseWriter(rw http.ResponseWriter) {
e.rw = rw
}

// ResponseWriter is an accessor.
func (e *EmbeddedSetter) ResponseWriter() http.ResponseWriter {
return e.rw
}
64 changes: 64 additions & 0 deletions response/encoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,3 +236,67 @@ func TestEncoder_Writer_httpStatus(t *testing.T) {
e.WriteSuccessfulResponse(w, r, output, rest.HandlerTrait{})
assert.Equal(t, http.StatusCreated, w.Code)
}

func TestEmbeddedSetter_SetResponseWriter(t *testing.T) {
e := response.Encoder{}

type EmbeddedHeader struct {
Foo int `header:"X-Foo" json:"-"`
Bar string `cookie:"bar" json:"-"`
}

type outputPort struct {
response.EmbeddedSetter
EmbeddedHeader
Name string `header:"X-Name" json:"-"`
Items []string `json:"items"`
Cookie int `cookie:"coo,httponly,path=/foo" json:"-"`
Cookie2 bool `cookie:"coo2,httponly,secure,samesite=lax,path=/foo,max-age=86400" json:"-"`
}

ht := rest.HandlerTrait{
SuccessContentType: "application/x-vnd-json",
}

validator := jsonschema.Validator{}
require.NoError(t, validator.AddSchema(
rest.ParamInHeader,
"X-Name",
[]byte(`{"type":"string","minLength":3}`),
false),
)

ht.RespValidator = &validator

e.SetupOutput(outputPort{}, &ht)

r, err := http.NewRequest(http.MethodGet, "/", nil)
require.NoError(t, err)

w := httptest.NewRecorder()
output := e.MakeOutput(w, ht)

out, ok := output.(*outputPort)
assert.True(t, ok)

out.Foo = 321
out.Bar = "baz"
out.Name = "Jane"
out.Items = []string{"one", "two", "three"}
out.Cookie = 123
out.Cookie2 = true

e.WriteSuccessfulResponse(w, r, output, ht)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "Jane", w.Header().Get("X-Name"))
assert.Equal(t, "321", w.Header().Get("X-Foo"))
assert.Equal(t, []string{
"bar=baz",
"coo=123; Path=/foo; HttpOnly",
"coo2=true; Path=/foo; Max-Age=86400; HttpOnly; Secure; SameSite=Lax",
}, w.Header()["Set-Cookie"])
assert.Equal(t, "application/x-vnd-json", w.Header().Get("Content-Type"))
assert.Equal(t, "32", w.Header().Get("Content-Length"))
assert.Equal(t, `{"items":["one","two","three"]}`+"\n", w.Body.String())
assert.Equal(t, w, out.ResponseWriter())
}
10 changes: 10 additions & 0 deletions web/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,16 @@ func (s *Service) Trace(pattern string, uc usecase.Interactor, options ...func(h
s.Method(http.MethodTrace, pattern, nethttp.NewHandler(uc, options...))
}

// OnNotFound registers usecase interactor as a handler for not found conditions.
func (s *Service) OnNotFound(uc usecase.Interactor, options ...func(h *nethttp.Handler)) {
s.NotFound(s.HandlerFunc(nethttp.NewHandler(uc, options...)))
}

// OnMethodNotAllowed registers usecase interactor as a handler for method not allowed conditions.
func (s *Service) OnMethodNotAllowed(uc usecase.Interactor, options ...func(h *nethttp.Handler)) {
s.MethodNotAllowed(s.HandlerFunc(nethttp.NewHandler(uc, options...)))
}

// Docs adds the route `pattern` that serves API documentation with Swagger UI.
//
// Swagger UI should be provided by `swgui` handler constructor, you can use one of these functions
Expand Down
20 changes: 20 additions & 0 deletions web/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,26 @@ func TestDefaultService(t *testing.T) {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {})
})

service.OnNotFound(usecase.NewIOI(
nil,
new(struct {
Foo string `json:"foo"`
}),
func(ctx context.Context, input, output interface{}) error {
return nil
}),
)

service.OnMethodNotAllowed(usecase.NewIOI(
nil,
new(struct {
Foo string `json:"foo"`
}),
func(ctx context.Context, input, output interface{}) error {
return nil
}),
)

rw := httptest.NewRecorder()
r, err := http.NewRequest(http.MethodGet, "http://localhost/docs/openapi.json", nil)
require.NoError(t, err)
Expand Down

0 comments on commit 2ccfa52

Please sign in to comment.