From 6d2b6944d8377ed9285d177d4ab9219c2d1c7a4f Mon Sep 17 00:00:00 2001 From: John Coleman Date: Mon, 29 Mar 2021 16:21:07 -0600 Subject: [PATCH 1/2] Make ContextData() available to middleware add test for retrieving context data in middleware --- context.go | 12 ------------ group.go | 15 +++++++++++++++ group_test.go | 45 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 60 insertions(+), 12 deletions(-) diff --git a/context.go b/context.go index 30ecf4f..92daf4d 100644 --- a/context.go +++ b/context.go @@ -58,13 +58,7 @@ func (cg *ContextGroup) NewGroup(path string) *ContextGroup { // Handle allows handling HTTP requests via an http.HandlerFunc, as opposed to an httptreemux.HandlerFunc. // Any parameters from the request URL are stored in a map[string]string in the request's context. func (cg *ContextGroup) Handle(method, path string, handler http.HandlerFunc) { - fullPath := cg.group.path + path cg.group.Handle(method, path, func(w http.ResponseWriter, r *http.Request, params map[string]string) { - routeData := &contextData{ - route: fullPath, - params: params, - } - r = r.WithContext(AddRouteDataToContext(r.Context(), routeData)) handler(w, r) }) } @@ -72,13 +66,7 @@ func (cg *ContextGroup) Handle(method, path string, handler http.HandlerFunc) { // Handler allows handling HTTP requests via an http.Handler interface, as opposed to an httptreemux.HandlerFunc. // Any parameters from the request URL are stored in a map[string]string in the request's context. func (cg *ContextGroup) Handler(method, path string, handler http.Handler) { - fullPath := cg.group.path + path cg.group.Handle(method, path, func(w http.ResponseWriter, r *http.Request, params map[string]string) { - routeData := &contextData{ - route: fullPath, - params: params, - } - r = r.WithContext(AddRouteDataToContext(r.Context(), routeData)) handler.ServeHTTP(w, r) }) } diff --git a/group.go b/group.go index 826f12b..07ce3d7 100644 --- a/group.go +++ b/group.go @@ -16,6 +16,17 @@ func handlerWithMiddlewares(handler HandlerFunc, stack []MiddlewareFunc) Handler return handler } +func handlerWithContextData(next HandlerFunc, fullPath string) HandlerFunc { + return func(writer http.ResponseWriter, request *http.Request, m map[string]string) { + routeData := &contextData{ + route: fullPath, + params: m, + } + request = request.WithContext(AddRouteDataToContext(request.Context(), routeData)) + next(writer, request, m) + } +} + type Group struct { path string mux *TreeMux @@ -138,6 +149,10 @@ func (g *Group) Handle(method string, path string, handler HandlerFunc) { handler = handlerWithMiddlewares(handler, g.stack) } + //add the context data after adding all middleware + fullPath := g.path + path + handler = handlerWithContextData(handler, fullPath) + addSlash := false addOne := func(thePath string) { node := g.mux.root.addPath(thePath[1:], nil, false) diff --git a/group_test.go b/group_test.go index 9065d5c..c75fce3 100644 --- a/group_test.go +++ b/group_test.go @@ -3,6 +3,7 @@ package httptreemux import ( "net/http" "net/http/httptest" + "reflect" "testing" ) @@ -165,3 +166,47 @@ func TestSetGetAfterHead(t *testing.T) { testMethod("HEAD", "HEAD") testMethod("GET", "GET") } + +func TestContextDataWithMiddleware(t *testing.T) { + wantRoute := "/foo/:id/bar" + wantParams := map[string]string{ + "id": "15", + } + + validateRequestAndParams := func(request *http.Request, params map[string]string, location string) { + data := ContextData(request.Context()) + if data == nil { + t.Fatalf("ContextData returned nil in %s", location) + } + if data.Route() != wantRoute { + t.Errorf("Unexpected route in %s. Got %s", location, data.Route()) + } + if !reflect.DeepEqual(data.Params(), wantParams) { + t.Errorf("Unexpected context params in %s. Got %+v", location, data.Params()) + } + if !reflect.DeepEqual(params, wantParams) { + t.Errorf("Unexpected handler params in %s. Got %+v", location, params) + } + } + + router := New() + router.Use(func(next HandlerFunc) HandlerFunc { + return func(writer http.ResponseWriter, request *http.Request, m map[string]string) { + validateRequestAndParams(request, m, "middleware") + next(writer, request, m) + } + }) + + router.GET(wantRoute, func(writer http.ResponseWriter, request *http.Request, m map[string]string) { + validateRequestAndParams(request, m, "handler") + writer.WriteHeader(http.StatusOK) + }) + + w := httptest.NewRecorder() + r, _ := http.NewRequest(http.MethodGet, "/foo/15/bar", nil) + router.ServeHTTP(w, r) + + if w.Code != http.StatusOK { + t.Fatalf("unexpected status code. got %d", w.Code) + } +} From 0b810765eaa4d48f397f44dd689d1f22a05f5f76 Mon Sep 17 00:00:00 2001 From: Daniel Imfeld Date: Tue, 30 Mar 2021 02:40:15 +0000 Subject: [PATCH 2/2] Move context-specific code to Go1.7+ files --- context.go | 31 ++++++++++++++++++++++-- context_test.go | 64 +++++++++++++++++++++++++++++++++++++++++++------ group.go | 18 ++++---------- group_test.go | 45 ---------------------------------- 4 files changed, 90 insertions(+), 68 deletions(-) diff --git a/context.go b/context.go index 92daf4d..66e1bba 100644 --- a/context.go +++ b/context.go @@ -55,20 +55,47 @@ func (cg *ContextGroup) NewGroup(path string) *ContextGroup { return cg.NewContextGroup(path) } +func (cg *ContextGroup) wrapHandler(path string, handler HandlerFunc) HandlerFunc { + if len(cg.group.stack) > 0 { + handler = handlerWithMiddlewares(handler, cg.group.stack) + } + + //add the context data after adding all middleware + fullPath := cg.group.path + path + return func(writer http.ResponseWriter, request *http.Request, m map[string]string) { + routeData := &contextData{ + route: fullPath, + params: m, + } + request = request.WithContext(AddRouteDataToContext(request.Context(), routeData)) + handler(writer, request, m) + } +} + // Handle allows handling HTTP requests via an http.HandlerFunc, as opposed to an httptreemux.HandlerFunc. // Any parameters from the request URL are stored in a map[string]string in the request's context. func (cg *ContextGroup) Handle(method, path string, handler http.HandlerFunc) { - cg.group.Handle(method, path, func(w http.ResponseWriter, r *http.Request, params map[string]string) { + cg.group.mux.mutex.Lock() + defer cg.group.mux.mutex.Unlock() + + wrapped := cg.wrapHandler(path, func(w http.ResponseWriter, r *http.Request, params map[string]string) { handler(w, r) }) + + cg.group.addFullStackHandler(method, path, wrapped) } // Handler allows handling HTTP requests via an http.Handler interface, as opposed to an httptreemux.HandlerFunc. // Any parameters from the request URL are stored in a map[string]string in the request's context. func (cg *ContextGroup) Handler(method, path string, handler http.Handler) { - cg.group.Handle(method, path, func(w http.ResponseWriter, r *http.Request, params map[string]string) { + cg.group.mux.mutex.Lock() + defer cg.group.mux.mutex.Unlock() + + wrapped := cg.wrapHandler(path, func(w http.ResponseWriter, r *http.Request, params map[string]string) { handler.ServeHTTP(w, r) }) + + cg.group.addFullStackHandler(method, path, wrapped) } // GET is convenience method for handling GET requests on a context group. diff --git a/context_test.go b/context_test.go index fa4b230..08e8384 100644 --- a/context_test.go +++ b/context_test.go @@ -43,24 +43,24 @@ func TestContextParams(t *testing.T) { } func TestContextRoute(t *testing.T) { - tests := []struct{ + tests := []struct { name, expectedRoute string - } { + }{ { - name: "basic", + name: "basic", expectedRoute: "/base/path", }, { - name: "params", + name: "params", expectedRoute: "/base/path/:id/items/:itemid", }, { - name: "catch-all", + name: "catch-all", expectedRoute: "/base/*path", }, { - name: "empty", + name: "empty", expectedRoute: "", }, } @@ -140,6 +140,10 @@ func testContextGroupMethods(t *testing.T, reqGen RequestCreator, headCanUseGet } ctxData := ContextData(r.Context()) + if ctxData == nil { + t.Fatal("context did not contain ContextData") + } + v, ok = ctxData.Params()["param"] if hasParam && !ok { t.Error("missing key 'param' in context from ContextData") @@ -371,7 +375,7 @@ func TestAddDataToContext(t *testing.T) { } ctx := AddRouteDataToContext(context.Background(), &contextData{ - route: expectedRoute, + route: expectedRoute, params: expectedParams, }) @@ -416,3 +420,49 @@ func TestAddRouteToContext(t *testing.T) { t.Error("failed to retrieve context data") } } + +func TestContextDataWithMiddleware(t *testing.T) { + wantRoute := "/foo/:id/bar" + wantParams := map[string]string{ + "id": "15", + } + + validateRequestAndParams := func(request *http.Request, params map[string]string, location string) { + data := ContextData(request.Context()) + if data == nil { + t.Fatalf("ContextData returned nil in %s", location) + } + if data.Route() != wantRoute { + t.Errorf("Unexpected route in %s. Got %s", location, data.Route()) + } + if !reflect.DeepEqual(data.Params(), wantParams) { + t.Errorf("Unexpected context params in %s. Got %+v", location, data.Params()) + } + if !reflect.DeepEqual(params, wantParams) { + t.Errorf("Unexpected handler params in %s. Got %+v", location, params) + } + } + + router := NewContextMux() + router.Use(func(next HandlerFunc) HandlerFunc { + return func(writer http.ResponseWriter, request *http.Request, m map[string]string) { + t.Log("Testing Middleware") + validateRequestAndParams(request, m, "middleware") + next(writer, request, m) + } + }) + + router.GET(wantRoute, func(writer http.ResponseWriter, request *http.Request) { + t.Log("Testing handler") + validateRequestAndParams(request, ContextParams(request.Context()), "handler") + writer.WriteHeader(http.StatusOK) + }) + + w := httptest.NewRecorder() + r, _ := http.NewRequest(http.MethodGet, "/foo/15/bar", nil) + router.ServeHTTP(w, r) + + if w.Code != http.StatusOK { + t.Fatalf("unexpected status code. got %d", w.Code) + } +} diff --git a/group.go b/group.go index 07ce3d7..b10bc64 100644 --- a/group.go +++ b/group.go @@ -16,17 +16,6 @@ func handlerWithMiddlewares(handler HandlerFunc, stack []MiddlewareFunc) Handler return handler } -func handlerWithContextData(next HandlerFunc, fullPath string) HandlerFunc { - return func(writer http.ResponseWriter, request *http.Request, m map[string]string) { - routeData := &contextData{ - route: fullPath, - params: m, - } - request = request.WithContext(AddRouteDataToContext(request.Context(), routeData)) - next(writer, request, m) - } -} - type Group struct { path string mux *TreeMux @@ -149,10 +138,10 @@ func (g *Group) Handle(method string, path string, handler HandlerFunc) { handler = handlerWithMiddlewares(handler, g.stack) } - //add the context data after adding all middleware - fullPath := g.path + path - handler = handlerWithContextData(handler, fullPath) + g.addFullStackHandler(method, path, handler) +} +func (g *Group) addFullStackHandler(method string, path string, handler HandlerFunc) { addSlash := false addOne := func(thePath string) { node := g.mux.root.addPath(thePath[1:], nil, false) @@ -190,6 +179,7 @@ func (g *Group) Handle(method string, path string, handler HandlerFunc) { } addOne(path) + } // Syntactic sugar for Handle("GET", path, handler) diff --git a/group_test.go b/group_test.go index c75fce3..9065d5c 100644 --- a/group_test.go +++ b/group_test.go @@ -3,7 +3,6 @@ package httptreemux import ( "net/http" "net/http/httptest" - "reflect" "testing" ) @@ -166,47 +165,3 @@ func TestSetGetAfterHead(t *testing.T) { testMethod("HEAD", "HEAD") testMethod("GET", "GET") } - -func TestContextDataWithMiddleware(t *testing.T) { - wantRoute := "/foo/:id/bar" - wantParams := map[string]string{ - "id": "15", - } - - validateRequestAndParams := func(request *http.Request, params map[string]string, location string) { - data := ContextData(request.Context()) - if data == nil { - t.Fatalf("ContextData returned nil in %s", location) - } - if data.Route() != wantRoute { - t.Errorf("Unexpected route in %s. Got %s", location, data.Route()) - } - if !reflect.DeepEqual(data.Params(), wantParams) { - t.Errorf("Unexpected context params in %s. Got %+v", location, data.Params()) - } - if !reflect.DeepEqual(params, wantParams) { - t.Errorf("Unexpected handler params in %s. Got %+v", location, params) - } - } - - router := New() - router.Use(func(next HandlerFunc) HandlerFunc { - return func(writer http.ResponseWriter, request *http.Request, m map[string]string) { - validateRequestAndParams(request, m, "middleware") - next(writer, request, m) - } - }) - - router.GET(wantRoute, func(writer http.ResponseWriter, request *http.Request, m map[string]string) { - validateRequestAndParams(request, m, "handler") - writer.WriteHeader(http.StatusOK) - }) - - w := httptest.NewRecorder() - r, _ := http.NewRequest(http.MethodGet, "/foo/15/bar", nil) - router.ServeHTTP(w, r) - - if w.Code != http.StatusOK { - t.Fatalf("unexpected status code. got %d", w.Code) - } -}