diff --git a/README.md b/README.md index 93ffc76..7daa979 100644 --- a/README.md +++ b/README.md @@ -47,8 +47,8 @@ type hello struct { func main() { mux := route.NewServeMux() - mux.GET("/", func (ctx route.Context) error { - return ctx.JSON(http.StatusOK, &hello{Title:"Hello, World!"}) + mux.GET("/", func (c route.Context) error { + return c.JSON(http.StatusOK, &hello{Title:"Hello, World!"}) }) log.Fatal(http.ListenAndServe(":9000", mux)) diff --git a/example/main.go b/example/main.go new file mode 100644 index 0000000..22f9cca --- /dev/null +++ b/example/main.go @@ -0,0 +1,21 @@ +package main + +import ( + "github.com/goroute/route" + "log" + "net/http" +) + +type hello struct { + Title string +} + +func main() { + mux := route.NewServeMux() + + mux.GET("/", func(c route.Context) error { + return c.JSON(http.StatusOK, &hello{Title: "Hello, World!"}) + }) + + log.Fatal(http.ListenAndServe(":9000", mux)) +} diff --git a/group.go b/group.go index a19df30..9e82985 100644 --- a/group.go +++ b/group.go @@ -5,16 +5,14 @@ import ( "path" ) -type ( - // Group is a set of sub-routes for a specified route. It can be used for inner - // routes that share a common middleware or functionality that should be separate - // from the parent mux instance while still inheriting from it. - Group struct { - prefix string - middleware []MiddlewareFunc - nio *Mux - } -) +// Group is a set of sub-routes for a specified route. It can be used for inner +// routes that share a common middleware or functionality that should be separate +// from the parent mux instance while still inheriting from it. +type Group struct { + prefix string + middleware []MiddlewareFunc + mux *Mux +} // Use implements `Mux#Use()` for sub-routes within the Group. func (g *Group) Use(middleware ...MiddlewareFunc) { @@ -22,7 +20,7 @@ func (g *Group) Use(middleware ...MiddlewareFunc) { // Allow all requests to reach the group as they might get dropped if router // doesn't find a match, making none of the group middleware process. for _, p := range []string{"", "/*"} { - g.nio.Any(path.Clean(g.prefix+p), func(c Context) error { + g.mux.Any(path.Clean(g.prefix+p), func(c Context) error { return NotFoundHandler(c) }, g.middleware...) } @@ -96,7 +94,7 @@ func (g *Group) Group(prefix string, middleware ...MiddlewareFunc) *Group { m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware)) m = append(m, g.middleware...) m = append(m, middleware...) - return g.nio.Group(g.prefix+prefix, m...) + return g.mux.Group(g.prefix+prefix, m...) } // Static implements `Mux#Static()` for sub-routes within the Group. @@ -106,7 +104,7 @@ func (g *Group) Static(prefix, root string) { // File implements `Mux#File()` for sub-routes within the Group. func (g *Group) File(path, file string) { - g.nio.File(g.prefix+path, file) + g.mux.File(g.prefix+path, file) } // Add implements `Mux#Add()` for sub-routes within the Group. @@ -117,5 +115,5 @@ func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...Midd m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware)) m = append(m, g.middleware...) m = append(m, middleware...) - return g.nio.Add(method, g.prefix+path, handler, m...) + return g.mux.Add(method, g.prefix+path, handler, m...) } diff --git a/group_test.go b/group_test.go index 2d4ee12..faacba7 100644 --- a/group_test.go +++ b/group_test.go @@ -31,30 +31,20 @@ func TestGroupRouteMiddleware(t *testing.T) { e := NewServeMux() g := e.Group("/group") h := func(Context) error { return nil } - m1 := func(next HandlerFunc) HandlerFunc { - return func(c Context) error { - return next(c) - } + m1 := func(c Context, next HandlerFunc) error { + return next(c) } - m2 := func(next HandlerFunc) HandlerFunc { - return func(c Context) error { - return next(c) - } + m2 := func(c Context, next HandlerFunc) error { + return next(c) } - m3 := func(next HandlerFunc) HandlerFunc { - return func(c Context) error { - return next(c) - } + m3 := func(c Context, next HandlerFunc) error { + return next(c) } - m4 := func(next HandlerFunc) HandlerFunc { - return func(c Context) error { - return c.NoContent(404) - } + m4 := func(c Context, next HandlerFunc) error { + return c.NoContent(404) } - m5 := func(next HandlerFunc) HandlerFunc { - return func(c Context) error { - return c.NoContent(405) - } + m5 := func(c Context, next HandlerFunc) error { + return c.NoContent(405) } g.Use(m1, m2, m3) g.GET("/404", h, m4) diff --git a/middleware.go b/middleware.go index 1c2d9a3..1b7ff6e 100644 --- a/middleware.go +++ b/middleware.go @@ -4,7 +4,7 @@ import "net/http" type ( // MiddlewareFunc defines a function to process middleware. - MiddlewareFunc func(HandlerFunc) HandlerFunc + MiddlewareFunc func(c Context, next HandlerFunc) error // Skipper defines a function to skip middleware. Returning true skips processing // the middleware. @@ -13,14 +13,12 @@ type ( // WrapMiddleware wraps `func(http.Handler) http.Handler` into `mux.MiddlewareFunc` func WrapMiddleware(m func(http.Handler) http.Handler) MiddlewareFunc { - return func(next HandlerFunc) HandlerFunc { - return func(c Context) (err error) { - m(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - c.SetRequest(r) - err = next(c) - })).ServeHTTP(c.Response(), c.Request()) - return - } + return func(c Context, next HandlerFunc) (err error) { + m(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c.SetRequest(r) + err = next(c) + })).ServeHTTP(c.Response(), c.Request()) + return } } @@ -28,3 +26,10 @@ func WrapMiddleware(m func(http.Handler) http.Handler) MiddlewareFunc { func DefaultSkipper(Context) bool { return false } + +// compose chains given handler with next middleware. +func compose(h HandlerFunc, m MiddlewareFunc) HandlerFunc { + return func(c Context) error { + return m(c, h) + } +} diff --git a/middleware_test.go b/middleware_test.go index a54e913..f0dd6b4 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -9,7 +9,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestNioWrapMiddleware(t *testing.T) { +func TestWrapMiddleware(t *testing.T) { e := NewServeMux() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() @@ -21,12 +21,21 @@ func TestNioWrapMiddleware(t *testing.T) { h.ServeHTTP(w, r) }) }) - h := mw(func(c Context) error { + + h := func(c Context) error { return c.String(http.StatusOK, "OK") - }) - if assert.NoError(t, h(c)) { + } + + err := mw(c, h) + if assert.NoError(t, err) { assert.Equal(t, "mw", buf.String()) assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, "OK", rec.Body.String()) } } + +func TestDefaultSkipper(t *testing.T) { + skipper := DefaultSkipper(nil) + + assert.Equal(t, false, skipper) +} diff --git a/mux.go b/mux.go index b7562d7..e394754 100644 --- a/mux.go +++ b/mux.go @@ -380,7 +380,7 @@ func (mux *Mux) Add(method, path string, handler HandlerFunc, middleware ...Midd h := handler // Chain middleware for i := len(middleware) - 1; i >= 0; i-- { - h = middleware[i](h) + h = compose(h, middleware[i]) } return h(c) }) @@ -395,7 +395,7 @@ func (mux *Mux) Add(method, path string, handler HandlerFunc, middleware ...Midd // Group creates a new router group with prefix and optional group-level middleware. func (mux *Mux) Group(prefix string, m ...MiddlewareFunc) (g *Group) { - g = &Group{prefix: prefix, nio: mux} + g = &Group{prefix: prefix, mux: mux} g.Use(m...) return } @@ -421,19 +421,19 @@ func (mux *Mux) ServeHTTP(w http.ResponseWriter, r *http.Request) { mux.router.find(r.Method, getPath(r), c) h = c.Handler() for i := len(mux.middleware) - 1; i >= 0; i-- { - h = mux.middleware[i](h) + h = compose(h, mux.middleware[i]) } } else { h = func(c Context) error { mux.router.find(r.Method, getPath(r), c) h := c.Handler() for i := len(mux.middleware) - 1; i >= 0; i-- { - h = mux.middleware[i](h) + h = compose(h, mux.middleware[i]) } return h(c) } for i := len(mux.premiddleware) - 1; i >= 0; i-- { - h = mux.premiddleware[i](h) + h = compose(h, mux.premiddleware[i]) } } diff --git a/mux_test.go b/mux_test.go index 0366b2f..a141620 100644 --- a/mux_test.go +++ b/mux_test.go @@ -34,44 +34,44 @@ const userJSONPretty = `{ }` func TestMux(t *testing.T) { - e := NewServeMux() + mux := NewServeMux() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() - c := e.NewContext(req, rec) + c := mux.NewContext(req, rec) - e.defaultHTTPErrorHandler(errors.New("error"), c) + mux.defaultHTTPErrorHandler(errors.New("error"), c) assert.Equal(t, http.StatusInternalServerError, rec.Code) } func TestMuxStatic(t *testing.T) { - e := NewServeMux() + mux := NewServeMux() assert := assert.New(t) // OK - e.Static("/images", "testdata/images") - c, b := request(http.MethodGet, "/images/walle.png", e) + mux.Static("/images", "testdata/images") + c, b := request(http.MethodGet, "/images/walle.png", mux) assert.Equal(http.StatusOK, c) assert.NotEmpty(b) // No file - e.Static("/images", "testdata/scripts") - c, _ = request(http.MethodGet, "/images/bolt.png", e) + mux.Static("/images", "testdata/scripts") + c, _ = request(http.MethodGet, "/images/bolt.png", mux) assert.Equal(http.StatusNotFound, c) // Directory - e.Static("/images", "testdata/images") - c, _ = request(http.MethodGet, "/images", e) + mux.Static("/images", "testdata/images") + c, _ = request(http.MethodGet, "/images", mux) assert.Equal(http.StatusNotFound, c) // Directory with index.html - e.Static("/", "testdata") - c, r := request(http.MethodGet, "/", e) + mux.Static("/", "testdata") + c, r := request(http.MethodGet, "/", mux) assert.Equal(http.StatusOK, c) assert.Equal(true, strings.HasPrefix(r, "")) // Sub-directory with index.html - c, r = request(http.MethodGet, "/folder", e) + c, r = request(http.MethodGet, "/folder", mux) assert.Equal(http.StatusOK, c) assert.Equal(true, strings.HasPrefix(r, "")) } @@ -82,99 +82,106 @@ func TestMuxWithOptions(t *testing.T) { mockHTTPErrorHandler := func(error, Context) { } - e := NewServeMux( + mux := NewServeMux( WithBinder(binder), WithRenderer(renderer), WithHTTPErrorHandler(mockHTTPErrorHandler), ) - assert.Equal(t, binder, e.binder) - assert.Equal(t, renderer, e.renderer) - assert.NotNil(t, e.httpErrorHandler) + assert.Equal(t, binder, mux.binder) + assert.Equal(t, renderer, mux.renderer) + assert.NotNil(t, mux.httpErrorHandler) } func TestMuxFile(t *testing.T) { - e := NewServeMux() - e.File("/walle", "testdata/images/walle.png") - c, b := request(http.MethodGet, "/walle", e) + mux := NewServeMux() + mux.File("/walle", "testdata/images/walle.png") + c, b := request(http.MethodGet, "/walle", mux) assert.Equal(t, http.StatusOK, c) assert.NotEmpty(t, b) } func TestMuxMiddleware(t *testing.T) { - e := NewServeMux() + mux := NewServeMux() buf := new(bytes.Buffer) - e.Pre(func(next HandlerFunc) HandlerFunc { - return func(c Context) error { - assert.Empty(t, c.Path()) - buf.WriteString("-1") - return next(c) - } + mux.Pre(func(c Context, next HandlerFunc) error { + assert.Empty(t, c.Path()) + buf.WriteString("-1") + return next(c) }) - e.Use(func(next HandlerFunc) HandlerFunc { - return func(c Context) error { - buf.WriteString("1") - return next(c) - } + mux.Use(func(c Context, next HandlerFunc) error { + buf.WriteString("1") + return next(c) }) - e.Use(func(next HandlerFunc) HandlerFunc { - return func(c Context) error { - buf.WriteString("2") - return next(c) - } + mux.Use(func(c Context, next HandlerFunc) error { + buf.WriteString("2") + return next(c) }) - e.Use(func(next HandlerFunc) HandlerFunc { - return func(c Context) error { - buf.WriteString("3") - return next(c) - } + mux.Use(func(c Context, next HandlerFunc) error { + buf.WriteString("3") + return next(c) }) // Route - e.GET("/", func(c Context) error { + mux.GET("/", func(c Context) error { return c.String(http.StatusOK, "OK") }) - c, b := request(http.MethodGet, "/", e) + c, b := request(http.MethodGet, "/", mux) + assert.Equal(t, "-1123", buf.String()) assert.Equal(t, http.StatusOK, c) assert.Equal(t, "OK", b) } func TestMuxMiddlewareError(t *testing.T) { - e := NewServeMux() - e.Use(func(next HandlerFunc) HandlerFunc { - return func(c Context) error { - return errors.New("error") - } + mux := NewServeMux() + buf := new(bytes.Buffer) + + mux.Use(func(c Context, next HandlerFunc) error { + buf.WriteString("1") + return next(c) }) - e.GET("/", NotFoundHandler) - c, _ := request(http.MethodGet, "/", e) + + mux.Use(func(c Context, next HandlerFunc) error { + buf.WriteString("2") + return errors.New("error") + }) + + mux.Use(func(c Context, next HandlerFunc) error { + buf.WriteString("3") + return next(c) + }) + + mux.GET("/", NotFoundHandler) + c, _ := request(http.MethodGet, "/", mux) + + assert.Equal(t, "12", buf.String()) assert.Equal(t, http.StatusInternalServerError, c) } func TestMuxHandler(t *testing.T) { - e := NewServeMux() + mux := NewServeMux() // HandlerFunc - e.GET("/ok", func(c Context) error { + mux.GET("/ok", func(c Context) error { return c.String(http.StatusOK, "OK") }) - c, b := request(http.MethodGet, "/ok", e) + c, b := request(http.MethodGet, "/ok", mux) assert.Equal(t, http.StatusOK, c) assert.Equal(t, "OK", b) } func TestMuxWrapHandler(t *testing.T) { - e := NewServeMux() + mux := NewServeMux() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() - c := e.NewContext(req, rec) + c := mux.NewContext(req, rec) h := WrapHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("test")) @@ -186,66 +193,66 @@ func TestMuxWrapHandler(t *testing.T) { } func TestMuxConnect(t *testing.T) { - e := NewServeMux() - testMethod(t, http.MethodConnect, "/", e) + mux := NewServeMux() + testMethod(t, http.MethodConnect, "/", mux) } func TestMuxDelete(t *testing.T) { - e := NewServeMux() - testMethod(t, http.MethodDelete, "/", e) + mux := NewServeMux() + testMethod(t, http.MethodDelete, "/", mux) } func TestMuxGet(t *testing.T) { - e := NewServeMux() - testMethod(t, http.MethodGet, "/", e) + mux := NewServeMux() + testMethod(t, http.MethodGet, "/", mux) } func TestMuxHead(t *testing.T) { - e := NewServeMux() - testMethod(t, http.MethodHead, "/", e) + mux := NewServeMux() + testMethod(t, http.MethodHead, "/", mux) } func TestMuxOptions(t *testing.T) { - e := NewServeMux() - testMethod(t, http.MethodOptions, "/", e) + mux := NewServeMux() + testMethod(t, http.MethodOptions, "/", mux) } func TestMuxPatch(t *testing.T) { - e := NewServeMux() - testMethod(t, http.MethodPatch, "/", e) + mux := NewServeMux() + testMethod(t, http.MethodPatch, "/", mux) } func TestMuxPost(t *testing.T) { - e := NewServeMux() - testMethod(t, http.MethodPost, "/", e) + mux := NewServeMux() + testMethod(t, http.MethodPost, "/", mux) } func TestMuxPut(t *testing.T) { - e := NewServeMux() - testMethod(t, http.MethodPut, "/", e) + mux := NewServeMux() + testMethod(t, http.MethodPut, "/", mux) } func TestMuxTrace(t *testing.T) { - e := NewServeMux() - testMethod(t, http.MethodTrace, "/", e) + mux := NewServeMux() + testMethod(t, http.MethodTrace, "/", mux) } func TestMuxAny(t *testing.T) { // JFC - e := NewServeMux() - e.Any("/", func(c Context) error { + mux := NewServeMux() + mux.Any("/", func(c Context) error { return c.String(http.StatusOK, "Any") }) } func TestMuxMatch(t *testing.T) { // JFC - e := NewServeMux() - e.Match([]string{http.MethodGet, http.MethodPost}, "/", func(c Context) error { + mux := NewServeMux() + mux.Match([]string{http.MethodGet, http.MethodPost}, "/", func(c Context) error { return c.String(http.StatusOK, "Match") }) } func TestMuxRoutes(t *testing.T) { - e := NewServeMux() + mux := NewServeMux() routes := []*Route{ {http.MethodGet, "/users/:user/events", ""}, {http.MethodGet, "/users/:user/events/public", ""}, @@ -253,13 +260,13 @@ func TestMuxRoutes(t *testing.T) { {http.MethodPost, "/repos/:owner/:repo/git/tags", ""}, } for _, r := range routes { - e.Add(r.Method, r.Path, func(c Context) error { + mux.Add(r.Method, r.Path, func(c Context) error { return c.String(http.StatusOK, "OK") }) } - if assert.Equal(t, len(routes), len(e.Routes())) { - for _, r := range e.Routes() { + if assert.Equal(t, len(routes), len(mux.Routes())) { + for _, r := range mux.Routes() { found := false for _, rr := range routes { if r.Method == rr.Method && r.Path == rr.Path { @@ -275,24 +282,22 @@ func TestMuxRoutes(t *testing.T) { } func TestMuxEncodedPath(t *testing.T) { - e := NewServeMux() - e.GET("/:id", func(c Context) error { + mux := NewServeMux() + mux.GET("/:id", func(c Context) error { return c.NoContent(http.StatusOK) }) req := httptest.NewRequest(http.MethodGet, "/with%2Fslash", nil) rec := httptest.NewRecorder() - e.ServeHTTP(rec, req) + mux.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code) } func TestMuxGroup(t *testing.T) { - e := NewServeMux() + mux := NewServeMux() buf := new(bytes.Buffer) - e.Use(MiddlewareFunc(func(next HandlerFunc) HandlerFunc { - return func(c Context) error { - buf.WriteString("0") - return next(c) - } + mux.Use(MiddlewareFunc(func(c Context, next HandlerFunc) error { + buf.WriteString("0") + return next(c) })) h := func(c Context) error { return c.NoContent(http.StatusOK) @@ -302,86 +307,80 @@ func TestMuxGroup(t *testing.T) { // Routes //-------- - e.GET("/users", h) + mux.GET("/users", h) // Group - g1 := e.Group("/group1") - g1.Use(func(next HandlerFunc) HandlerFunc { - return func(c Context) error { - buf.WriteString("1") - return next(c) - } + g1 := mux.Group("/group1") + g1.Use(func(c Context, next HandlerFunc) error { + buf.WriteString("1") + return next(c) }) g1.GET("", h) // Nested groups with middleware - g2 := e.Group("/group2") - g2.Use(func(next HandlerFunc) HandlerFunc { - return func(c Context) error { - buf.WriteString("2") - return next(c) - } + g2 := mux.Group("/group2") + g2.Use(func(c Context, next HandlerFunc) error { + buf.WriteString("2") + return next(c) }) g3 := g2.Group("/group3") - g3.Use(func(next HandlerFunc) HandlerFunc { - return func(c Context) error { - buf.WriteString("3") - return next(c) - } + g3.Use(func(c Context, next HandlerFunc) error { + buf.WriteString("3") + return next(c) }) g3.GET("", h) - request(http.MethodGet, "/users", e) + request(http.MethodGet, "/users", mux) assert.Equal(t, "0", buf.String()) buf.Reset() - request(http.MethodGet, "/group1", e) + request(http.MethodGet, "/group1", mux) assert.Equal(t, "01", buf.String()) buf.Reset() - request(http.MethodGet, "/group2/group3", e) + request(http.MethodGet, "/group2/group3", mux) assert.Equal(t, "023", buf.String()) } func TestMuxNotFound(t *testing.T) { - e := NewServeMux() + mux := NewServeMux() req := httptest.NewRequest(http.MethodGet, "/files", nil) rec := httptest.NewRecorder() - e.ServeHTTP(rec, req) + mux.ServeHTTP(rec, req) assert.Equal(t, http.StatusNotFound, rec.Code) } func TestMuxMethodNotAllowed(t *testing.T) { - e := NewServeMux() - e.GET("/", func(c Context) error { + mux := NewServeMux() + mux.GET("/", func(c Context) error { return c.String(http.StatusOK, "Mux!") }) req := httptest.NewRequest(http.MethodPost, "/", nil) rec := httptest.NewRecorder() - e.ServeHTTP(rec, req) + mux.ServeHTTP(rec, req) assert.Equal(t, http.StatusMethodNotAllowed, rec.Code) } func TestMuxContext(t *testing.T) { - e := NewServeMux() - c := e.pool.Get().(*context) + mux := NewServeMux() + c := mux.pool.Get().(*context) assert.IsType(t, new(context), c) - e.pool.Put(c) + mux.pool.Put(c) } func TestMuxStart(t *testing.T) { - e := NewServeMux() + mux := NewServeMux() go func() { - err := http.ListenAndServe(":0", e) + err := http.ListenAndServe(":0", mux) assert.NoError(t, err) }() time.Sleep(200 * time.Millisecond) } func TestMuxStartTLS(t *testing.T) { - e := NewServeMux() + mux := NewServeMux() go func() { - err := http.ListenAndServeTLS(":0", "testdata/certs/cert.pem", "testdata/certs/key.pem", e) + err := http.ListenAndServeTLS(":0", "testdata/certs/cert.pem", "testdata/certs/key.pem", mux) // Prevent the test to fail after closing the servers if err != http.ErrServerClosed { assert.NoError(t, err)