Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add methods for NotFound and MethodNotAllowed, add response.EmbeddedSetter #185

Merged
merged 2 commits into from
Jan 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions chirouter/wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,15 @@ func (r *Wrapper) Trace(pattern string, handlerFn http.HandlerFunc) {
r.Method(http.MethodTrace, pattern, handlerFn)
}

// HandlerFunc prepares handler and returns its function.
//
// Can be used as input for NotFound, MethodNotAllowed.
func (r *Wrapper) HandlerFunc(h http.Handler) http.HandlerFunc {
h = nethttp.WrapHandler(h, r.wraps...)

return h.ServeHTTP
}

func (r *Wrapper) resolvePattern(pattern string) string {
return r.basePattern + strings.ReplaceAll(pattern, "/*/", "/")
}
Expand Down
7 changes: 5 additions & 2 deletions chirouter/wrapper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ func (h HandlerWithBar) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
}

func TestNewWrapper(t *testing.T) {
r := chirouter.NewWrapper(chi.NewRouter()).With(func(handler http.Handler) http.Handler {
w := chirouter.NewWrapper(chi.NewRouter())
r := w.With(func(handler http.Handler) http.Handler {
return http.HandlerFunc(handler.ServeHTTP)
})

Expand Down Expand Up @@ -79,6 +80,8 @@ func TestNewWrapper(t *testing.T) {

r.Use(mw)

r.NotFound(r.(*chirouter.Wrapper).HandlerFunc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})))

r.Group(func(r chi.Router) {
r.Method(http.MethodPost,
"/baz/{id}/",
Expand Down Expand Up @@ -125,7 +128,7 @@ func TestNewWrapper(t *testing.T) {
}

assert.Equal(t, 14, handlersCnt)
assert.Equal(t, 21, totalCnt)
assert.Equal(t, 22, totalCnt)
}

func TestWrapper_Use_precedence(t *testing.T) {
Expand Down
7 changes: 5 additions & 2 deletions request/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ func DecoderMiddleware(factory DecoderMaker) func(http.Handler) http.Handler {
)

if !nethttp.HandlerAs(handler, &setRequestDecoder) ||
!nethttp.HandlerAs(handler, &withRoute) ||
!nethttp.HandlerAs(handler, &withUseCase) ||
!usecase.As(withUseCase.UseCase(), &useCaseWithInput) {
return handler
Expand All @@ -45,7 +44,11 @@ func DecoderMiddleware(factory DecoderMaker) func(http.Handler) http.Handler {

input := useCaseWithInput.InputPort()
if input != nil {
method := withRoute.RouteMethod()
method := http.MethodPost // Default for handlers without method (for example NotFound handler).
if nethttp.HandlerAs(handler, &withRoute) {
method = withRoute.RouteMethod()
}

dec := factory.MakeDecoder(method, input, customMapping)
setRequestDecoder.SetRequestDecoder(dec)
}
Expand Down
35 changes: 35 additions & 0 deletions response/encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@ import (
"github.com/swaggest/usecase/status"
)

type (
// Setter captures original http.ResponseWriter.
//
// Implement this interface on a pointer to your output structure to get access to http.ResponseWriter.
Setter interface {
SetResponseWriter(rw http.ResponseWriter)
}
)

// Encoder prepares and writes http response.
type Encoder struct {
JSONWriter func(v interface{})
Expand All @@ -30,6 +39,7 @@ type Encoder struct {
unwrapInterface bool

dynamicWithHeadersSetup bool
dynamicSetter bool
dynamicETagged bool
dynamicNoContent bool
}
Expand Down Expand Up @@ -187,6 +197,10 @@ func (h *Encoder) SetupOutput(output interface{}, ht *rest.HandlerTrait) {
h.dynamicWithHeadersSetup = true
}

if _, ok := output.(Setter); ok || h.unwrapInterface {
h.dynamicSetter = true
}

if _, ok := output.(rest.ETagged); ok || h.unwrapInterface {
h.dynamicETagged = true
}
Expand Down Expand Up @@ -504,6 +518,12 @@ func (h *Encoder) MakeOutput(w http.ResponseWriter, ht rest.HandlerTrait) interf
}
}

if h.dynamicSetter {
if setter, ok := output.(Setter); ok {
setter.SetResponseWriter(w)
}
}

return output
}

Expand Down Expand Up @@ -550,3 +570,18 @@ func (w *writerWithHeaders) Write(data []byte) (int, error) {

return w.ResponseWriter.Write(data)
}

// EmbeddedSetter can capture http.ResponseWriter in your output structure.
type EmbeddedSetter struct {
rw http.ResponseWriter
}

// SetResponseWriter implements Setter.
func (e *EmbeddedSetter) SetResponseWriter(rw http.ResponseWriter) {
e.rw = rw
}

// ResponseWriter is an accessor.
func (e *EmbeddedSetter) ResponseWriter() http.ResponseWriter {
return e.rw
}
64 changes: 64 additions & 0 deletions response/encoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,3 +236,67 @@ func TestEncoder_Writer_httpStatus(t *testing.T) {
e.WriteSuccessfulResponse(w, r, output, rest.HandlerTrait{})
assert.Equal(t, http.StatusCreated, w.Code)
}

func TestEmbeddedSetter_SetResponseWriter(t *testing.T) {
e := response.Encoder{}

type EmbeddedHeader struct {
Foo int `header:"X-Foo" json:"-"`
Bar string `cookie:"bar" json:"-"`
}

type outputPort struct {
response.EmbeddedSetter
EmbeddedHeader
Name string `header:"X-Name" json:"-"`
Items []string `json:"items"`
Cookie int `cookie:"coo,httponly,path=/foo" json:"-"`
Cookie2 bool `cookie:"coo2,httponly,secure,samesite=lax,path=/foo,max-age=86400" json:"-"`
}

ht := rest.HandlerTrait{
SuccessContentType: "application/x-vnd-json",
}

validator := jsonschema.Validator{}
require.NoError(t, validator.AddSchema(
rest.ParamInHeader,
"X-Name",
[]byte(`{"type":"string","minLength":3}`),
false),
)

ht.RespValidator = &validator

e.SetupOutput(outputPort{}, &ht)

r, err := http.NewRequest(http.MethodGet, "/", nil)
require.NoError(t, err)

w := httptest.NewRecorder()
output := e.MakeOutput(w, ht)

out, ok := output.(*outputPort)
assert.True(t, ok)

out.Foo = 321
out.Bar = "baz"
out.Name = "Jane"
out.Items = []string{"one", "two", "three"}
out.Cookie = 123
out.Cookie2 = true

e.WriteSuccessfulResponse(w, r, output, ht)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "Jane", w.Header().Get("X-Name"))
assert.Equal(t, "321", w.Header().Get("X-Foo"))
assert.Equal(t, []string{
"bar=baz",
"coo=123; Path=/foo; HttpOnly",
"coo2=true; Path=/foo; Max-Age=86400; HttpOnly; Secure; SameSite=Lax",
}, w.Header()["Set-Cookie"])
assert.Equal(t, "application/x-vnd-json", w.Header().Get("Content-Type"))
assert.Equal(t, "32", w.Header().Get("Content-Length"))
assert.Equal(t, `{"items":["one","two","three"]}`+"\n", w.Body.String())
assert.Equal(t, w, out.ResponseWriter())
}
10 changes: 10 additions & 0 deletions web/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,16 @@ func (s *Service) Trace(pattern string, uc usecase.Interactor, options ...func(h
s.Method(http.MethodTrace, pattern, nethttp.NewHandler(uc, options...))
}

// OnNotFound registers usecase interactor as a handler for not found conditions.
func (s *Service) OnNotFound(uc usecase.Interactor, options ...func(h *nethttp.Handler)) {
s.NotFound(s.HandlerFunc(nethttp.NewHandler(uc, options...)))
}

// OnMethodNotAllowed registers usecase interactor as a handler for method not allowed conditions.
func (s *Service) OnMethodNotAllowed(uc usecase.Interactor, options ...func(h *nethttp.Handler)) {
s.MethodNotAllowed(s.HandlerFunc(nethttp.NewHandler(uc, options...)))
}

// Docs adds the route `pattern` that serves API documentation with Swagger UI.
//
// Swagger UI should be provided by `swgui` handler constructor, you can use one of these functions
Expand Down
20 changes: 20 additions & 0 deletions web/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,26 @@ func TestDefaultService(t *testing.T) {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {})
})

service.OnNotFound(usecase.NewIOI(
nil,
new(struct {
Foo string `json:"foo"`
}),
func(ctx context.Context, input, output interface{}) error {
return nil
}),
)

service.OnMethodNotAllowed(usecase.NewIOI(
nil,
new(struct {
Foo string `json:"foo"`
}),
func(ctx context.Context, input, output interface{}) error {
return nil
}),
)

rw := httptest.NewRecorder()
r, err := http.NewRequest(http.MethodGet, "http://localhost/docs/openapi.json", nil)
require.NoError(t, err)
Expand Down
Loading