From 99ad97ff6f4861f89329a898edced62e2262f1ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joakim=20Hamr=C3=A9n?= Date: Thu, 16 Apr 2020 14:10:37 +0200 Subject: [PATCH] Route recursive-search regexp patterns (#506) --- mux_test.go | 41 +++++++++++++++++++++++++++++++++++++++++ tree.go | 29 ++++++++++++++++++++++------- tree_test.go | 51 +++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 114 insertions(+), 7 deletions(-) diff --git a/mux_test.go b/mux_test.go index add3857c..4cf96a2f 100644 --- a/mux_test.go +++ b/mux_test.go @@ -1466,6 +1466,47 @@ func TestMuxRegexp2(t *testing.T) { } } +func TestMuxRegexp3(t *testing.T) { + r := NewRouter() + r.Get("/one/{firstId:[a-z0-9-]+}/{secondId:[a-z]+}/first", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("first")) + }) + r.Get("/one/{firstId:[a-z0-9-_]+}/{secondId:[0-9]+}/second", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("second")) + }) + r.Delete("/one/{firstId:[a-z0-9-_]+}/{secondId:[0-9]+}/second", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("third")) + }) + + r.Route("/one", func(r Router) { + r.Get("/{dns:[a-z-0-9_]+}", func(writer http.ResponseWriter, request *http.Request) { + writer.Write([]byte("_")) + }) + r.Get("/{dns:[a-z-0-9_]+}/info", func(writer http.ResponseWriter, request *http.Request) { + writer.Write([]byte("_")) + }) + r.Delete("/{id:[0-9]+}", func(writer http.ResponseWriter, request *http.Request) { + writer.Write([]byte("forth")) + }) + }) + + ts := httptest.NewServer(r) + defer ts.Close() + + if _, body := testRequest(t, ts, "GET", "/one/hello/peter/first", nil); body != "first" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/one/hithere/123/second", nil); body != "second" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "DELETE", "/one/hithere/123/second", nil); body != "third" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "DELETE", "/one/123", nil); body != "forth" { + t.Fatalf(body) + } +} + func TestMuxContextIsThreadSafe(t *testing.T) { router := NewRouter() router.Get("/{id}", func(w http.ResponseWriter, r *http.Request) { diff --git a/tree.go b/tree.go index fa9bd4bd..a7e29c5b 100644 --- a/tree.go +++ b/tree.go @@ -417,8 +417,6 @@ func (n *node) findRoute(rctx *Context, method methodTyp, path string) *node { continue } - found := false - // serially loop through each node grouped by the tail delimiter for idx := 0; idx < len(nds); idx++ { xn = nds[idx] @@ -443,16 +441,33 @@ func (n *node) findRoute(rctx *Context, method methodTyp, path string) *node { continue } + prevlen := len(rctx.routeParams.Values) rctx.routeParams.Values = append(rctx.routeParams.Values, xsearch[:p]) xsearch = xsearch[p:] - found = true - break - } - if !found { - rctx.routeParams.Values = append(rctx.routeParams.Values, "") + if len(xsearch) == 0 { + if xn.isLeaf() { + h := xn.endpoints[method] + if h != nil && h.handler != nil { + rctx.routeParams.Keys = append(rctx.routeParams.Keys, h.paramKeys...) + return xn + } + } + } + + // recursively find the next node on this branch + fin := xn.findRoute(rctx, method, xsearch) + if fin != nil { + return fin + } + + // not found on this branch, reset vars + rctx.routeParams.Values = rctx.routeParams.Values[:prevlen] + xsearch = search } + rctx.routeParams.Values = append(rctx.routeParams.Values, "") + default: // catch-all nodes rctx.routeParams.Values = append(rctx.routeParams.Values, search) diff --git a/tree_test.go b/tree_test.go index 70d642a6..1017efbf 100644 --- a/tree_test.go +++ b/tree_test.go @@ -333,6 +333,57 @@ func TestTreeRegexp(t *testing.T) { } } +func TestTreeRegexpRecursive(t *testing.T) { + hStub1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hStub2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + + tr := &node{} + tr.InsertRoute(mGET, "/one/{firstId:[a-z0-9-]+}/{secondId:[a-z0-9-]+}/first", hStub1) + tr.InsertRoute(mGET, "/one/{firstId:[a-z0-9-_]+}/{secondId:[a-z0-9-_]+}/second", hStub2) + + // log.Println("~~~~~~~~~") + // log.Println("~~~~~~~~~") + // debugPrintTree(0, 0, tr, 0) + // log.Println("~~~~~~~~~") + // log.Println("~~~~~~~~~") + + tests := []struct { + r string // input request path + h http.Handler // output matched handler + k []string // output param keys + v []string // output param values + }{ + {r: "/one/hello/world/first", h: hStub1, k: []string{"firstId", "secondId"}, v: []string{"hello", "world"}}, + {r: "/one/hi_there/ok/second", h: hStub2, k: []string{"firstId", "secondId"}, v: []string{"hi_there", "ok"}}, + {r: "/one///first", h: nil, k: []string{}, v: []string{}}, + {r: "/one/hi/123/second", h: hStub2, k: []string{"firstId", "secondId"}, v: []string{"hi", "123"}}, + } + + for i, tt := range tests { + rctx := NewRouteContext() + + _, handlers, _ := tr.FindRoute(rctx, mGET, tt.r) + + var handler http.Handler + if methodHandler, ok := handlers[mGET]; ok { + handler = methodHandler.handler + } + + paramKeys := rctx.routeParams.Keys + paramValues := rctx.routeParams.Values + + if fmt.Sprintf("%v", tt.h) != fmt.Sprintf("%v", handler) { + t.Errorf("input [%d]: find '%s' expecting handler:%v , got:%v", i, tt.r, tt.h, handler) + } + if !stringSliceEqual(tt.k, paramKeys) { + t.Errorf("input [%d]: find '%s' expecting paramKeys:(%d)%v , got:(%d)%v", i, tt.r, len(tt.k), tt.k, len(paramKeys), paramKeys) + } + if !stringSliceEqual(tt.v, paramValues) { + t.Errorf("input [%d]: find '%s' expecting paramValues:(%d)%v , got:(%d)%v", i, tt.r, len(tt.v), tt.v, len(paramValues), paramValues) + } + } +} + func TestTreeRegexMatchWholeParam(t *testing.T) { hStub1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})