Skip to content

Commit

Permalink
Merge pull request #7 from louieliu97/ll-add-path-template
Browse files Browse the repository at this point in the history
Add the ability to store and retrieve the path template
  • Loading branch information
shellfu authored Sep 12, 2024
2 parents 426617d + f97fd11 commit c54b427
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 10 deletions.
23 changes: 19 additions & 4 deletions route.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package muxer

import (
"errors"
"net/http"
"regexp"
)
Expand All @@ -11,10 +12,11 @@ It contains the regular expression that matches the request path, the HTTP metho
the handler to be executed for that request, and the parameter names extracted from the path.
*/
type Route struct {
path *regexp.Regexp
method string
handler http.Handler
params []string
path *regexp.Regexp
method string
handler http.Handler
params []string
template string
}

func (r *Route) match(path string) map[string]string {
Expand All @@ -30,3 +32,16 @@ func (r *Route) match(path string) map[string]string {

return params
}

// PathTemplate retrieves the path template of the current route
func (r *Route) PathTemplate() (string, error) {
if r == nil {
return "", errors.New("route is nil, no template")
}

if r.template == "" {
return r.template, errors.New("template is empty")
}

return r.template, nil
}
27 changes: 21 additions & 6 deletions router.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ type contextKey string
const (
// ParamsKey is the key used to store the extracted parameters in the request context.
ParamsKey contextKey = "params"
// RouteContextKey is the key used to store the matched route in the request context
RouteContextKey contextKey = "matched_route"
)

/*
Expand Down Expand Up @@ -131,19 +133,20 @@ func (r *Router) HandleRoute(method, path string, handler http.HandlerFunc) {
// Parse path to extract parameter names
paramNames := make([]string, 0)
re := regexp.MustCompile(`:([\w-]+)`)
path = re.ReplaceAllStringFunc(path, func(m string) string {
pathRegex := re.ReplaceAllStringFunc(path, func(m string) string {
paramName := m[1:]
paramNames = append(paramNames, paramName)
return `([-\w.]+)`
})

exactPath := regexp.MustCompile("^" + path + "$")
exactPath := regexp.MustCompile("^" + pathRegex + "$")

r.routes = append(r.routes, Route{
method: method,
path: exactPath,
handler: handler,
params: paramNames,
method: method,
path: exactPath,
handler: handler,
params: paramNames,
template: path,
})
}

Expand Down Expand Up @@ -208,6 +211,7 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {

ctx := req.Context()
ctx = context.WithValue(ctx, ParamsKey, params)
ctx = context.WithValue(ctx, RouteContextKey, &route)

handler := route.handler
for i := len(r.middleware) - 1; i >= 0; i-- {
Expand Down Expand Up @@ -256,3 +260,14 @@ the given order before executing the main handler.
func (r *Router) Use(middleware ...func(http.Handler) http.Handler) {
r.middleware = append(r.middleware, middleware...)
}

// CurrentRoute returns the matched route for the current request, if any.
// This only works when called inside the handler of the matched route
// because the matched route is stored inside the request's context,
// which is wiped after the handler returns.
func CurrentRoute(r *http.Request) *Route {
if rv := r.Context().Value(RouteContextKey); rv != nil {
return rv.(*Route)
}
return nil
}
100 changes: 100 additions & 0 deletions router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package muxer

import (
"context"
"errors"
"fmt"
"io"
"io/ioutil"
Expand Down Expand Up @@ -519,3 +520,102 @@ func TestEnableCORSOption(t *testing.T) {
})
}
}

func TestPathTemplate(t *testing.T) {
tests := []struct {
name string
route *Route
expectedOutput string
expectedError error
}{
{
name: "Error with nil Route",
route: nil,
expectedOutput: "",
expectedError: errors.New("route is nil, no template"),
},
{
name: "Error with empty template",
route: &Route{template: ""},
expectedOutput: "",
expectedError: errors.New("template is empty"),
},
{
name: "Valid Route with Template and path param",
route: &Route{template: "/users/:id"},
expectedOutput: "/users/:id",
expectedError: nil,
},
{
name: "Valid Route with simple Template",
route: &Route{template: "/metrics"},
expectedOutput: "/metrics",
expectedError: nil,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
output, err := tt.route.PathTemplate()

if tt.expectedOutput != output {
t.Errorf("expected output %v, got %v", tt.expectedOutput, output)
}
if tt.expectedError != nil {
if tt.expectedError.Error() != err.Error() {
t.Errorf("expected error %v, got %v", tt.expectedError, err)
}
} else {
if err != nil {
t.Errorf("expected error to be nil, got %v", err)
}
}
})
}
}

func TestCurrentRoute(t *testing.T) {
route := &Route{template: "/users/:id"}

tests := []struct {
name string
contextKey interface{}
contextValue interface{}
expectedRoute *Route
}{
{
name: "Route in context",
contextKey: RouteContextKey,
contextValue: route,
expectedRoute: route,
},
{
name: "No route in context",
contextKey: "some_other_key",
contextValue: "some_value",
expectedRoute: nil,
},
{
name: "Empty context",
contextKey: nil,
contextValue: nil,
expectedRoute: nil,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodGet, "/users/123", nil)

if tt.contextKey != nil {
req = req.WithContext(context.WithValue(req.Context(), tt.contextKey, tt.contextValue))
}

result := CurrentRoute(req)

if tt.expectedRoute != result {
t.Errorf("expected route %v got %v", tt.expectedRoute, result)
}
})
}
}

0 comments on commit c54b427

Please sign in to comment.