From 2ccfa5239519290bbf3ac9997bc3d05047ed3de8 Mon Sep 17 00:00:00 2001 From: Viacheslav Poturaev Date: Mon, 1 Jan 2024 14:39:54 +0100 Subject: [PATCH] Add methods for NotFound and MethodNotAllowed, add response.EmbeddedSetter (#185) --- chirouter/wrapper.go | 9 ++++++ chirouter/wrapper_test.go | 7 +++-- request/middleware.go | 7 +++-- response/encoder.go | 35 +++++++++++++++++++++ response/encoder_test.go | 64 +++++++++++++++++++++++++++++++++++++++ web/service.go | 10 ++++++ web/service_test.go | 20 ++++++++++++ 7 files changed, 148 insertions(+), 4 deletions(-) diff --git a/chirouter/wrapper.go b/chirouter/wrapper.go index 47e0361..5fe11ce 100644 --- a/chirouter/wrapper.go +++ b/chirouter/wrapper.go @@ -199,6 +199,15 @@ func (r *Wrapper) Trace(pattern string, handlerFn http.HandlerFunc) { r.Method(http.MethodTrace, pattern, handlerFn) } +// HandlerFunc prepares handler and returns its function. +// +// Can be used as input for NotFound, MethodNotAllowed. +func (r *Wrapper) HandlerFunc(h http.Handler) http.HandlerFunc { + h = nethttp.WrapHandler(h, r.wraps...) + + return h.ServeHTTP +} + func (r *Wrapper) resolvePattern(pattern string) string { return r.basePattern + strings.ReplaceAll(pattern, "/*/", "/") } diff --git a/chirouter/wrapper_test.go b/chirouter/wrapper_test.go index c23a545..0ff0d39 100644 --- a/chirouter/wrapper_test.go +++ b/chirouter/wrapper_test.go @@ -51,7 +51,8 @@ func (h HandlerWithBar) ServeHTTP(rw http.ResponseWriter, r *http.Request) { } func TestNewWrapper(t *testing.T) { - r := chirouter.NewWrapper(chi.NewRouter()).With(func(handler http.Handler) http.Handler { + w := chirouter.NewWrapper(chi.NewRouter()) + r := w.With(func(handler http.Handler) http.Handler { return http.HandlerFunc(handler.ServeHTTP) }) @@ -79,6 +80,8 @@ func TestNewWrapper(t *testing.T) { r.Use(mw) + r.NotFound(r.(*chirouter.Wrapper).HandlerFunc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))) + r.Group(func(r chi.Router) { r.Method(http.MethodPost, "/baz/{id}/", @@ -125,7 +128,7 @@ func TestNewWrapper(t *testing.T) { } assert.Equal(t, 14, handlersCnt) - assert.Equal(t, 21, totalCnt) + assert.Equal(t, 22, totalCnt) } func TestWrapper_Use_precedence(t *testing.T) { diff --git a/request/middleware.go b/request/middleware.go index 96533bd..4d53d15 100644 --- a/request/middleware.go +++ b/request/middleware.go @@ -32,7 +32,6 @@ func DecoderMiddleware(factory DecoderMaker) func(http.Handler) http.Handler { ) if !nethttp.HandlerAs(handler, &setRequestDecoder) || - !nethttp.HandlerAs(handler, &withRoute) || !nethttp.HandlerAs(handler, &withUseCase) || !usecase.As(withUseCase.UseCase(), &useCaseWithInput) { return handler @@ -45,7 +44,11 @@ func DecoderMiddleware(factory DecoderMaker) func(http.Handler) http.Handler { input := useCaseWithInput.InputPort() if input != nil { - method := withRoute.RouteMethod() + method := http.MethodPost // Default for handlers without method (for example NotFound handler). + if nethttp.HandlerAs(handler, &withRoute) { + method = withRoute.RouteMethod() + } + dec := factory.MakeDecoder(method, input, customMapping) setRequestDecoder.SetRequestDecoder(dec) } diff --git a/response/encoder.go b/response/encoder.go index 875fde7..f34a8bb 100644 --- a/response/encoder.go +++ b/response/encoder.go @@ -17,6 +17,15 @@ import ( "github.com/swaggest/usecase/status" ) +type ( + // Setter captures original http.ResponseWriter. + // + // Implement this interface on a pointer to your output structure to get access to http.ResponseWriter. + Setter interface { + SetResponseWriter(rw http.ResponseWriter) + } +) + // Encoder prepares and writes http response. type Encoder struct { JSONWriter func(v interface{}) @@ -30,6 +39,7 @@ type Encoder struct { unwrapInterface bool dynamicWithHeadersSetup bool + dynamicSetter bool dynamicETagged bool dynamicNoContent bool } @@ -187,6 +197,10 @@ func (h *Encoder) SetupOutput(output interface{}, ht *rest.HandlerTrait) { h.dynamicWithHeadersSetup = true } + if _, ok := output.(Setter); ok || h.unwrapInterface { + h.dynamicSetter = true + } + if _, ok := output.(rest.ETagged); ok || h.unwrapInterface { h.dynamicETagged = true } @@ -504,6 +518,12 @@ func (h *Encoder) MakeOutput(w http.ResponseWriter, ht rest.HandlerTrait) interf } } + if h.dynamicSetter { + if setter, ok := output.(Setter); ok { + setter.SetResponseWriter(w) + } + } + return output } @@ -550,3 +570,18 @@ func (w *writerWithHeaders) Write(data []byte) (int, error) { return w.ResponseWriter.Write(data) } + +// EmbeddedSetter can capture http.ResponseWriter in your output structure. +type EmbeddedSetter struct { + rw http.ResponseWriter +} + +// SetResponseWriter implements Setter. +func (e *EmbeddedSetter) SetResponseWriter(rw http.ResponseWriter) { + e.rw = rw +} + +// ResponseWriter is an accessor. +func (e *EmbeddedSetter) ResponseWriter() http.ResponseWriter { + return e.rw +} diff --git a/response/encoder_test.go b/response/encoder_test.go index 6f16960..93d530f 100644 --- a/response/encoder_test.go +++ b/response/encoder_test.go @@ -236,3 +236,67 @@ func TestEncoder_Writer_httpStatus(t *testing.T) { e.WriteSuccessfulResponse(w, r, output, rest.HandlerTrait{}) assert.Equal(t, http.StatusCreated, w.Code) } + +func TestEmbeddedSetter_SetResponseWriter(t *testing.T) { + e := response.Encoder{} + + type EmbeddedHeader struct { + Foo int `header:"X-Foo" json:"-"` + Bar string `cookie:"bar" json:"-"` + } + + type outputPort struct { + response.EmbeddedSetter + EmbeddedHeader + Name string `header:"X-Name" json:"-"` + Items []string `json:"items"` + Cookie int `cookie:"coo,httponly,path=/foo" json:"-"` + Cookie2 bool `cookie:"coo2,httponly,secure,samesite=lax,path=/foo,max-age=86400" json:"-"` + } + + ht := rest.HandlerTrait{ + SuccessContentType: "application/x-vnd-json", + } + + validator := jsonschema.Validator{} + require.NoError(t, validator.AddSchema( + rest.ParamInHeader, + "X-Name", + []byte(`{"type":"string","minLength":3}`), + false), + ) + + ht.RespValidator = &validator + + e.SetupOutput(outputPort{}, &ht) + + r, err := http.NewRequest(http.MethodGet, "/", nil) + require.NoError(t, err) + + w := httptest.NewRecorder() + output := e.MakeOutput(w, ht) + + out, ok := output.(*outputPort) + assert.True(t, ok) + + out.Foo = 321 + out.Bar = "baz" + out.Name = "Jane" + out.Items = []string{"one", "two", "three"} + out.Cookie = 123 + out.Cookie2 = true + + e.WriteSuccessfulResponse(w, r, output, ht) + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "Jane", w.Header().Get("X-Name")) + assert.Equal(t, "321", w.Header().Get("X-Foo")) + assert.Equal(t, []string{ + "bar=baz", + "coo=123; Path=/foo; HttpOnly", + "coo2=true; Path=/foo; Max-Age=86400; HttpOnly; Secure; SameSite=Lax", + }, w.Header()["Set-Cookie"]) + assert.Equal(t, "application/x-vnd-json", w.Header().Get("Content-Type")) + assert.Equal(t, "32", w.Header().Get("Content-Length")) + assert.Equal(t, `{"items":["one","two","three"]}`+"\n", w.Body.String()) + assert.Equal(t, w, out.ResponseWriter()) +} diff --git a/web/service.go b/web/service.go index fed0621..d907caa 100644 --- a/web/service.go +++ b/web/service.go @@ -162,6 +162,16 @@ func (s *Service) Trace(pattern string, uc usecase.Interactor, options ...func(h s.Method(http.MethodTrace, pattern, nethttp.NewHandler(uc, options...)) } +// OnNotFound registers usecase interactor as a handler for not found conditions. +func (s *Service) OnNotFound(uc usecase.Interactor, options ...func(h *nethttp.Handler)) { + s.NotFound(s.HandlerFunc(nethttp.NewHandler(uc, options...))) +} + +// OnMethodNotAllowed registers usecase interactor as a handler for method not allowed conditions. +func (s *Service) OnMethodNotAllowed(uc usecase.Interactor, options ...func(h *nethttp.Handler)) { + s.MethodNotAllowed(s.HandlerFunc(nethttp.NewHandler(uc, options...))) +} + // Docs adds the route `pattern` that serves API documentation with Swagger UI. // // Swagger UI should be provided by `swgui` handler constructor, you can use one of these functions diff --git a/web/service_test.go b/web/service_test.go index 6b57a9c..1d822cf 100644 --- a/web/service_test.go +++ b/web/service_test.go @@ -56,6 +56,26 @@ func TestDefaultService(t *testing.T) { return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {}) }) + service.OnNotFound(usecase.NewIOI( + nil, + new(struct { + Foo string `json:"foo"` + }), + func(ctx context.Context, input, output interface{}) error { + return nil + }), + ) + + service.OnMethodNotAllowed(usecase.NewIOI( + nil, + new(struct { + Foo string `json:"foo"` + }), + func(ctx context.Context, input, output interface{}) error { + return nil + }), + ) + rw := httptest.NewRecorder() r, err := http.NewRequest(http.MethodGet, "http://localhost/docs/openapi.json", nil) require.NoError(t, err)