diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml new file mode 100644 index 0000000..2446707 --- /dev/null +++ b/.github/workflows/go.yml @@ -0,0 +1,31 @@ +# This workflow will build a golang project +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go + +name: Go Test + +on: + push: + branches: + - master + pull_request: + branches: + - master + +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v4 + with: + go-version: 1.21 + + - name: Run Test + run: go test -race -coverprofile=coverage.txt -covermode=atomic ./... + + - name: Upload coverage reports to Codecov + uses: codecov/codecov-action@v3 + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} diff --git a/README.md b/README.md index 5b75c5b..a337519 100644 --- a/README.md +++ b/README.md @@ -1 +1,356 @@ -# web \ No newline at end of file +# web +[![GoDoc][1]][2] [![Build Status][7]][8] [![Codecov][9]][10] [![Release][5]][6] [![license-Apache 2][3]][4] + +[1]: https://godoc.org/go-spring.dev/web?status.svg +[2]: https://godoc.org/go-spring.dev/web +[3]: https://img.shields.io/badge/license-Apache%202-blue.svg +[4]: LICENSE +[5]: https://img.shields.io/github/v/release/go-spring-projects/web?color=orange +[6]: https://github.com/go-spring-projects/web/releases/latest +[7]: https://github.com/go-spring-projects/web/workflows/Go%20Test/badge.svg?branch=master +[8]: https://github.com/go-spring-projects/web/actions?query=branch%3Amaster +[9]: https://codecov.io/gh/go-spring-projects/web/graph/badge.svg?token=BQ6OKWWOF0 +[10]: https://codecov.io/gh/go-spring-projects/web + +The `web` package aims to provide a simpler and more user-friendly development experience. + +*Note: This package does not depend on the go-spring* + +## Install + +`go get go-spring.dev/web@latest` + +## Features: + +* Automatically bind models based on `ContentType`. +* Automatically output based on function return type. +* binding from `path/header/cookie/form/body`. +* Support binding files for easier file uploads handling. +* Support customizing global output formats and route-level custom output. +* Support custom parameter validators. +* Support handler converter, adding the above capabilities with just one line of code for all http servers based on the standard library solution. +* Support for middlewares based on chain of responsibility. + +## Router + +web router is based on a kind of [Patricia Radix trie](https://en.wikipedia.org/wiki/Radix_tree). The router is compatible with net/http. + +Router interface: + +```go +// Router registers routes to be matched and dispatches a handler. +// +type Router interface { + // Handler dispatches the handler registered in the matched route. + http.Handler + + // Use appends a MiddlewareFunc to the chain. + Use(mwf ...MiddlewareFunc) Router + + // Renderer to be used Response renderer in default. + Renderer(renderer Renderer) Router + + // Group creates a new router group. + Group(pattern string, fn ...func(r Router)) Router + + // Handle registers a new route with a matcher for the URL pattern. + Handle(pattern string, handler http.Handler) + + // HandleFunc registers a new route with a matcher for the URL pattern. + HandleFunc(pattern string, handler http.HandlerFunc) + + // Any registers a route that matches all the HTTP methods. + // GET, POST, PUT, PATCH, HEAD, OPTIONS, DELETE, CONNECT, TRACE. + Any(pattern string, handler interface{}) + + // Get registers a new GET route with a matcher for the URL path of the get method. + Get(pattern string, handler interface{}) + + // Head registers a new HEAD route with a matcher for the URL path of the head method. + Head(pattern string, handler interface{}) + + // Post registers a new POST route with a matcher for the URL path of the post method. + Post(pattern string, handler interface{}) + + // Put registers a new PUT route with a matcher for the URL path of the put method. + Put(pattern string, handler interface{}) + + // Patch registers a new PATCH route with a matcher for the URL path of the patch method. + Patch(pattern string, handler interface{}) + + // Delete registers a new DELETE route with a matcher for the URL path of the delete method. + Delete(pattern string, handler interface{}) + + // Connect registers a new CONNECT route with a matcher for the URL path of the connect method. + Connect(pattern string, handler interface{}) + + // Options registers a new OPTIONS route with a matcher for the URL path of the options method. + Options(pattern string, handler interface{}) + + // Trace registers a new TRACE route with a matcher for the URL path of the trace method. + Trace(pattern string, handler interface{}) + + // NotFound to be used when no route matches. + NotFound(handler http.HandlerFunc) + + // MethodNotAllowed to be used when the request method does not match the route. + MethodNotAllowed(handler http.HandlerFunc) +} +``` + + +## Quick start + +### HelloWorld + +```go +package main + +import ( + "context" + "net/http" + + "go-spring.dev/web" +) + +func main() { + var router = web.NewRouter() + + router.Get("/greeting", func(ctx context.Context) string { + return "greeting!!!" + }) + + http.ListenAndServe(":8080", router) +} +``` + +### Adaptation standard library + +Supported function forms to be converted to `http.HandlerFunc`: + +```go +// Bind convert fn to HandlerFunc. +// +// func(ctx context.Context) +// +// func(ctx context.Context) R +// +// func(ctx context.Context) error +// +// func(ctx context.Context, req T) R +// +// func(ctx context.Context, req T) error +// +// func(ctx context.Context, req T) (R, error) +// +func Bind(fn interface{}, render Renderer) http.HandlerFunc +``` + +An example based std http server: + +```go +package main + +import ( + "context" + "log/slog" + "mime/multipart" + "net/http" + + "go-spring.dev/web" +) + +func main() { + http.Handle("/user/register", web.Bind(UserRegister, web.JsonRender())) + + http.ListenAndServe(":8080", nil) +} + +type UserRegisterModel struct { + Username string `form:"username"` // username + Password string `form:"password"` // password + Avatar *multipart.FileHeader `form:"avatar"` // avatar + Captcha string `form:"captcha"` // captcha + UserAgent string `header:"User-Agent"` // user agent + Ad string `query:"ad"` // advertising ID + Token string `cookie:"token"` // token +} + +func UserRegister(ctx context.Context, req UserRegisterModel) string { + slog.Info("user register", + slog.String("username", req.Username), + slog.String("password", req.Password), + slog.String("captcha", req.Captcha), + slog.String("userAgent", req.UserAgent), + slog.String("ad", req.Ad), + slog.String("token", req.Token), + ) + return "success" +} +``` + +### Custom validator + +Allows you to register a custom value validator. If the value verification fails, request processing aborts. + +In this example, we will use [go-validator/validator](https://github.com/go-validator/validator), you can refer to this example to register your custom validator. + +```go +package main + +import ( + "context" + "log/slog" + "mime/multipart" + "net/http" + + "go-spring.dev/web" + "go-spring.dev/web/binding" + "gopkg.in/validator.v2" +) + +var validatorInst = validator.NewValidator().WithTag("validate") + +func main() { + binding.RegisterValidator(func(i interface{}) error { + return validatorInst.Validate(i) + }) + + var router = web.NewRouter() + router.Post("/user/register", UserRegister) + + http.ListenAndServe(":8080", router) +} + +type UserRegisterModel struct { + Username string `form:"username" validate:"min=6,max=20"` // username + Password string `form:"password" validate:"min=10,max=20"` // password + Avatar *multipart.FileHeader `form:"avatar" validate:"nonzero"` // avatar + Captcha string `form:"captcha" validate:"min=4,max=4"` // captcha + UserAgent string `header:"User-Agent"` // user agent + Ad string `query:"ad"` // advertising ID + Token string `cookie:"token"` // token +} + +func UserRegister(ctx context.Context, req UserRegisterModel) string { + slog.Info("user register", + slog.String("username", req.Username), + slog.String("password", req.Password), + slog.String("captcha", req.Captcha), + slog.String("userAgent", req.UserAgent), + slog.String("ad", req.Ad), + slog.String("token", req.Token), + ) + return "success" +} + +``` + +### Middlewares + +Compatible with middlewares based on standard library solutions. + +```go +package main + +import ( + "context" + "log/slog" + "net/http" + "time" + + "go-spring.dev/web" +) + +func main() { + var router = web.NewRouter() + + // access log + router.Use(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + t1 := time.Now() + next.ServeHTTP(writer, request) + slog.Info("access log", slog.String("path", request.URL.Path), slog.String("method", request.Method), slog.Duration("cost", time.Since(t1))) + }) + }) + + // cors + router.Use(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + writer.Header().Set("Access-Control-Allow-Origin", "*") + writer.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS") + writer.Header().Set("Access-Control-Allow-Headers", "Origin, Content-Type") + + // preflight request + if request.Method == http.MethodOptions { + writer.WriteHeader(http.StatusNoContent) + return + } + + next.ServeHTTP(writer, request) + }) + }) + + router.Group("/public", func(r web.Router) { + r.Post("/register", func(ctx context.Context) string { return "register: do something" }) + r.Post("/forgot", func(ctx context.Context) string { return "forgot: do something" }) + r.Post("/login", func(ctx context.Context, req struct { + Username string `form:"username"` + Password string `form:"password"` + }) error { + if "admin" == req.Username && "admin123" == req.Password { + web.FromContext(ctx).SetCookie("token", req.Username, 600, "/", "", false, false) + return nil + } + return web.Error(400, "login failed") + }) + }) + + router.Group("/user", func(r web.Router) { + + // user login check + r.Use(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + // check login state in cookies + // + if _, err := request.Cookie("token"); nil != err { + writer.WriteHeader(http.StatusForbidden) + return + } + + // login check success + next.ServeHTTP(writer, request) + }) + }) + + r.Get("/userInfo", func(ctx context.Context) interface{} { + // TODO: load user from database + // + return map[string]interface{}{ + "username": "admin", + "time": time.Now().String(), + } + }) + + r.Get("/logout", func(ctx context.Context) string { + // delete cookie + web.FromContext(ctx).SetCookie("token", "", -1, "/", "", false, false) + return "success" + }) + + }) + + http.ListenAndServe(":8080", router) +} + +``` + +## Acknowledgments + +* https://github.com/go-chi/chi +* https://github.com/gin-gonic/gin +* https://github.com/lvan100 + +### License + +The repository released under version 2.0 of the Apache License. \ No newline at end of file diff --git a/bind.go b/bind.go new file mode 100644 index 0000000..d3f9487 --- /dev/null +++ b/bind.go @@ -0,0 +1,260 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package web + +import ( + "context" + "errors" + "fmt" + "net/http" + "reflect" + + "go-spring.dev/web/binding" +) + +type Renderer interface { + Render(ctx *Context, err error, result interface{}) +} + +type RendererFunc func(ctx *Context, err error, result interface{}) + +func (fn RendererFunc) Render(ctx *Context, err error, result interface{}) { + fn(ctx, err, result) +} + +// Bind convert fn to HandlerFunc. +// +// func(ctx context.Context) +// +// func(ctx context.Context) R +// +// func(ctx context.Context) error +// +// func(ctx context.Context, req T) R +// +// func(ctx context.Context, req T) error +// +// func(ctx context.Context, req T) (R, error) +// +// func(writer http.ResponseWriter, request *http.Request) +func Bind(fn interface{}, render Renderer) http.HandlerFunc { + + fnValue := reflect.ValueOf(fn) + fnType := fnValue.Type() + + switch h := fn.(type) { + case http.HandlerFunc: + return warpContext(h) + case http.Handler: + return warpContext(h.ServeHTTP) + case func(http.ResponseWriter, *http.Request): + return warpContext(h) + default: + // valid func + if err := validMappingFunc(fnType); nil != err { + panic(err) + } + } + + firstOutIsErrorType := 1 == fnType.NumOut() && isErrorType(fnType.Out(0)) + + return func(writer http.ResponseWriter, request *http.Request) { + + // param of context + webCtx := &Context{Writer: writer, Request: request} + ctx := WithContext(request.Context(), webCtx) + + defer func() { + if nil != request.MultipartForm { + _ = request.MultipartForm.RemoveAll() + } + _ = request.Body.Close() + }() + + var returnValues []reflect.Value + var err error + + defer func() { + if r := recover(); nil != r { + if e, ok := r.(error); ok { + err = fmt.Errorf("%s: %w", request.URL.Path, e) + } else { + err = fmt.Errorf("%s: %v", request.URL.Path, r) + } + + // render error response + render.Render(webCtx, err, nil) + } + }() + + ctxValue := reflect.ValueOf(ctx) + + switch fnType.NumIn() { + case 1: + returnValues = fnValue.Call([]reflect.Value{ctxValue}) + case 2: + paramType := fnType.In(1) + pointer := false + if reflect.Ptr == paramType.Kind() { + paramType = paramType.Elem() + pointer = true + } + + // new param instance with paramType. + paramValue := reflect.New(paramType) + // bind paramValue with request + if err = binding.Bind(paramValue.Interface(), webCtx); nil != err { + break + } + if !pointer { + paramValue = paramValue.Elem() + } + returnValues = fnValue.Call([]reflect.Value{ctxValue, paramValue}) + default: + panic("unreachable here") + } + + var result interface{} + + if nil == err { + switch len(returnValues) { + case 0: + // nothing + return + case 1: + if firstOutIsErrorType { + err, _ = returnValues[0].Interface().(error) + } else { + result = returnValues[0].Interface() + } + case 2: + // check error + result = returnValues[0].Interface() + err, _ = returnValues[1].Interface().(error) + default: + panic("unreachable here") + } + } + + // render response + render.Render(webCtx, err, result) + } +} + +func validMappingFunc(fnType reflect.Type) error { + // func(ctx context.Context) + // func(ctx context.Context) R + // func(ctx context.Context) error + // func(ctx context.Context, req T) R + // func(ctx context.Context, req T) error + // func(ctx context.Context, req T) (R, error) + if !isFuncType(fnType) { + return fmt.Errorf("%s: not a func", fnType.String()) + } + + if fnType.NumIn() < 1 || fnType.NumIn() > 2 { + return fmt.Errorf("%s: expect func(ctx context.Context, [T]) [R, error]", fnType.String()) + } + + if fnType.NumOut() > 2 { + return fmt.Errorf("%s: expect func(ctx context.Context, [T]) [(R, error)]", fnType.String()) + } + + if !isContextType(fnType.In(0)) { + return fmt.Errorf("%s: expect func(ctx context.Context, [T]) [(R, error)", fnType.String()) + } + + if fnType.NumIn() > 1 { + argType := fnType.In(1) + if !(reflect.Struct == argType.Kind() || (reflect.Ptr == argType.Kind() && reflect.Struct == argType.Elem().Kind())) { + return fmt.Errorf("%s: input param type (%s) must be struct/*struct", fnType.String(), argType.String()) + } + } + + switch fnType.NumOut() { + case 0: // nothing + case 1: // R | error + case 2: // (R, error) + if isErrorType(fnType.Out(0)) { + return fmt.Errorf("%s: expect func(...) (R, error)", fnType.String()) + } + + if !isErrorType(fnType.Out(1)) { + return fmt.Errorf("%s: expect func(...) (R, error)", fnType.String()) + } + } + + return nil +} + +func warpContext(handler http.HandlerFunc) http.HandlerFunc { + return func(writer http.ResponseWriter, request *http.Request) { + webCtx := &Context{Writer: writer, Request: request} + handler.ServeHTTP(writer, request.WithContext(WithContext(request.Context(), webCtx))) + } +} + +// errorType the reflection type of error. +var errorType = reflect.TypeOf((*error)(nil)).Elem() + +// contextType the reflection type of context.Context. +var contextType = reflect.TypeOf((*context.Context)(nil)).Elem() + +// IsFuncType returns whether `t` is func type. +func isFuncType(t reflect.Type) bool { + return t.Kind() == reflect.Func +} + +// IsErrorType returns whether `t` is error type. +func isErrorType(t reflect.Type) bool { + return t == errorType || t.Implements(errorType) +} + +// IsContextType returns whether `t` is context.Context type. +func isContextType(t reflect.Type) bool { + return t == contextType || t.Implements(contextType) +} + +// JsonRender is default Render +func JsonRender() RendererFunc { + return func(ctx *Context, err error, result interface{}) { + var code = 0 + var message = "" + if nil != err { + var e HttpError + if errors.As(err, &e) { + code = e.Code + message = e.Message + } else { + code = http.StatusInternalServerError + message = err.Error() + + if errors.Is(err, binding.ErrBinding) || errors.Is(err, binding.ErrValidate) { + code = http.StatusBadRequest + } + } + } + + type response struct { + Code int `json:"code"` + Message string `json:"message,omitempty"` + Data interface{} `json:"data"` + } + + ctx.JSON(http.StatusOK, response{Code: code, Message: message, Data: result}) + } +} diff --git a/bind_test.go b/bind_test.go new file mode 100644 index 0000000..2c20d7d --- /dev/null +++ b/bind_test.go @@ -0,0 +1,80 @@ +package web + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "os" + "reflect" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIsErrorType(t *testing.T) { + err := fmt.Errorf("error") + assert.True(t, isErrorType(reflect.TypeOf(err))) + err = os.ErrClosed + assert.True(t, isErrorType(reflect.TypeOf(err))) +} + +func TestIsContextType(t *testing.T) { + ctx := context.TODO() + assert.True(t, isContextType(reflect.TypeOf(ctx))) + ctx = context.WithValue(context.TODO(), "a", "3") + assert.True(t, isContextType(reflect.TypeOf(ctx))) +} + +func TestBindWithoutParams(t *testing.T) { + + var handler = func(ctx context.Context) string { + webCtx := FromContext(ctx) + assert.NotNil(t, webCtx) + return "0987654321" + } + + request := httptest.NewRequest(http.MethodGet, "/get", strings.NewReader("{}")) + response := httptest.NewRecorder() + Bind(handler, JsonRender())(response, request) + assert.Equal(t, response.Body.String(), "{\"code\":0,\"data\":\"0987654321\"}\n") +} + +func TestBindWithParams(t *testing.T) { + var handler = func(ctx context.Context, req struct { + Username string `json:"username"` + Password string `json:"password"` + }) string { + webCtx := FromContext(ctx) + assert.NotNil(t, webCtx) + assert.Equal(t, req.Username, "aaa") + assert.Equal(t, req.Password, "88888888") + return "success" + } + + request := httptest.NewRequest(http.MethodPost, "/post", strings.NewReader(`{"username": "aaa", "password": "88888888"}`)) + request.Header.Add("Content-Type", "application/json") + response := httptest.NewRecorder() + Bind(handler, JsonRender())(response, request) + assert.Equal(t, response.Body.String(), "{\"code\":0,\"data\":\"success\"}\n") +} + +func TestBindWithParamsAndError(t *testing.T) { + var handler = func(ctx context.Context, req struct { + Username string `json:"username"` + Password string `json:"password"` + }) (string, error) { + webCtx := FromContext(ctx) + assert.NotNil(t, webCtx) + assert.Equal(t, req.Username, "aaa") + assert.Equal(t, req.Password, "88888888") + return "requestid: 9999999", Error(403, "user locked") + } + + request := httptest.NewRequest(http.MethodPost, "/post", strings.NewReader(`{"username": "aaa", "password": "88888888"}`)) + request.Header.Add("Content-Type", "application/json") + response := httptest.NewRecorder() + Bind(handler, JsonRender())(response, request) + assert.Equal(t, response.Body.String(), "{\"code\":403,\"message\":\"user locked\",\"data\":\"requestid: 9999999\"}\n") +} diff --git a/binding/binding.go b/binding/binding.go new file mode 100644 index 0000000..933e199 --- /dev/null +++ b/binding/binding.go @@ -0,0 +1,210 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package binding ... +package binding + +import ( + "errors" + "fmt" + "io" + "mime" + "mime/multipart" + "net/url" + "reflect" + "strconv" + "strings" +) + +var ErrBinding = errors.New("binding failed") +var ErrValidate = errors.New("validate failed") + +const ( + MIMEApplicationJSON = "application/json" + MIMEApplicationXML = "application/xml" + MIMETextXML = "text/xml" + MIMEApplicationForm = "application/x-www-form-urlencoded" + MIMEMultipartForm = "multipart/form-data" +) + +type Request interface { + ContentType() string + Header(key string) (string, bool) + Cookie(name string) (string, bool) + PathParam(name string) (string, bool) + QueryParam(name string) (string, bool) + FormParams() (url.Values, error) + MultipartParams(maxMemory int64) (*multipart.Form, error) + RequestBody() io.Reader +} + +type BindScope int + +const ( + BindScopeURI BindScope = iota + BindScopeQuery + BindScopeHeader + BindScopeCookie + BindScopeBody +) + +var scopeTags = map[BindScope]string{ + BindScopeURI: "path", + BindScopeQuery: "query", + BindScopeHeader: "header", + BindScopeCookie: "cookie", +} + +var scopeGetters = map[BindScope]func(r Request, name string) (string, bool){ + BindScopeURI: Request.PathParam, + BindScopeQuery: Request.QueryParam, + BindScopeHeader: Request.Header, + BindScopeCookie: Request.Cookie, +} + +// ValidateStruct validates a single struct. +var validateStruct func(i interface{}) error + +type BodyBinder func(i interface{}, r Request) error + +var bodyBinders = map[string]BodyBinder{ + MIMEApplicationForm: BindForm, + MIMEMultipartForm: BindMultipartForm, + MIMEApplicationJSON: BindJSON, + MIMEApplicationXML: BindXML, + MIMETextXML: BindXML, +} + +func RegisterScopeTag(scope BindScope, tag string) { + scopeTags[scope] = tag +} + +func RegisterBodyBinder(mime string, binder BodyBinder) { + bodyBinders[mime] = binder +} + +func RegisterValidator(validator func(i interface{}) error) { + validateStruct = validator +} + +// Bind checks the Method and Content-Type to select a binding engine automatically, +// Depending on the "Content-Type" header different bindings are used, for example: +// +// "application/json" --> JSON binding +// "application/xml" --> XML binding +func Bind(i interface{}, r Request) error { + if err := bindScope(i, r); err != nil { + return fmt.Errorf("%w: %v", ErrBinding, err) + } + + if err := bindBody(i, r); err != nil { + return fmt.Errorf("%w: %v", ErrBinding, err) + } + + if nil != validateStruct { + if err := validateStruct(i); nil != err { + return fmt.Errorf("%w: %v", ErrValidate, err) + } + } + return nil +} + +func bindBody(i interface{}, r Request) error { + mediaType, _, err := mime.ParseMediaType(r.ContentType()) + if nil != err && !strings.Contains(err.Error(), "mime: no media type") { + return err + } + binder, ok := bodyBinders[mediaType] + if !ok { + binder = bodyBinders[MIMEApplicationForm] + } + return binder(i, r) +} + +func bindScope(i interface{}, r Request) error { + t := reflect.TypeOf(i) + if t.Kind() != reflect.Ptr { + return fmt.Errorf("%s: is not pointer", t.String()) + } + + et := t.Elem() + if et.Kind() != reflect.Struct { + return fmt.Errorf("%s: is not a struct pointer", t.String()) + } + + ev := reflect.ValueOf(i).Elem() + for j := 0; j < ev.NumField(); j++ { + fv := ev.Field(j) + ft := et.Field(j) + for scope := BindScopeURI; scope < BindScopeBody; scope++ { + if err := bindScopeField(scope, fv, ft, r); err != nil { + return err + } + } + } + return nil +} + +func bindScopeField(scope BindScope, v reflect.Value, field reflect.StructField, r Request) error { + if tag, loaded := scopeTags[scope]; loaded { + if name, ok := field.Tag.Lookup(tag); ok && name != "-" { + if val, exists := scopeGetters[scope](r, name); exists { + if err := bindData(v, val); err != nil { + return err + } + } + } + } + return nil +} + +func bindData(v reflect.Value, val string) error { + switch v.Kind() { + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + u, err := strconv.ParseUint(val, 0, 0) + if err != nil { + return err + } + v.SetUint(u) + return nil + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + i, err := strconv.ParseInt(val, 0, 0) + if err != nil { + return err + } + v.SetInt(i) + return nil + case reflect.Float32, reflect.Float64: + f, err := strconv.ParseFloat(val, 64) + if err != nil { + return err + } + v.SetFloat(f) + return nil + case reflect.Bool: + b, err := strconv.ParseBool(val) + if err != nil { + return err + } + v.SetBool(b) + return nil + case reflect.String: + v.SetString(val) + return nil + default: + return fmt.Errorf("unsupported binding type %q", v.Type().String()) + } +} diff --git a/binding/binding_test.go b/binding/binding_test.go new file mode 100644 index 0000000..aad1d31 --- /dev/null +++ b/binding/binding_test.go @@ -0,0 +1,121 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package binding_test + +import ( + "fmt" + "io" + "mime/multipart" + "net/url" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "go-spring.dev/web/binding" +) + +type MockRequest struct { + contentType string + headers map[string]string + queryParams map[string]string + pathParams map[string]string + cookies map[string]string + formParams url.Values + requestBody string +} + +var _ binding.Request = &MockRequest{} + +func (r *MockRequest) ContentType() string { + return r.contentType +} + +func (r *MockRequest) Header(key string) (string, bool) { + value, ok := r.headers[key] + return value, ok +} + +func (r *MockRequest) Cookie(name string) (string, bool) { + value, ok := r.cookies[name] + return value, ok +} + +func (r *MockRequest) QueryParam(name string) (string, bool) { + value, ok := r.queryParams[name] + return value, ok +} + +func (r *MockRequest) PathParam(name string) (string, bool) { + value, ok := r.pathParams[name] + return value, ok +} + +func (r *MockRequest) FormParams() (url.Values, error) { + return r.formParams, nil +} + +func (r *MockRequest) MultipartParams(maxMemory int64) (*multipart.Form, error) { + return nil, fmt.Errorf("not impl") +} + +func (r *MockRequest) RequestBody() io.Reader { + return strings.NewReader(r.requestBody) +} + +type ScopeBindParam struct { + A string `path:"a"` + B string `path:"b"` + C string `path:"c" query:"c"` + D string `query:"d"` + E string `query:"e" header:"e"` + F string `cookie:"f"` +} + +func TestScopeBind(t *testing.T) { + + ctx := &MockRequest{ + headers: map[string]string{ + "e": "6", + }, + queryParams: map[string]string{ + "c": "3", + "d": "4", + "e": "5", + }, + pathParams: map[string]string{ + "a": "1", + "b": "2", + }, + cookies: map[string]string{ + "f": "7", + }, + } + + expect := ScopeBindParam{ + A: "1", + B: "2", + C: "3", + D: "4", + E: "6", + F: "7", + } + + var p ScopeBindParam + err := binding.Bind(&p, ctx) + assert.Nil(t, err) + assert.Equal(t, expect, p) +} diff --git a/binding/form.go b/binding/form.go new file mode 100644 index 0000000..ead7be2 --- /dev/null +++ b/binding/form.go @@ -0,0 +1,161 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package binding + +import ( + "mime/multipart" + "net/url" + "reflect" +) + +var fileHeaderType = reflect.TypeOf((*multipart.FileHeader)(nil)) + +func BindForm(i interface{}, r Request) error { + params, err := r.FormParams() + if err != nil { + return err + } + t := reflect.TypeOf(i) + if t.Kind() != reflect.Ptr { + return nil + } + et := t.Elem() + if et.Kind() != reflect.Struct { + return nil + } + ev := reflect.ValueOf(i).Elem() + return bindFormStruct(ev, et, params) +} + +func bindFormStruct(v reflect.Value, t reflect.Type, params url.Values) error { + for j := 0; j < t.NumField(); j++ { + ft := t.Field(j) + fv := v.Field(j) + if ft.Anonymous { + if ft.Type.Kind() != reflect.Struct { + continue + } + if err := bindFormStruct(fv, ft.Type, params); nil != err { + return err + } + continue + } + name, ok := ft.Tag.Lookup("form") + if !ok || !fv.CanInterface() { + continue + } + values := params[name] + if len(values) == 0 { + continue + } + err := bindFormField(fv, ft.Type, values) + if err != nil { + return err + } + } + return nil +} + +func bindFormField(v reflect.Value, t reflect.Type, values []string) error { + if v.Kind() == reflect.Slice { + slice := reflect.MakeSlice(t, 0, len(values)) + defer func() { v.Set(slice) }() + et := t.Elem() + for _, value := range values { + ev := reflect.New(et).Elem() + if err := bindData(ev, value); nil != err { + return err + } + slice = reflect.Append(slice, ev) + } + return nil + } + return bindData(v, values[0]) +} + +func BindMultipartForm(i interface{}, r Request) error { + const defaultMaxMemory = 32 << 20 // 32 MB + form, err := r.MultipartParams(defaultMaxMemory) + if nil != err { + return err + } + + t := reflect.TypeOf(i) + if t.Kind() != reflect.Ptr { + return nil + } + et := t.Elem() + if et.Kind() != reflect.Struct { + return nil + } + ev := reflect.ValueOf(i).Elem() + return bindMultipartFormStruct(ev, et, form) +} + +func bindMultipartFormStruct(v reflect.Value, t reflect.Type, form *multipart.Form) error { + for j := 0; j < t.NumField(); j++ { + ft := t.Field(j) + fv := v.Field(j) + if ft.Anonymous { + if ft.Type.Kind() != reflect.Struct { + continue + } + if err := bindMultipartFormStruct(fv, ft.Type, form); nil != err { + return err + } + continue + } + name, ok := ft.Tag.Lookup("form") + if !ok || !fv.CanInterface() { + continue + } + + if ft.Type == fileHeaderType || (reflect.Slice == ft.Type.Kind() && ft.Type.Elem() == fileHeaderType) { + files := form.File[name] + if len(files) == 0 { + continue + } + if err := bindMultipartFormFiles(fv, ft.Type, files); nil != err { + return err + } + } else { + values := form.Value[name] + if len(values) == 0 { + continue + } + if err := bindFormField(fv, ft.Type, values); nil != err { + return err + } + } + + } + return nil +} + +func bindMultipartFormFiles(v reflect.Value, t reflect.Type, files []*multipart.FileHeader) error { + if v.Kind() == reflect.Slice { + slice := reflect.MakeSlice(t, 0, len(files)) + defer func() { v.Set(slice) }() + for _, file := range files { + slice = reflect.Append(slice, reflect.ValueOf(file)) + } + return nil + } + + v.Set(reflect.ValueOf(files[0])) + return nil +} diff --git a/binding/form_test.go b/binding/form_test.go new file mode 100644 index 0000000..d13f4a0 --- /dev/null +++ b/binding/form_test.go @@ -0,0 +1,62 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package binding_test + +import ( + "net/url" + "testing" + + "github.com/stretchr/testify/assert" + "go-spring.dev/web/binding" +) + +type FormBindParamCommon struct { + A string `form:"a"` + B []string `form:"b"` +} + +type FormBindParam struct { + FormBindParamCommon + C int `form:"c"` + D []int `form:"d"` +} + +func TestBindForm(t *testing.T) { + + ctx := &MockRequest{ + formParams: url.Values{ + "a": {"1"}, + "b": {"2", "3"}, + "c": {"4"}, + "d": {"5", "6"}, + }, + } + + expect := FormBindParam{ + FormBindParamCommon: FormBindParamCommon{ + A: "1", + B: []string{"2", "3"}, + }, + C: 4, + D: []int{5, 6}, + } + + var p FormBindParam + err := binding.Bind(&p, ctx) + assert.Nil(t, err) + assert.Equal(t, expect, p) +} diff --git a/binding/json.go b/binding/json.go new file mode 100644 index 0000000..1d13f02 --- /dev/null +++ b/binding/json.go @@ -0,0 +1,26 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package binding + +import ( + "encoding/json" +) + +func BindJSON(i interface{}, r Request) error { + decoder := json.NewDecoder(r.RequestBody()) + return decoder.Decode(i) +} diff --git a/binding/json_test.go b/binding/json_test.go new file mode 100644 index 0000000..db01d93 --- /dev/null +++ b/binding/json_test.go @@ -0,0 +1,68 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package binding_test + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "go-spring.dev/web/binding" +) + +type JSONBindParamCommon struct { + A string `json:"a"` + B []string `json:"b"` +} + +type JSONBindParam struct { + JSONBindParamCommon + C int `json:"c"` + D []int `json:"d"` +} + +func TestBindJSON(t *testing.T) { + + data, err := json.Marshal(map[string]interface{}{ + "a": "1", + "b": []string{"2", "3"}, + "c": 4, + "d": []int64{5, 6}, + }) + if err != nil { + t.Fatal(err) + } + + ctx := &MockRequest{ + contentType: binding.MIMEApplicationJSON, + requestBody: string(data), + } + + expect := JSONBindParam{ + JSONBindParamCommon: JSONBindParamCommon{ + A: "1", + B: []string{"2", "3"}, + }, + C: 4, + D: []int{5, 6}, + } + + var p JSONBindParam + err = binding.Bind(&p, ctx) + assert.Nil(t, err) + assert.Equal(t, expect, p) +} diff --git a/binding/xml.go b/binding/xml.go new file mode 100644 index 0000000..0100d3d --- /dev/null +++ b/binding/xml.go @@ -0,0 +1,26 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package binding + +import ( + "encoding/xml" +) + +func BindXML(i interface{}, r Request) error { + decoder := xml.NewDecoder(r.RequestBody()) + return decoder.Decode(i) +} diff --git a/binding/xml_test.go b/binding/xml_test.go new file mode 100644 index 0000000..2690bfb --- /dev/null +++ b/binding/xml_test.go @@ -0,0 +1,68 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package binding_test + +import ( + "encoding/xml" + "testing" + + "github.com/stretchr/testify/assert" + "go-spring.dev/web/binding" +) + +type XMLBindParamCommon struct { + A string `xml:"a"` + B []string `xml:"b"` +} + +type XMLBindParam struct { + XMLBindParamCommon + C int `xml:"c"` + D []int `xml:"d"` +} + +func TestBindXML(t *testing.T) { + + data, err := xml.Marshal(&XMLBindParam{ + XMLBindParamCommon: XMLBindParamCommon{ + A: "1", + B: []string{"2", "3"}, + }, + C: 4, + D: []int{5, 6}, + }) + assert.Nil(t, err) + + r := &MockRequest{ + contentType: binding.MIMEApplicationXML, + requestBody: string(data), + } + + expect := XMLBindParam{ + XMLBindParamCommon: XMLBindParamCommon{ + A: "1", + B: []string{"2", "3"}, + }, + C: 4, + D: []int{5, 6}, + } + + var p XMLBindParam + err = binding.Bind(&p, r) + assert.Nil(t, err) + assert.Equal(t, expect, p) +} diff --git a/context.go b/context.go new file mode 100644 index 0000000..967805c --- /dev/null +++ b/context.go @@ -0,0 +1,412 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package web + +import ( + "context" + "fmt" + "io" + "mime/multipart" + "net" + "net/http" + "net/textproto" + "net/url" + "strings" + "unicode" + + "go-spring.dev/web/binding" + "go-spring.dev/web/render" +) + +type contextKey struct{} + +func WithContext(parent context.Context, ctx *Context) context.Context { + return context.WithValue(parent, contextKey{}, ctx) +} + +func FromContext(ctx context.Context) *Context { + if v := ctx.Value(contextKey{}); v != nil { + return v.(*Context) + } + return nil +} + +type Context struct { + // A ResponseWriter interface is used by an HTTP handler to + // construct an HTTP response. + Writer http.ResponseWriter + + // A Request represents an HTTP request received by a server + // or to be sent by a client. + Request *http.Request + + // SameSite allows a server to define a cookie attribute making it impossible for + // the browser to send this cookie along with cross-site requests. + sameSite http.SameSite +} + +// Context returns the request's context. +func (c *Context) Context() context.Context { + return c.Request.Context() +} + +// ContentType returns the request header `Content-Type`. +func (c *Context) ContentType() string { + contentType := c.Request.Header.Get("Content-Type") + return contentType +} + +// Header returns the named header in the request. +func (c *Context) Header(key string) (string, bool) { + if values, ok := c.Request.Header[textproto.CanonicalMIMEHeaderKey(key)]; ok && len(values) > 0 { + return values[0], true + } + return "", false +} + +// Cookie returns the named cookie provided in the request. +func (c *Context) Cookie(name string) (string, bool) { + cookie, err := c.Request.Cookie(name) + if err != nil { + return "", false + } + if val, err := url.QueryUnescape(cookie.Value); nil == err { + return val, true + } + return cookie.Value, true +} + +// PathParam returns the named variables in the request. +func (c *Context) PathParam(name string) (string, bool) { + if ctx := FromRouteContext(c.Request.Context()); nil != ctx { + return ctx.URLParams.Get(name) + } + return "", false +} + +// QueryParam returns the named query in the request. +func (c *Context) QueryParam(name string) (string, bool) { + if values := c.Request.URL.Query(); nil != values { + if value, ok := values[name]; ok && len(value) > 0 { + return value[0], true + } + } + return "", false +} + +// FormParams returns the form in the request. +func (c *Context) FormParams() (url.Values, error) { + if err := c.Request.ParseForm(); nil != err { + return nil, err + } + return c.Request.Form, nil +} + +// MultipartParams returns a request body as multipart/form-data. +// The whole request body is parsed and up to a total of maxMemory bytes of its file parts are stored in memory, with the remainder stored on disk in temporary files. +func (c *Context) MultipartParams(maxMemory int64) (*multipart.Form, error) { + if !strings.Contains(c.ContentType(), binding.MIMEMultipartForm) { + return nil, fmt.Errorf("require `multipart/form-data` request") + } + + if nil == c.Request.MultipartForm { + if err := c.Request.ParseMultipartForm(maxMemory); nil != err { + return nil, err + } + } + return c.Request.MultipartForm, nil +} + +// RequestBody returns the request body. +func (c *Context) RequestBody() io.Reader { + return c.Request.Body +} + +// IsWebsocket returns true if the request headers indicate that a websocket +// handshake is being initiated by the client. +func (c *Context) IsWebsocket() bool { + if strings.Contains(strings.ToLower(c.Request.Header.Get("Connection")), "upgrade") && + strings.EqualFold(c.Request.Header.Get("Upgrade"), "websocket") { + return true + } + return false +} + +// SetSameSite with cookie +func (c *Context) SetSameSite(samesite http.SameSite) { + c.sameSite = samesite +} + +// Status sets the HTTP response code. +func (c *Context) Status(code int) { + c.Writer.WriteHeader(code) +} + +// SetHeader is an intelligent shortcut for c.Writer.Header().Set(key, value). +// It writes a header in the response. +// If value == "", this method removes the header `c.Writer.Header().Del(key)` +func (c *Context) SetHeader(key, value string) { + if value == "" { + c.Writer.Header().Del(key) + return + } + c.Writer.Header().Set(key, value) +} + +// SetCookie adds a Set-Cookie header to the ResponseWriter's headers. +// The provided cookie must have a valid Name. Invalid cookies may be +// silently dropped. +func (c *Context) SetCookie(name, value string, maxAge int, path, domain string, secure, httpOnly bool) { + if path == "" { + path = "/" + } + http.SetCookie(c.Writer, &http.Cookie{ + Name: name, + Value: url.QueryEscape(value), + MaxAge: maxAge, + Path: path, + Domain: domain, + SameSite: c.sameSite, + Secure: secure, + HttpOnly: httpOnly, + }) +} + +// Bind checks the Method and Content-Type to select a binding engine automatically, +// Depending on the "Content-Type" header different bindings are used, for example: +// +// "application/json" --> JSON binding +// "application/xml" --> XML binding +func (c *Context) Bind(r interface{}) error { + return binding.Bind(r, c) +} + +// Render writes the response headers and calls render.Render to render data. +func (c *Context) Render(code int, render render.Renderer) error { + if code > 0 { + if len(c.Writer.Header().Get("Content-Type")) <= 0 { + if contentType := render.ContentType(); len(contentType) > 0 { + c.Writer.Header().Set("Content-Type", contentType) + } + } + c.Writer.WriteHeader(code) + } + + if !bodyAllowedForStatus(code) { + return nil + } + + return render.Render(c.Writer) +} + +// Redirect returns an HTTP redirect to the specific location. +func (c *Context) Redirect(code int, location string) error { + return c.Render(-1, render.RedirectRenderer{Code: code, Request: c.Request, Location: location}) +} + +// String writes the given string into the response body. +func (c *Context) String(code int, format string, args ...interface{}) error { + return c.Render(code, render.TextRenderer{Format: format, Args: args}) +} + +// Data writes some data into the body stream and updates the HTTP code. +func (c *Context) Data(code int, contentType string, data []byte) error { + return c.Render(code, render.BinaryRenderer{DataType: contentType, Data: data}) +} + +// JSON serializes the given struct as JSON into the response body. +// It also sets the Content-Type as "application/json". +func (c *Context) JSON(code int, obj interface{}) error { + return c.Render(code, render.JsonRenderer{Data: obj}) +} + +// IndentedJSON serializes the given struct as pretty JSON (indented + endlines) into the response body. +// It also sets the Content-Type as "application/json". +func (c *Context) IndentedJSON(code int, obj interface{}) error { + return c.Render(code, render.JsonRenderer{Data: obj, Indent: " "}) +} + +// XML serializes the given struct as XML into the response body. +// It also sets the Content-Type as "application/xml". +func (c *Context) XML(code int, obj interface{}) error { + return c.Render(code, render.XmlRenderer{Data: obj}) +} + +// IndentedXML serializes the given struct as pretty XML (indented + endlines) into the response body. +// It also sets the Content-Type as "application/xml". +func (c *Context) IndentedXML(code int, obj interface{}) error { + return c.Render(code, render.XmlRenderer{Data: obj, Indent: " "}) +} + +// File writes the specified file into the body stream in an efficient way. +func (c *Context) File(filepath string) { + http.ServeFile(c.Writer, c.Request, filepath) +} + +// FileAttachment writes the specified file into the body stream in an efficient way +// On the client side, the file will typically be downloaded with the given filename +func (c *Context) FileAttachment(filepath, filename string) { + if isASCII(filename) { + c.Writer.Header().Set("Content-Disposition", `attachment; filename="`+escapeQuotes(filename)+`"`) + } else { + c.Writer.Header().Set("Content-Disposition", `attachment; filename*=UTF-8''`+url.QueryEscape(filename)) + } + http.ServeFile(c.Writer, c.Request, filepath) +} + +// RemoteIP parses the IP from Request.RemoteAddr, normalizes and returns the IP (without the port). +func (c *Context) RemoteIP() string { + ip, _, err := net.SplitHostPort(strings.TrimSpace(c.Request.RemoteAddr)) + if err != nil { + return "" + } + return ip +} + +// ClientIP implements one best effort algorithm to return the real client IP. +// It calls c.RemoteIP() under the hood, to check if the remote IP is a trusted proxy or not. +// If it is it will then try to parse the headers defined in RemoteIPHeaders (defaulting to [X-Forwarded-For, X-Real-Ip]). +// If the headers are not syntactically valid OR the remote IP does not correspond to a trusted proxy, +// the remote IP (coming from Request.RemoteAddr) is returned. +func (c *Context) ClientIP() string { + // It also checks if the remoteIP is a trusted proxy or not. + // In order to perform this validation, it will see if the IP is contained within at least one of the CIDR blocks + // defined by Engine.SetTrustedProxies() + remoteIP := net.ParseIP(c.RemoteIP()) + if remoteIP == nil { + return "" + } + + for _, headerName := range []string{"X-Forwarded-For", "X-Real-Ip"} { + if ns := strings.Split(c.Request.Header.Get(headerName), ","); len(ns) > 0 && len(ns[0]) > 0 { + return ns[0] + } + } + return remoteIP.String() +} + +type routeContextKey struct{} + +func WithRouteContext(parent context.Context, ctx *RouteContext) context.Context { + return context.WithValue(parent, routeContextKey{}, ctx) +} + +func FromRouteContext(ctx context.Context) *RouteContext { + if v := ctx.Value(routeContextKey{}); v != nil { + return v.(*RouteContext) + } + return nil +} + +type RouteContext struct { + Routes Routes + // URLParams are the stack of routeParams captured during the + // routing lifecycle across a stack of sub-routers. + URLParams RouteParams + + // routeParams matched for the current sub-router. It is + // intentionally unexported so it can't be tampered. + routeParams RouteParams + + // Routing path/method override used during the route search. + RoutePath string + RouteMethod string + + // The endpoint routing pattern that matched the request URI path + // or `RoutePath` of the current sub-router. This value will update + // during the lifecycle of a request passing through a stack of + // sub-routers. + RoutePattern string + routePatterns []string + + methodNotAllowed bool + methodsAllowed []methodTyp +} + +// Reset context to initial state +func (c *RouteContext) Reset() { + c.Routes = nil + c.RoutePath = "" + c.RouteMethod = "" + c.RoutePattern = "" + c.routePatterns = c.routePatterns[:0] + c.URLParams.Keys = c.URLParams.Keys[:0] + c.URLParams.Values = c.URLParams.Values[:0] + c.routeParams.Keys = c.routeParams.Keys[:0] + c.routeParams.Values = c.routeParams.Values[:0] + c.methodNotAllowed = false + c.methodsAllowed = c.methodsAllowed[:0] +} + +// RouteParams is a structure to track URL routing parameters efficiently. +type RouteParams struct { + Keys, Values []string +} + +// Add will append a URL parameter to the end of the route param +func (s *RouteParams) Add(key, value string) { + s.Keys = append(s.Keys, key) + s.Values = append(s.Values, value) +} + +func (s *RouteParams) Get(key string) (value string, ok bool) { + for i := len(s.Keys) - 1; i >= 0; i-- { + if s.Keys[i] == key { + return s.Values[i], true + } + } + return "", false +} + +// https://stackoverflow.com/questions/53069040/checking-a-string-contains-only-ascii-characters +func isASCII(s string) bool { + for i := 0; i < len(s); i++ { + if s[i] > unicode.MaxASCII { + return false + } + } + return true +} + +var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"") + +func escapeQuotes(s string) string { + return quoteEscaper.Replace(s) +} + +// bodyAllowedForStatus is a copy of http.bodyAllowedForStatus non-exported function. +func bodyAllowedForStatus(status int) bool { + switch { + case status >= 100 && status <= 199: + return false + case status == http.StatusNoContent: + return false + case status == http.StatusNotModified: + return false + } + return true +} + +func notFound() http.Handler { + return http.NotFoundHandler() +} + +func notAllowed() http.Handler { + return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + http.Error(writer, "405 method not allowed", http.StatusMethodNotAllowed) + }) +} diff --git a/context_test.go b/context_test.go new file mode 100644 index 0000000..c17f797 --- /dev/null +++ b/context_test.go @@ -0,0 +1,15 @@ +package web + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestContextBodyAllowedForStatus(t *testing.T) { + assert.False(t, false, bodyAllowedForStatus(http.StatusProcessing)) + assert.False(t, false, bodyAllowedForStatus(http.StatusNoContent)) + assert.False(t, false, bodyAllowedForStatus(http.StatusNotModified)) + assert.True(t, true, bodyAllowedForStatus(http.StatusInternalServerError)) +} diff --git a/error.go b/error.go new file mode 100644 index 0000000..2f3b601 --- /dev/null +++ b/error.go @@ -0,0 +1,39 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package web + +import ( + "fmt" + "net/http" +) + +type HttpError struct { + Code int + Message string +} + +func (e HttpError) Error() string { + return fmt.Sprintf("%d: %s", e.Code, e.Message) +} + +func Error(code int, format string, args ...interface{}) HttpError { + var message = http.StatusText(code) + if len(format) > 0 { + message = fmt.Sprintf(format, args...) + } + return HttpError{Code: code, Message: message} +} diff --git a/examples/go.mod b/examples/go.mod new file mode 100644 index 0000000..480d06d --- /dev/null +++ b/examples/go.mod @@ -0,0 +1,10 @@ +module examples + +go 1.21 + +replace go-spring.dev/web => ../ + +require ( + go-spring.dev/web v0.0.0-00010101000000-000000000000 + gopkg.in/validator.v2 v2.0.1 +) diff --git a/examples/go.sum b/examples/go.sum new file mode 100644 index 0000000..c99a5fa --- /dev/null +++ b/examples/go.sum @@ -0,0 +1,16 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/validator.v2 v2.0.1 h1:xF0KWyGWXm/LM2G1TrEjqOu4pa6coO9AlWSf3msVfDY= +gopkg.in/validator.v2 v2.0.1/go.mod h1:lIUZBlB3Im4s/eYp39Ry/wkR02yOPhZ9IwIRBjuPuG8= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/examples/greeting/main.go b/examples/greeting/main.go new file mode 100644 index 0000000..ebfc75f --- /dev/null +++ b/examples/greeting/main.go @@ -0,0 +1,18 @@ +package main + +import ( + "context" + "net/http" + + "go-spring.dev/web" +) + +func main() { + var router = web.NewRouter() + + router.Get("/greeting", func(ctx context.Context) string { + return "greeting!!!" + }) + + http.ListenAndServe(":8080", router) +} diff --git a/examples/middleware/main.go b/examples/middleware/main.go new file mode 100644 index 0000000..4114d26 --- /dev/null +++ b/examples/middleware/main.go @@ -0,0 +1,91 @@ +package main + +import ( + "context" + "log/slog" + "net/http" + "time" + + "go-spring.dev/web" +) + +func main() { + var router = web.NewRouter() + + // access log + router.Use(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + t1 := time.Now() + next.ServeHTTP(writer, request) + slog.Info("access log", slog.String("path", request.URL.Path), slog.String("method", request.Method), slog.Duration("cost", time.Since(t1))) + }) + }) + + // cors + router.Use(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + writer.Header().Set("Access-Control-Allow-Origin", "*") + writer.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS") + writer.Header().Set("Access-Control-Allow-Headers", "Origin, Content-Type") + + // preflight request + if request.Method == http.MethodOptions { + writer.WriteHeader(http.StatusNoContent) + return + } + + next.ServeHTTP(writer, request) + }) + }) + + router.Group("/public", func(r web.Router) { + r.Post("/register", func(ctx context.Context) string { return "register: do something" }) + r.Post("/forgot", func(ctx context.Context) string { return "forgot: do something" }) + r.Post("/login", func(ctx context.Context, req struct { + Username string `form:"username"` + Password string `form:"password"` + }) error { + if "admin" == req.Username && "admin123" == req.Password { + web.FromContext(ctx).SetCookie("token", req.Username, 600, "/", "", false, false) + return nil + } + return web.Error(400, "login failed") + }) + }) + + router.Group("/user", func(r web.Router) { + + // user login check + r.Use(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + // check login state in cookies + // + if _, err := request.Cookie("token"); nil != err { + writer.WriteHeader(http.StatusForbidden) + return + } + + // login check success + next.ServeHTTP(writer, request) + }) + }) + + r.Get("/userInfo", func(ctx context.Context) interface{} { + // TODO: load user from database + // + return map[string]interface{}{ + "username": "admin", + "time": time.Now().String(), + } + }) + + r.Get("/logout", func(ctx context.Context) string { + // delete cookie + web.FromContext(ctx).SetCookie("token", "", -1, "/", "", false, false) + return "success" + }) + + }) + + http.ListenAndServe(":8080", router) +} diff --git a/examples/stdmux/main.go b/examples/stdmux/main.go new file mode 100644 index 0000000..02094b3 --- /dev/null +++ b/examples/stdmux/main.go @@ -0,0 +1,38 @@ +package main + +import ( + "context" + "log/slog" + "mime/multipart" + "net/http" + + "go-spring.dev/web" +) + +func main() { + http.Handle("/user/register", web.Bind(UserRegister, web.JsonRender())) + + http.ListenAndServe(":8080", nil) +} + +type UserRegisterModel struct { + Username string `form:"username"` // username + Password string `form:"password"` // password + Avatar *multipart.FileHeader `form:"avatar"` // avatar + Captcha string `form:"captcha"` // captcha + UserAgent string `header:"User-Agent"` // user agent + Ad string `query:"ad"` // advertising ID + Token string `cookie:"token"` // token +} + +func UserRegister(ctx context.Context, req UserRegisterModel) string { + slog.Info("user register", + slog.String("username", req.Username), + slog.String("password", req.Password), + slog.String("captcha", req.Captcha), + slog.String("userAgent", req.UserAgent), + slog.String("ad", req.Ad), + slog.String("token", req.Token), + ) + return "success" +} diff --git a/examples/validator/main.go b/examples/validator/main.go new file mode 100644 index 0000000..a44d01a --- /dev/null +++ b/examples/validator/main.go @@ -0,0 +1,47 @@ +package main + +import ( + "context" + "log/slog" + "mime/multipart" + "net/http" + + "go-spring.dev/web" + "go-spring.dev/web/binding" + "gopkg.in/validator.v2" +) + +var validatorInst = validator.NewValidator().WithTag("validate") + +func main() { + binding.RegisterValidator(func(i interface{}) error { + return validatorInst.Validate(i) + }) + + var router = web.NewRouter() + router.Post("/user/register", UserRegister) + + http.ListenAndServe(":8080", router) +} + +type UserRegisterModel struct { + Username string `form:"username" validate:"min=6,max=20"` // username + Password string `form:"password" validate:"min=10,max=20"` // password + Avatar *multipart.FileHeader `form:"avatar" validate:"nonzero"` // avatar + Captcha string `form:"captcha" validate:"min=4,max=4"` // captcha + UserAgent string `header:"User-Agent"` // user agent + Ad string `query:"ad"` // advertising ID + Token string `cookie:"token"` // token +} + +func UserRegister(ctx context.Context, req UserRegisterModel) string { + slog.Info("user register", + slog.String("username", req.Username), + slog.String("password", req.Password), + slog.String("captcha", req.Captcha), + slog.String("userAgent", req.UserAgent), + slog.String("ad", req.Ad), + slog.String("token", req.Token), + ) + return "success" +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..d702fce --- /dev/null +++ b/go.mod @@ -0,0 +1,11 @@ +module go-spring.dev/web + +go 1.21 + +require github.com/stretchr/testify v1.8.4 + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..fa4b6e6 --- /dev/null +++ b/go.sum @@ -0,0 +1,10 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/middleware.go b/middleware.go new file mode 100644 index 0000000..28386b2 --- /dev/null +++ b/middleware.go @@ -0,0 +1,62 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package web + +import "net/http" + +// MiddlewareFunc is a function which receives an http.Handler and returns another http.Handler. +// Typically, the returned handler is a closure which does something with the http.ResponseWriter and http.Request passed +// to it, and then calls the handler passed as parameter to the MiddlewareFunc. +type MiddlewareFunc = func(next http.Handler) http.Handler + +// Middlewares type is a slice of standard middleware handlers with methods +// to compose middleware chains and http.Handler's. +type Middlewares []MiddlewareFunc + +// Handler builds and returns a http.Handler from the chain of middlewares, +// with `h http.Handler` as the final handler. +func (mws Middlewares) Handler(h http.Handler) http.Handler { + return &chainHandler{Endpoint: h, chain: mws.chain(h), Middlewares: mws} +} + +// HandlerFunc builds and returns a http.Handler from the chain of middlewares, +// with `h http.Handler` as the final handler. +func (mws Middlewares) HandlerFunc(h http.HandlerFunc) http.Handler { + return &chainHandler{Endpoint: h, chain: mws.chain(h), Middlewares: mws} +} + +// Build a http.Handler composed of an inline middlewares. +func (mws Middlewares) chain(handler http.Handler) http.Handler { + if 0 == len(mws) { + return handler + } + + for i := len(mws) - 1; i >= 0; i-- { + handler = mws[i](handler) + } + return handler +} + +type chainHandler struct { + Endpoint http.Handler + chain http.Handler + Middlewares Middlewares +} + +func (c *chainHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + c.chain.ServeHTTP(w, r) +} diff --git a/options.go b/options.go new file mode 100644 index 0000000..7b9758f --- /dev/null +++ b/options.go @@ -0,0 +1,104 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package web + +import ( + "crypto/tls" + "time" +) + +type Options struct { + // Addr optionally specifies the TCP address for the server to listen on, + // in the form "host:port". If empty, ":http" (port 8080) is used. + // The service names are defined in RFC 6335 and assigned by IANA. + // See net.Dial for details of the address format. + Addr string `json:"addr" value:"${addr:=}"` + + // CertFile containing a certificate and matching private key for the + // server must be provided if neither the Server's + // TLSConfig.Certificates nor TLSConfig.GetCertificate are populated. + // If the certificate is signed by a certificate authority, the + // certFile should be the concatenation of the server's certificate, + // any intermediates, and the CA's certificate. + CertFile string `json:"cert-file" value:"${cert-file:=}"` + + // KeyFile containing a private key file. + KeyFile string `json:"key-file" value:"${key-file:=}"` + + // ReadTimeout is the maximum duration for reading the entire + // request, including the body. A zero or negative value means + // there will be no timeout. + // + // Because ReadTimeout does not let Handlers make per-request + // decisions on each request body's acceptable deadline or + // upload rate, most users will prefer to use + // ReadHeaderTimeout. It is valid to use them both. + ReadTimeout time.Duration `json:"read-timeout" value:"${read-timeout:=0s}"` + + // ReadHeaderTimeout is the amount of time allowed to read + // request headers. The connection's read deadline is reset + // after reading the headers and the Handler can decide what + // is considered too slow for the body. If ReadHeaderTimeout + // is zero, the value of ReadTimeout is used. If both are + // zero, there is no timeout. + ReadHeaderTimeout time.Duration `json:"read-header-timeout" value:"${read-header-timeout:=0s}"` + + // WriteTimeout is the maximum duration before timing out + // writes of the response. It is reset whenever a new + // request's header is read. Like ReadTimeout, it does not + // let Handlers make decisions on a per-request basis. + // A zero or negative value means there will be no timeout. + WriteTimeout time.Duration `json:"write-timeout" value:"${write-timeout:=0s}"` + + // IdleTimeout is the maximum amount of time to wait for the + // next request when keep-alives are enabled. If IdleTimeout + // is zero, the value of ReadTimeout is used. If both are + // zero, there is no timeout. + IdleTimeout time.Duration `json:"idle-timeout" value:"${idle-timeout:=0s}"` + + // MaxHeaderBytes controls the maximum number of bytes the + // server will read parsing the request header's keys and + // values, including the request line. It does not limit the + // size of the request body. + // If zero, DefaultMaxHeaderBytes is used. + MaxHeaderBytes int `json:"max-header-bytes" value:"${max-header-bytes:=0}"` + + // Router optionally specifies an external router. + Router Router `json:"-"` +} + +func (options Options) IsTls() bool { + return len(options.CertFile) > 0 && len(options.KeyFile) > 0 +} + +func (options Options) TlsConfig() *tls.Config { + if !options.IsTls() { + return nil + } + + return &tls.Config{ + GetCertificate: options.GetCertificate, + } +} + +func (options Options) GetCertificate(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + cert, err := tls.LoadX509KeyPair(options.CertFile, options.KeyFile) + if err != nil { + return nil, err + } + return &cert, nil +} diff --git a/render/binary.go b/render/binary.go new file mode 100644 index 0000000..781c911 --- /dev/null +++ b/render/binary.go @@ -0,0 +1,39 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package render + +import ( + "net/http" +) + +type BinaryRenderer struct { + DataType string // Content-Type + Data []byte +} + +func (b BinaryRenderer) ContentType() string { + contentType := "application/octet-stream" + if len(b.DataType) > 0 { + contentType = b.DataType + } + return contentType +} + +func (b BinaryRenderer) Render(writer http.ResponseWriter) error { + _, err := writer.Write(b.Data) + return err +} diff --git a/render/binary_test.go b/render/binary_test.go new file mode 100644 index 0000000..05777f6 --- /dev/null +++ b/render/binary_test.go @@ -0,0 +1,42 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package render + +import ( + "crypto/rand" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBinaryRenderer(t *testing.T) { + + data := make([]byte, 1024) + if _, err := rand.Reader.Read(data); nil != err { + panic(err) + } + + w := httptest.NewRecorder() + + render := BinaryRenderer{DataType: "application/octet-stream", Data: data} + err := render.Render(w) + assert.Nil(t, err) + + assert.Equal(t, render.ContentType(), "application/octet-stream") + assert.Equal(t, w.Body.Bytes(), data) +} diff --git a/render/html.go b/render/html.go new file mode 100644 index 0000000..7ae2a4d --- /dev/null +++ b/render/html.go @@ -0,0 +1,39 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package render + +import ( + "html/template" + "net/http" +) + +type HTMLRenderer struct { + Template *template.Template + Name string + Data interface{} +} + +func (h HTMLRenderer) ContentType() string { + return "text/html; charset=utf-8" +} + +func (h HTMLRenderer) Render(writer http.ResponseWriter) error { + if len(h.Name) > 0 { + return h.Template.ExecuteTemplate(writer, h.Name, h.Data) + } + return h.Template.Execute(writer, h.Data) +} diff --git a/render/html_test.go b/render/html_test.go new file mode 100644 index 0000000..d28a4e4 --- /dev/null +++ b/render/html_test.go @@ -0,0 +1,38 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package render + +import ( + "html/template" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestHTMLRenderer(t *testing.T) { + + w := httptest.NewRecorder() + templ := template.Must(template.New("t").Parse(`Hello {{.name}}`)) + + htmlRender := HTMLRenderer{Template: templ, Name: "t", Data: map[string]interface{}{"name": "asdklajhdasdd"}} + err := htmlRender.Render(w) + + assert.Nil(t, err) + assert.Equal(t, htmlRender.ContentType(), "text/html; charset=utf-8") + assert.Equal(t, w.Body.String(), "Hello asdklajhdasdd") +} diff --git a/render/json.go b/render/json.go new file mode 100644 index 0000000..c616978 --- /dev/null +++ b/render/json.go @@ -0,0 +1,40 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package render + +import ( + "encoding/json" + "net/http" +) + +type JsonRenderer struct { + Prefix string + Indent string + Data interface{} +} + +func (j JsonRenderer) ContentType() string { + return "application/json; charset=utf-8" +} + +func (j JsonRenderer) Render(writer http.ResponseWriter) error { + encoder := json.NewEncoder(writer) + if len(j.Prefix) > 0 || len(j.Indent) > 0 { + encoder.SetIndent(j.Prefix, j.Indent) + } + return encoder.Encode(j.Data) +} diff --git a/render/json_test.go b/render/json_test.go new file mode 100644 index 0000000..839e4e4 --- /dev/null +++ b/render/json_test.go @@ -0,0 +1,40 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package render + +import ( + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestJSONRenderer(t *testing.T) { + data := map[string]any{ + "foo": "bar", + "html": "", + } + + w := httptest.NewRecorder() + + render := JsonRenderer{Data: data} + err := render.Render(w) + assert.Nil(t, err) + + assert.Equal(t, render.ContentType(), "application/json; charset=utf-8") + assert.Equal(t, w.Body.String(), "{\"foo\":\"bar\",\"html\":\"\\u003cb\\u003e\"}\n") +} diff --git a/render/redirect.go b/render/redirect.go new file mode 100644 index 0000000..d758186 --- /dev/null +++ b/render/redirect.go @@ -0,0 +1,40 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package render + +import ( + "fmt" + "net/http" +) + +type RedirectRenderer struct { + Code int + Request *http.Request + Location string +} + +func (r RedirectRenderer) ContentType() string { + return "" +} + +func (r RedirectRenderer) Render(writer http.ResponseWriter) error { + if (r.Code < http.StatusMultipleChoices || r.Code > http.StatusPermanentRedirect) && r.Code != http.StatusCreated { + panic(fmt.Sprintf("Cannot redirect with status code %d", r.Code)) + } + http.Redirect(writer, r.Request, r.Location, r.Code) + return nil +} diff --git a/render/redirect_test.go b/render/redirect_test.go new file mode 100644 index 0000000..a8e526e --- /dev/null +++ b/render/redirect_test.go @@ -0,0 +1,63 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package render + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRedirectRenderer(t *testing.T) { + req, err := http.NewRequest("GET", "/test-redirect", nil) + assert.Nil(t, err) + + data1 := RedirectRenderer{ + Code: http.StatusMovedPermanently, + Request: req, + Location: "/new/location", + } + + w := httptest.NewRecorder() + err = data1.Render(w) + assert.Nil(t, err) + assert.Equal(t, data1.ContentType(), "") + + data2 := RedirectRenderer{ + Code: http.StatusOK, + Request: req, + Location: "/new/location", + } + + w = httptest.NewRecorder() + assert.Panics(t, func() { + err := data2.Render(w) + assert.Nil(t, err) + }, "Cannot redirect with status code 200") + + data3 := RedirectRenderer{ + Code: http.StatusCreated, + Request: req, + Location: "/new/location", + } + + w = httptest.NewRecorder() + err = data3.Render(w) + assert.Nil(t, err) +} diff --git a/render/renderer.go b/render/renderer.go new file mode 100644 index 0000000..969e716 --- /dev/null +++ b/render/renderer.go @@ -0,0 +1,27 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package render + +import ( + "net/http" +) + +// Renderer writes data with custom ContentType and headers. +type Renderer interface { + ContentType() string + Render(writer http.ResponseWriter) error +} diff --git a/render/text.go b/render/text.go new file mode 100644 index 0000000..f81d208 --- /dev/null +++ b/render/text.go @@ -0,0 +1,38 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package render + +import ( + "fmt" + "io" + "net/http" + "strings" +) + +type TextRenderer struct { + Format string + Args []interface{} +} + +func (t TextRenderer) ContentType() string { + return "text/plain; charset=utf-8" +} + +func (t TextRenderer) Render(writer http.ResponseWriter) error { + _, err := io.Copy(writer, strings.NewReader(fmt.Sprintf(t.Format, t.Args...))) + return err +} diff --git a/render/text_test.go b/render/text_test.go new file mode 100644 index 0000000..3f2cd11 --- /dev/null +++ b/render/text_test.go @@ -0,0 +1,39 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package render + +import ( + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestTextRenderer(t *testing.T) { + w := httptest.NewRecorder() + + render := TextRenderer{ + Format: "hello %s %d", + Args: []any{"bob", 2}, + } + + err := render.Render(w) + + assert.Nil(t, err) + assert.Equal(t, render.ContentType(), "text/plain; charset=utf-8") + assert.Equal(t, w.Body.String(), "hello bob 2") +} diff --git a/render/xml.go b/render/xml.go new file mode 100644 index 0000000..19533b8 --- /dev/null +++ b/render/xml.go @@ -0,0 +1,40 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package render + +import ( + "encoding/xml" + "net/http" +) + +type XmlRenderer struct { + Prefix string + Indent string + Data interface{} +} + +func (x XmlRenderer) ContentType() string { + return "application/xml; charset=utf-8" +} + +func (x XmlRenderer) Render(writer http.ResponseWriter) error { + encoder := xml.NewEncoder(writer) + if len(x.Prefix) > 0 || len(x.Indent) > 0 { + encoder.Indent(x.Prefix, x.Indent) + } + return encoder.Encode(x.Data) +} diff --git a/render/xml_test.go b/render/xml_test.go new file mode 100644 index 0000000..a6cceb2 --- /dev/null +++ b/render/xml_test.go @@ -0,0 +1,63 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package render + +import ( + "encoding/xml" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +type xmlmap map[string]any + +// Allows type H to be used with xml.Marshal +func (h xmlmap) MarshalXML(e *xml.Encoder, start xml.StartElement) error { + start.Name = xml.Name{ + Space: "", + Local: "map", + } + if err := e.EncodeToken(start); err != nil { + return err + } + for key, value := range h { + elem := xml.StartElement{ + Name: xml.Name{Space: "", Local: key}, + Attr: []xml.Attr{}, + } + if err := e.EncodeElement(value, elem); err != nil { + return err + } + } + + return e.EncodeToken(xml.EndElement{Name: start.Name}) +} + +func TestXmlRenderer(t *testing.T) { + w := httptest.NewRecorder() + data := xmlmap{ + "foo": "bar", + } + + render := (XmlRenderer{Data: data}) + err := render.Render(w) + + assert.Nil(t, err) + assert.Equal(t, render.ContentType(), "application/xml; charset=utf-8") + assert.Equal(t, w.Body.String(), "bar") +} diff --git a/router.go b/router.go new file mode 100644 index 0000000..ca0bc40 --- /dev/null +++ b/router.go @@ -0,0 +1,484 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package web + +import ( + "fmt" + "net/http" + "sync" +) + +// Router registers routes to be matched and dispatches a handler. +// +// Registers a new route with a matcher for the URL pattern. +// Automatic binding request to handler input params and validate params. +// +// Router.Any() Router.Get() Router.Head() Router.Post() Router.Put() Router.Patch() Router.Delete() Router.Connect() Router.Options() Router.Trace() +// +// The handler accepts the following functional signatures: +// +// func(ctx context.Context) +// +// func(ctx context.Context) R +// +// func(ctx context.Context) error +// +// func(ctx context.Context, req T) R +// +// func(ctx context.Context, req T) error +// +// func(ctx context.Context, req T) (R, error) +// +// It implements the http.Handler interface, so it can be registered to serve +// requests: +// +// var router = web.NewRouter() +// +// func main() { +// router.Get("/greeting", func(ctx context.Context) string { +// return "greeting!!!" +// }) +// http.Handle("/", router) +// } +type Router interface { + // Handler dispatches the handler registered in the matched route. + http.Handler + + // Use appends a MiddlewareFunc to the chain. + Use(mwf ...MiddlewareFunc) Router + + // Renderer to be used Response renderer in default. + Renderer(renderer Renderer) Router + + // Group creates a new router group. + Group(pattern string, fn ...func(r Router)) Router + + // Handle registers a new route with a matcher for the URL pattern. + Handle(pattern string, handler http.Handler) + + // HandleFunc registers a new route with a matcher for the URL pattern. + HandleFunc(pattern string, handler http.HandlerFunc) + + // Any registers a route that matches all the HTTP methods. + // GET, POST, PUT, PATCH, HEAD, OPTIONS, DELETE, CONNECT, TRACE. + Any(pattern string, handler interface{}) + + // Get registers a new GET route with a matcher for the URL path of the get method. + Get(pattern string, handler interface{}) + + // Head registers a new HEAD route with a matcher for the URL path of the head method. + Head(pattern string, handler interface{}) + + // Post registers a new POST route with a matcher for the URL path of the post method. + Post(pattern string, handler interface{}) + + // Put registers a new PUT route with a matcher for the URL path of the put method. + Put(pattern string, handler interface{}) + + // Patch registers a new PATCH route with a matcher for the URL path of the patch method. + Patch(pattern string, handler interface{}) + + // Delete registers a new DELETE route with a matcher for the URL path of the delete method. + Delete(pattern string, handler interface{}) + + // Connect registers a new CONNECT route with a matcher for the URL path of the connect method. + Connect(pattern string, handler interface{}) + + // Options registers a new OPTIONS route with a matcher for the URL path of the options method. + Options(pattern string, handler interface{}) + + // Trace registers a new TRACE route with a matcher for the URL path of the trace method. + Trace(pattern string, handler interface{}) + + // NotFound to be used when no route matches. + NotFound(handler http.HandlerFunc) + + // MethodNotAllowed to be used when the request method does not match the route. + MethodNotAllowed(handler http.HandlerFunc) +} + +type Routes interface { + // Routes returns the routing tree in an easily traversable structure. + Routes() []Route + + // Middlewares returns the list of middlewares in use by the router. + Middlewares() Middlewares + + // Match searches the routing tree for a handler that matches + // the method/path - similar to routing a http request, but without + // executing the handler thereafter. + Match(ctx *RouteContext, method, path string) bool +} + +// NewRouter returns a new router instance. +func NewRouter() Router { + return &routerGroup{ + tree: &node{}, + renderer: JsonRender(), + pool: &sync.Pool{New: func() interface{} { return &RouteContext{} }}, + } +} + +type routerGroup struct { + handler http.Handler + inline bool + tree *node + parent *routerGroup + middlewares Middlewares + renderer Renderer + notFoundHandler http.HandlerFunc + notAllowedHandler http.HandlerFunc + pool *sync.Pool +} + +// Use appends a MiddlewareFunc to the chain. +// Middleware can be used to intercept or otherwise modify requests and/or responses, and are executed in the order that they are applied to the Router. +func (rg *routerGroup) Use(mwf ...MiddlewareFunc) Router { + if rg.handler != nil { + panic("middlewares must be defined before routes registers") + } + rg.middlewares = append(rg.middlewares, mwf...) + return rg +} + +// Renderer to be used Response renderer in default. +func (rg *routerGroup) Renderer(renderer Renderer) Router { + rg.renderer = renderer + return rg +} + +func (rg *routerGroup) NotFoundHandler() http.Handler { + if rg.notFoundHandler != nil { + return rg.notFoundHandler + } + return notFound() +} + +func (rg *routerGroup) NotAllowedHandler() http.Handler { + if rg.notAllowedHandler != nil { + return rg.notAllowedHandler + } + return notAllowed() +} + +// ServeHTTP dispatches the handler registered in the matched route. +func (rg *routerGroup) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if nil == rg.handler { + rg.NotFoundHandler().ServeHTTP(w, r) + return + } + + ctx := FromRouteContext(r.Context()) + if nil != ctx { + rg.handler.ServeHTTP(w, r) + return + } + + // get context from pool + ctx = rg.pool.Get().(*RouteContext) + ctx.Routes = rg + + // with context + r = r.WithContext(WithRouteContext(r.Context(), ctx)) + rg.handler.ServeHTTP(w, r) + + // put context to pool + ctx.Reset() + rg.pool.Put(ctx) + +} + +// Recursively update data on child routers. +func (rg *routerGroup) updateSubRoutes(fn func(subMux *routerGroup)) { + for _, r := range rg.tree.routes() { + subMux, ok := r.SubRoutes.(*routerGroup) + if !ok { + continue + } + fn(subMux) + } +} + +func (rg *routerGroup) nextRoutePath(ctx *RouteContext) string { + routePath := "/" + nx := len(ctx.routeParams.Keys) - 1 // index of last param in list + if nx >= 0 && ctx.routeParams.Keys[nx] == "*" && len(ctx.routeParams.Values) > nx { + routePath = "/" + ctx.routeParams.Values[nx] + } + return routePath +} + +// routeHTTP Routes a http.Request through the routing tree to serve +// the matching handler for a particular http method. +func (rg *routerGroup) routeHTTP(w http.ResponseWriter, r *http.Request) { + // Grab the route context object + ctx := FromRouteContext(r.Context()) + + // The request routing path + routePath := ctx.RoutePath + if routePath == "" { + if r.URL.RawPath != "" { + routePath = r.URL.RawPath + } else { + routePath = r.URL.Path + } + if routePath == "" { + routePath = "/" + } + } + + if ctx.RouteMethod == "" { + ctx.RouteMethod = r.Method + } + + method, ok := methodMap[ctx.RouteMethod] + if !ok { + rg.NotAllowedHandler().ServeHTTP(w, r) + return + } + + // Find the route + if _, _, h := rg.tree.FindRoute(ctx, method, routePath); h != nil { + h.ServeHTTP(w, r) + return + } + if ctx.methodNotAllowed { + rg.NotAllowedHandler().ServeHTTP(w, r) + } else { + rg.NotFoundHandler().ServeHTTP(w, r) + } +} + +// Group creates a new router group. +func (rg *routerGroup) Group(pattern string, fn ...func(r Router)) Router { + subRouter := &routerGroup{tree: &node{}, renderer: rg.renderer, pool: rg.pool} + for _, f := range fn { + f(subRouter) + } + rg.Mount(pattern, subRouter) + return subRouter +} + +// Mount attaches another http.Handler or RouterGroup as a subrouter along a routing +// path. It's very useful to split up a large API as many independent routers and +// compose them as a single service using Mount. +func (rg *routerGroup) Mount(pattern string, handler http.Handler) { + if handler == nil { + panic(fmt.Sprintf("attempting to Mount() a nil handler on '%s'", pattern)) + } + + // Provide runtime safety for ensuring a pattern isn't mounted on an existing + // routing pattern. + if rg.tree.findPattern(pattern+"*") || rg.tree.findPattern(pattern+"/*") { + panic(fmt.Sprintf("attempting to Mount() a handler on an existing path, '%s'", pattern)) + } + + // Assign sub-Router'rg with the parent not found & method not allowed handler if not specified. + subr, ok := handler.(*routerGroup) + if ok && subr.notFoundHandler == nil && rg.notFoundHandler != nil { + subr.NotFound(rg.notFoundHandler) + } + if ok && subr.notAllowedHandler == nil && rg.notAllowedHandler != nil { + subr.MethodNotAllowed(rg.notAllowedHandler) + } + + mountHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := FromRouteContext(r.Context()) + + // shift the url path past the previous subrouter + ctx.RoutePath = rg.nextRoutePath(ctx) + + // reset the wildcard URLParam which connects the subrouter + n := len(ctx.URLParams.Keys) - 1 + if n >= 0 && ctx.URLParams.Keys[n] == "*" && len(ctx.URLParams.Values) > n { + ctx.URLParams.Values[n] = "" + } + + handler.ServeHTTP(w, r) + }) + + if pattern == "" || pattern[len(pattern)-1] != '/' { + rg.handle(mALL|mSTUB, pattern, mountHandler) + rg.handle(mALL|mSTUB, pattern+"/", mountHandler) + pattern += "/" + } + + method := mALL + subroutes, _ := handler.(Routes) + if subroutes != nil { + method |= mSTUB + } + n := rg.handle(method, pattern+"*", mountHandler) + + if subroutes != nil { + n.subroutes = subroutes + } +} + +// bind a new route with a matcher for the URL pattern. +// Automatic binding request to handler input params and validate params. +func (rg *routerGroup) bind(method methodTyp, pattern string, handler interface{}) *node { + return rg.handle(method, pattern, Bind(handler, rg.renderer)) +} + +func (rg *routerGroup) handle(method methodTyp, pattern string, handler http.Handler) *node { + if len(pattern) == 0 || pattern[0] != '/' { + panic(fmt.Sprintf("routing pattern must begin with '/' in '%s'", pattern)) + } + if !rg.inline && rg.handler == nil { + rg.handler = rg.middlewares.HandlerFunc(rg.routeHTTP) + } + + if rg.inline { + rg.handler = http.HandlerFunc(rg.routeHTTP) + handler = rg.middlewares.Handler(handler) + } + + // Add the endpoint to the tree + return rg.tree.InsertRoute(method, pattern, handler) +} + +// Handle registers a new route with a matcher for the URL pattern. +func (rg *routerGroup) Handle(pattern string, handler http.Handler) { + rg.handle(mALL, pattern, handler) +} + +// HandleFunc registers a new route with a matcher for the URL pattern. +func (rg *routerGroup) HandleFunc(pattern string, handler http.HandlerFunc) { + rg.handle(mALL, pattern, handler) +} + +// Any registers a route that matches all the HTTP methods. +// GET, POST, PUT, PATCH, HEAD, OPTIONS, DELETE, CONNECT, TRACE. +func (rg *routerGroup) Any(pattern string, handler interface{}) { + rg.bind(mALL, pattern, handler) +} + +// Get registers a new GET route with a matcher for the URL pattern of the get method. +func (rg *routerGroup) Get(pattern string, handler interface{}) { + rg.bind(mGET, pattern, handler) +} + +// Head registers a new HEAD route with a matcher for the URL pattern of the get method. +func (rg *routerGroup) Head(pattern string, handler interface{}) { + rg.bind(mHEAD, pattern, handler) +} + +// Post registers a new POST route with a matcher for the URL pattern of the get method. +func (rg *routerGroup) Post(pattern string, handler interface{}) { + rg.bind(mPOST, pattern, handler) +} + +// Put registers a new PUT route with a matcher for the URL pattern of the get method. +func (rg *routerGroup) Put(pattern string, handler interface{}) { + rg.bind(mPUT, pattern, handler) +} + +// Patch registers a new PATCH route with a matcher for the URL pattern of the get method. +func (rg *routerGroup) Patch(pattern string, handler interface{}) { + rg.bind(mPATCH, pattern, handler) +} + +// Delete registers a new DELETE route with a matcher for the URL pattern of the get method. +func (rg *routerGroup) Delete(pattern string, handler interface{}) { + rg.bind(mDELETE, pattern, handler) +} + +// Connect registers a new CONNECT route with a matcher for the URL pattern of the get method. +func (rg *routerGroup) Connect(pattern string, handler interface{}) { + rg.bind(mCONNECT, pattern, handler) +} + +// Options registers a new OPTIONS route with a matcher for the URL pattern of the get method. +func (rg *routerGroup) Options(pattern string, handler interface{}) { + rg.bind(mOPTIONS, pattern, handler) +} + +// Trace registers a new TRACE route with a matcher for the URL pattern of the get method. +func (rg *routerGroup) Trace(pattern string, handler interface{}) { + rg.bind(mTRACE, pattern, handler) +} + +// NotFound to be used when no route matches. +// This can be used to render your own 404 Not Found errors. +func (rg *routerGroup) NotFound(handler http.HandlerFunc) { + // Build NotFound handler chain + m := rg + hFn := handler + if rg.inline && rg.parent != nil { + m = rg.parent + hFn = rg.middlewares.HandlerFunc(hFn).ServeHTTP + } + + // Update the notFoundHandler from this point forward + m.notFoundHandler = hFn + m.updateSubRoutes(func(subMux *routerGroup) { + if subMux.notFoundHandler == nil { + subMux.NotFound(hFn) + } + }) +} + +// MethodNotAllowed to be used when the request method does not match the route. +// This can be used to render your own 405 Method Not Allowed errors. +func (rg *routerGroup) MethodNotAllowed(handler http.HandlerFunc) { + // Build MethodNotAllowed handler chain + m := rg + hFn := handler + if rg.inline && rg.parent != nil { + m = rg.parent + hFn = rg.middlewares.HandlerFunc(hFn).ServeHTTP + } + + // Update the methodNotAllowedHandler from this point forward + m.notAllowedHandler = hFn + m.updateSubRoutes(func(subMux *routerGroup) { + if subMux.notAllowedHandler == nil { + subMux.MethodNotAllowed(hFn) + } + }) +} + +// Routes returns a slice of routing information from the tree, +// useful for traversing available Routes of a router. +func (rg *routerGroup) Routes() []Route { + return rg.tree.routes() +} + +// Middlewares returns a slice of middleware handler functions. +func (rg *routerGroup) Middlewares() Middlewares { + return rg.middlewares +} + +// Match searches the routing tree for a handler that matches the method/path. +// It's similar to routing a http request, but without executing the handler +// thereafter. +func (rg *routerGroup) Match(ctx *RouteContext, method, path string) bool { + m, ok := methodMap[method] + if !ok { + return false + } + + node, _, h := rg.tree.FindRoute(ctx, m, path) + + if node != nil && node.subroutes != nil { + ctx.RoutePath = rg.nextRoutePath(ctx) + return node.subroutes.Match(ctx, method, ctx.RoutePath) + } + + return h != nil +} diff --git a/router_test.go b/router_test.go new file mode 100644 index 0000000..1421771 --- /dev/null +++ b/router_test.go @@ -0,0 +1,1315 @@ +package web + +import ( + "bytes" + "context" + "fmt" + "io" + "io/ioutil" + "net" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" +) + +func URLParam(r *http.Request, name string) string { + if ctx := FromRouteContext(r.Context()); nil != ctx { + v, _ := ctx.URLParams.Get(name) + return v + } + return "" +} + +func TestMuxBasic(t *testing.T) { + var count uint64 + countermw := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + count++ + next.ServeHTTP(w, r) + }) + } + + usermw := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + ctx = context.WithValue(ctx, ctxKey{"user"}, "peter") + r = r.WithContext(ctx) + next.ServeHTTP(w, r) + }) + } + + exmw := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := context.WithValue(r.Context(), ctxKey{"ex"}, "a") + r = r.WithContext(ctx) + next.ServeHTTP(w, r) + }) + } + + logbuf := bytes.NewBufferString("") + logmsg := "logmw test" + logmw := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + logbuf.WriteString(logmsg) + next.ServeHTTP(w, r) + }) + } + + cxindex := func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + user := ctx.Value(ctxKey{"user"}).(string) + w.WriteHeader(200) + w.Write([]byte(fmt.Sprintf("hi %s", user))) + } + + headPing := func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Ping", "1") + w.WriteHeader(200) + } + + createPing := func(w http.ResponseWriter, r *http.Request) { + // create .... + w.WriteHeader(201) + } + + pingAll2 := func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte("ping all2")) + } + + pingOne := func(w http.ResponseWriter, r *http.Request) { + idParam := URLParam(r, "id") + w.WriteHeader(200) + w.Write([]byte(fmt.Sprintf("ping one id: %s", idParam))) + } + + pingWoop := func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte("woop." + URLParam(r, "iidd"))) + } + + catchAll := func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte("catchall")) + } + + m := NewRouter() + m.Use(countermw) + m.Use(usermw) + m.Use(exmw) + m.Use(logmw) + m.Get("/", cxindex) + m.Get("/ping/all2", pingAll2) + + m.Head("/ping", headPing) + m.Post("/ping", createPing) + m.Get("/ping/{id}", pingWoop) + m.Get("/ping/{id}", pingOne) // expected to overwrite to pingOne handler + m.Get("/ping/{iidd}/woop", pingWoop) + m.HandleFunc("/admin/*", catchAll) + // m.Post("/admin/*", catchAll) + + ts := httptest.NewServer(m) + defer ts.Close() + + // GET / + if _, body := testRequest(t, ts, "GET", "/", nil); body != "hi peter" { + t.Fatalf(body) + } + tlogmsg, _ := logbuf.ReadString(0) + if tlogmsg != logmsg { + t.Error("expecting log message from middleware:", logmsg) + } + + // GET /ping/all2 + if _, body := testRequest(t, ts, "GET", "/ping/all2", nil); body != "ping all2" { + t.Fatalf(body) + } + + // GET /ping/123 + if _, body := testRequest(t, ts, "GET", "/ping/123", nil); body != "ping one id: 123" { + t.Fatalf(body) + } + + // GET /ping/allan + if _, body := testRequest(t, ts, "GET", "/ping/allan", nil); body != "ping one id: allan" { + t.Fatalf(body) + } + + // GET /ping/1/woop + if _, body := testRequest(t, ts, "GET", "/ping/1/woop", nil); body != "woop.1" { + t.Fatalf(body) + } + + // HEAD /ping + resp, err := http.Head(ts.URL + "/ping") + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != 200 { + t.Error("head failed, should be 200") + } + if resp.Header.Get("X-Ping") == "" { + t.Error("expecting X-Ping header") + } + + // GET /admin/catch-this + if _, body := testRequest(t, ts, "GET", "/admin/catch-thazzzzz", nil); body != "catchall" { + t.Fatalf(body) + } + + // POST /admin/catch-this + resp, err = http.Post(ts.URL+"/admin/casdfsadfs", "text/plain", bytes.NewReader([]byte{})) + if err != nil { + t.Fatal(err) + } + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + t.Error("POST failed, should be 200") + } + + if string(body) != "catchall" { + t.Error("expecting response body: 'catchall'") + } + + // Custom http method DIE /ping/1/woop + if resp, body := testRequest(t, ts, "DIE", "/ping/1/woop", nil); body != "405 method not allowed\n" || resp.StatusCode != 405 { + t.Fatalf(fmt.Sprintf("expecting 405 status and empty body, got %d '%s'", resp.StatusCode, body)) + } +} + +func TestMuxMounts(t *testing.T) { + r := NewRouter() + + r.Get("/{hash}", func(w http.ResponseWriter, r *http.Request) { + v := URLParam(r, "hash") + w.Write([]byte(fmt.Sprintf("/%s", v))) + }) + + (func(r Router) { + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + v := URLParam(r, "hash") + w.Write([]byte(fmt.Sprintf("/%s/share", v))) + }) + r.Get("/{network}", func(w http.ResponseWriter, r *http.Request) { + v := URLParam(r, "hash") + n := URLParam(r, "network") + w.Write([]byte(fmt.Sprintf("/%s/share/%s", v, n))) + }) + })(r.Group("/{hash}/share")) + + m := NewRouter().(*routerGroup) + m.Mount("/sharing", r) + + ts := httptest.NewServer(m) + defer ts.Close() + + if _, body := testRequest(t, ts, "GET", "/sharing/aBc", nil); body != "/aBc" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/sharing/aBc/share", nil); body != "/aBc/share" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/sharing/aBc/share/twitter", nil); body != "/aBc/share/twitter" { + t.Fatalf(body) + } +} + +func TestMuxPlain(t *testing.T) { + r := NewRouter() + r.Get("/hi", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("bye")) + }) + r.NotFound(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + w.Write([]byte("nothing here")) + }) + + ts := httptest.NewServer(r) + defer ts.Close() + + if _, body := testRequest(t, ts, "GET", "/hi", nil); body != "bye" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/nothing-here", nil); body != "nothing here" { + t.Fatalf(body) + } +} + +func TestMuxEmptyRoutes(t *testing.T) { + mux := NewRouter() + + apiRouter := NewRouter() + // oops, we forgot to declare any route handlers + + mux.Handle("/api*", apiRouter) + + if _, body := testHandler(t, mux, "GET", "/", nil); body != "404 page not found\n" { + t.Fatalf(body) + } + + if _, body := testHandler(t, apiRouter, "GET", "/", nil); body != "404 page not found\n" { + t.Fatalf(body) + } +} + +// Test a mux that routes a trailing slash, see also middleware/strip_test.go +// for an example of using a middleware to handle trailing slashes. +func TestMuxTrailingSlash(t *testing.T) { + r := NewRouter().(*routerGroup) + r.NotFound(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + w.Write([]byte("nothing here")) + }) + + subRoutes := NewRouter() + indexHandler := func(w http.ResponseWriter, r *http.Request) { + accountID := URLParam(r, "accountID") + w.Write([]byte(accountID)) + } + subRoutes.Get("/", indexHandler) + + r.Mount("/accounts/{accountID}", subRoutes) + r.Get("/accounts/{accountID}/", indexHandler) + + ts := httptest.NewServer(r) + defer ts.Close() + + if _, body := testRequest(t, ts, "GET", "/accounts/admin", nil); body != "admin" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/accounts/admin/", nil); body != "admin" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/nothing-here", nil); body != "nothing here" { + t.Fatalf(body) + } +} + +func TestMethodNotAllowed(t *testing.T) { + r := NewRouter() + + r.Get("/hi", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hi, get")) + }) + + r.Head("/hi", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hi, head")) + }) + + ts := httptest.NewServer(r) + defer ts.Close() + + t.Run("Registered Method", func(t *testing.T) { + resp, _ := testRequest(t, ts, "GET", "/hi", nil) + if resp.StatusCode != 200 { + t.Fatal(resp.Status) + } + if resp.Header.Values("Allow") != nil { + t.Fatal("allow should be empty when method is registered") + } + }) + + t.Run("Unregistered Method", func(t *testing.T) { + resp, _ := testRequest(t, ts, "POST", "/hi", nil) + if resp.StatusCode != 405 { + t.Fatal(resp.Status) + } + }) +} + +func TestMuxNestedMethodNotAllowed(t *testing.T) { + r := NewRouter().(*routerGroup) + r.Get("/root", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("root")) + }) + r.MethodNotAllowed(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(405) + w.Write([]byte("root 405")) + }) + + sr1 := NewRouter() + sr1.Get("/sub1", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("sub1")) + }) + sr1.MethodNotAllowed(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(405) + w.Write([]byte("sub1 405")) + }) + + sr2 := NewRouter() + sr2.Get("/sub2", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("sub2")) + }) + + pathVar := NewRouter() + pathVar.Get("/{var}", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("pv")) + }) + pathVar.MethodNotAllowed(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(405) + w.Write([]byte("pv 405")) + }) + + r.Mount("/prefix1", sr1) + r.Mount("/prefix2", sr2) + r.Mount("/pathVar", pathVar) + + ts := httptest.NewServer(r) + defer ts.Close() + + if _, body := testRequest(t, ts, "GET", "/root", nil); body != "root" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "PUT", "/root", nil); body != "root 405" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/prefix1/sub1", nil); body != "sub1" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "PUT", "/prefix1/sub1", nil); body != "sub1 405" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/prefix2/sub2", nil); body != "sub2" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "PUT", "/prefix2/sub2", nil); body != "root 405" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/pathVar/myvar", nil); body != "pv" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "DELETE", "/pathVar/myvar", nil); body != "pv 405" { + t.Fatalf(body) + } +} + +func TestMuxComplicatedNotFound(t *testing.T) { + decorateRouter := func(r *routerGroup) { + // Root router with groups + r.Get("/auth", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("auth get")) + }) + (func(r Router) { + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("public get")) + }) + })(r.Group("/public")) + + // sub router with groups + sub0 := NewRouter() + (func(r Router) { + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("private get")) + }) + })(sub0.Group("/resource")) + r.Mount("/private", sub0) + + // sub router with groups + sub1 := NewRouter() + (func(r Router) { + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("private get")) + }) + })(sub1.Group("/resource")) + } + + testNotFound := func(t *testing.T, r *routerGroup) { + ts := httptest.NewServer(r) + defer ts.Close() + + // check that we didn't break correct routes + if _, body := testRequest(t, ts, "GET", "/auth", nil); body != "auth get" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/public", nil); body != "public get" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/public/", nil); body != "public get" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/private/resource", nil); body != "private get" { + t.Fatalf(body) + } + // check custom not-found on all levels + if _, body := testRequest(t, ts, "GET", "/nope", nil); body != "custom not-found" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/public/nope", nil); body != "custom not-found" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/private/nope", nil); body != "custom not-found" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/private/resource/nope", nil); body != "custom not-found" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/private_mw/nope", nil); body != "custom not-found" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/private_mw/resource/nope", nil); body != "custom not-found" { + t.Fatalf(body) + } + // check custom not-found on trailing slash routes + if _, body := testRequest(t, ts, "GET", "/auth/", nil); body != "custom not-found" { + t.Fatalf(body) + } + } + + t.Run("pre", func(t *testing.T) { + r := NewRouter().(*routerGroup) + r.NotFound(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("custom not-found")) + }) + decorateRouter(r) + testNotFound(t, r) + }) + + t.Run("post", func(t *testing.T) { + r := NewRouter().(*routerGroup) + decorateRouter(r) + r.NotFound(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("custom not-found")) + }) + testNotFound(t, r) + }) +} + +func TestMuxMiddlewareStack(t *testing.T) { + var stdmwInit, stdmwHandler uint64 + stdmw := func(next http.Handler) http.Handler { + stdmwInit++ + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + stdmwHandler++ + next.ServeHTTP(w, r) + }) + } + _ = stdmw + + var ctxmwInit, ctxmwHandler uint64 + ctxmw := func(next http.Handler) http.Handler { + ctxmwInit++ + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctxmwHandler++ + ctx := r.Context() + ctx = context.WithValue(ctx, ctxKey{"count.ctxmwHandler"}, ctxmwHandler) + r = r.WithContext(ctx) + next.ServeHTTP(w, r) + }) + } + + r := NewRouter() + r.Use(stdmw) + r.Use(ctxmw) + r.Use(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/ping" { + w.Write([]byte("pong")) + return + } + next.ServeHTTP(w, r) + }) + }) + + var handlerCount uint64 + + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + handlerCount++ + ctx := r.Context() + ctxmwHandlerCount := ctx.Value(ctxKey{"count.ctxmwHandler"}).(uint64) + w.Write([]byte(fmt.Sprintf("inits:%d reqs:%d ctxValue:%d", ctxmwInit, handlerCount, ctxmwHandlerCount))) + }) + + r.Get("/hi", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("wooot")) + }) + + ts := httptest.NewServer(r) + defer ts.Close() + + testRequest(t, ts, "GET", "/", nil) + testRequest(t, ts, "GET", "/", nil) + var body string + _, body = testRequest(t, ts, "GET", "/", nil) + if body != "inits:1 reqs:3 ctxValue:3" { + t.Fatalf("got: '%s'", body) + } + + _, body = testRequest(t, ts, "GET", "/ping", nil) + if body != "pong" { + t.Fatalf("got: '%s'", body) + } +} + +func TestMuxSubroutesBasic(t *testing.T) { + hIndex := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("index")) + }) + hArticlesList := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("articles-list")) + }) + hSearchArticles := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("search-articles")) + }) + hGetArticle := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(fmt.Sprintf("get-article:%s", URLParam(r, "id")))) + }) + hSyncArticle := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(fmt.Sprintf("sync-article:%s", URLParam(r, "id")))) + }) + + r := NewRouter() + // var rr1, rr2 *Mux + r.Get("/", hIndex) + (func(r Router) { + // rr1 = r.(*Mux) + r.Get("/", hArticlesList) + r.Get("/search", hSearchArticles) + (func(r Router) { + // rr2 = r.(*Mux) + r.Get("/", hGetArticle) + r.Get("/sync", hSyncArticle) + })(r.Group("/{id}")) + })(r.Group("/articles")) + + // log.Println("~~~~~~~~~") + // log.Println("~~~~~~~~~") + // debugPrintTree(0, 0, r.tree, 0) + // log.Println("~~~~~~~~~") + // log.Println("~~~~~~~~~") + + // log.Println("~~~~~~~~~") + // log.Println("~~~~~~~~~") + // debugPrintTree(0, 0, rr1.tree, 0) + // log.Println("~~~~~~~~~") + // log.Println("~~~~~~~~~") + + // log.Println("~~~~~~~~~") + // log.Println("~~~~~~~~~") + // debugPrintTree(0, 0, rr2.tree, 0) + // log.Println("~~~~~~~~~") + // log.Println("~~~~~~~~~") + + ts := httptest.NewServer(r) + defer ts.Close() + + var body, expected string + + _, body = testRequest(t, ts, "GET", "/", nil) + expected = "index" + if body != expected { + t.Fatalf("expected:%s got:%s", expected, body) + } + _, body = testRequest(t, ts, "GET", "/articles", nil) + expected = "articles-list" + if body != expected { + t.Fatalf("expected:%s got:%s", expected, body) + } + _, body = testRequest(t, ts, "GET", "/articles/search", nil) + expected = "search-articles" + if body != expected { + t.Fatalf("expected:%s got:%s", expected, body) + } + _, body = testRequest(t, ts, "GET", "/articles/123", nil) + expected = "get-article:123" + if body != expected { + t.Fatalf("expected:%s got:%s", expected, body) + } + _, body = testRequest(t, ts, "GET", "/articles/123/sync", nil) + expected = "sync-article:123" + if body != expected { + t.Fatalf("expected:%s got:%s", expected, body) + } +} + +func TestMuxSubroutes(t *testing.T) { + hHubView1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hub1")) + }) + hHubView2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hub2")) + }) + hHubView3 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hub3")) + }) + hAccountView1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("account1")) + }) + hAccountView2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("account2")) + }) + + r := NewRouter().(*routerGroup) + r.Get("/hubs/{hubID}/view", hHubView1) + r.Get("/hubs/{hubID}/view/*", hHubView2) + + sr := NewRouter().(*routerGroup) + sr.Get("/", hHubView3) + r.Mount("/hubs/{hubID}/users", sr) + r.Get("/hubs/{hubID}/users/", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hub3 override")) + }) + + sr3 := NewRouter() + sr3.Get("/", hAccountView1) + sr3.Get("/hi", hAccountView2) + + // var sr2 *Mux + (func(r Router) { + rg := r.(*routerGroup) // sr2 + // r.Get("/", hAccountView1) + rg.Mount("/", sr3) + })(r.Group("/accounts/{accountID}")) + + // This is the same as the r.Route() call mounted on sr2 + // sr2 := NewRouter() + // sr2.Mount("/", sr3) + // r.Mount("/accounts/{accountID}", sr2) + + ts := httptest.NewServer(r) + defer ts.Close() + + var body, expected string + + _, body = testRequest(t, ts, "GET", "/hubs/123/view", nil) + expected = "hub1" + if body != expected { + t.Fatalf("expected:%s got:%s", expected, body) + } + _, body = testRequest(t, ts, "GET", "/hubs/123/view/index.html", nil) + expected = "hub2" + if body != expected { + t.Fatalf("expected:%s got:%s", expected, body) + } + _, body = testRequest(t, ts, "GET", "/hubs/123/users", nil) + expected = "hub3" + if body != expected { + t.Fatalf("expected:%s got:%s", expected, body) + } + _, body = testRequest(t, ts, "GET", "/hubs/123/users/", nil) + expected = "hub3 override" + if body != expected { + t.Fatalf("expected:%s got:%s", expected, body) + } + _, body = testRequest(t, ts, "GET", "/accounts/44", nil) + expected = "account1" + if body != expected { + t.Fatalf("request:%s expected:%s got:%s", "GET /accounts/44", expected, body) + } + _, body = testRequest(t, ts, "GET", "/accounts/44/hi", nil) + expected = "account2" + if body != expected { + t.Fatalf("expected:%s got:%s", expected, body) + } + + // Test that we're building the routingPatterns properly + router := r + req, _ := http.NewRequest("GET", "/accounts/44/hi", nil) + + rctx := &RouteContext{} + req = req.WithContext(context.WithValue(req.Context(), routeContextKey{}, rctx)) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + body = w.Body.String() + expected = "account2" + if body != expected { + t.Fatalf("expected:%s got:%s", expected, body) + } + + routePatterns := rctx.routePatterns + if len(rctx.routePatterns) != 3 { + t.Fatalf("expected 3 routing patterns, got:%d", len(rctx.routePatterns)) + } + expected = "/accounts/{accountID}/*" + if routePatterns[0] != expected { + t.Fatalf("routePattern, expected:%s got:%s", expected, routePatterns[0]) + } + expected = "/*" + if routePatterns[1] != expected { + t.Fatalf("routePattern, expected:%s got:%s", expected, routePatterns[1]) + } + expected = "/hi" + if routePatterns[2] != expected { + t.Fatalf("routePattern, expected:%s got:%s", expected, routePatterns[2]) + } + +} + +func TestSingleHandler(t *testing.T) { + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + name := URLParam(r, "name") + w.Write([]byte("hi " + name)) + }) + + r, _ := http.NewRequest("GET", "/", nil) + rctx := &RouteContext{} + r = r.WithContext(context.WithValue(r.Context(), routeContextKey{}, rctx)) + rctx.URLParams.Add("name", "joe") + + w := httptest.NewRecorder() + h.ServeHTTP(w, r) + + body := w.Body.String() + expected := "hi joe" + if body != expected { + t.Fatalf("expected:%s got:%s", expected, body) + } +} + +// TODO: a Router wrapper test.. +// +// type ACLMux struct { +// *Mux +// XX string +// } +// +// func NewACLMux() *ACLMux { +// return &ACLMux{Mux: NewRouter(), XX: "hihi"} +// } +// +// // TODO: this should be supported... +// func TestWoot(t *testing.T) { +// var r Router = NewRouter() +// +// var r2 Router = NewACLMux() //NewRouter() +// r2.Get("/hi", func(w http.ResponseWriter, r *http.Request) { +// w.Write([]byte("hi")) +// }) +// +// r.Mount("/", r2) +// } + +func TestServeHTTPExistingContext(t *testing.T) { + r := NewRouter() + r.Get("/hi", func(w http.ResponseWriter, r *http.Request) { + s, _ := r.Context().Value(ctxKey{"testCtx"}).(string) + w.Write([]byte(s)) + }) + r.NotFound(func(w http.ResponseWriter, r *http.Request) { + s, _ := r.Context().Value(ctxKey{"testCtx"}).(string) + w.WriteHeader(404) + w.Write([]byte(s)) + }) + + testcases := []struct { + Ctx context.Context + Method string + Path string + ExpectedBody string + ExpectedStatus int + }{ + { + Method: "GET", + Path: "/hi", + Ctx: context.WithValue(context.Background(), ctxKey{"testCtx"}, "hi ctx"), + ExpectedStatus: 200, + ExpectedBody: "hi ctx", + }, + { + Method: "GET", + Path: "/hello", + Ctx: context.WithValue(context.Background(), ctxKey{"testCtx"}, "nothing here ctx"), + ExpectedStatus: 404, + ExpectedBody: "nothing here ctx", + }, + } + + for _, tc := range testcases { + resp := httptest.NewRecorder() + req, err := http.NewRequest(tc.Method, tc.Path, nil) + if err != nil { + t.Fatalf("%v", err) + } + req = req.WithContext(tc.Ctx) + r.ServeHTTP(resp, req) + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("%v", err) + } + if resp.Code != tc.ExpectedStatus { + t.Fatalf("%v != %v", tc.ExpectedStatus, resp.Code) + } + if string(b) != tc.ExpectedBody { + t.Fatalf("%s != %s", tc.ExpectedBody, b) + } + } +} + +func TestMiddlewarePanicOnLateUse(t *testing.T) { + handler := func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hello\n")) + } + + mw := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + next.ServeHTTP(w, r) + }) + } + + defer func() { + if recover() == nil { + t.Error("expected panic()") + } + }() + + r := NewRouter() + r.Get("/", handler) + r.Use(mw) // Too late to apply middleware, we're expecting panic(). +} + +func TestMountingExistingPath(t *testing.T) { + handler := func(w http.ResponseWriter, r *http.Request) {} + + defer func() { + if recover() == nil { + t.Error("expected panic()") + } + }() + + r := NewRouter().(*routerGroup) + r.Get("/", handler) + r.Mount("/hi", http.HandlerFunc(handler)) + r.Mount("/hi", http.HandlerFunc(handler)) +} + +func TestMountingSimilarPattern(t *testing.T) { + r := NewRouter().(*routerGroup) + r.Get("/hi", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("bye")) + }) + + r2 := NewRouter() + r2.Get("/", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("foobar")) + }) + + r3 := NewRouter() + r3.Get("/", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("foo")) + }) + + r.Mount("/foobar", r2) + r.Mount("/foo", r3) + + ts := httptest.NewServer(r) + defer ts.Close() + + if _, body := testRequest(t, ts, "GET", "/hi", nil); body != "bye" { + t.Fatalf(body) + } +} + +func TestMuxEmptyParams(t *testing.T) { + r := NewRouter() + r.Get(`/users/{x}/{y}/{z}`, func(w http.ResponseWriter, r *http.Request) { + x := URLParam(r, "x") + y := URLParam(r, "y") + z := URLParam(r, "z") + w.Write([]byte(fmt.Sprintf("%s-%s-%s", x, y, z))) + }) + + ts := httptest.NewServer(r) + defer ts.Close() + + if _, body := testRequest(t, ts, "GET", "/users/a/b/c", nil); body != "a-b-c" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/users///c", nil); body != "--c" { + t.Fatalf(body) + } +} + +func TestMuxMissingParams(t *testing.T) { + r := NewRouter() + r.Get(`/user/{userId:\d+}`, func(w http.ResponseWriter, r *http.Request) { + userID := URLParam(r, "userId") + w.Write([]byte(fmt.Sprintf("userId = '%s'", userID))) + }) + r.NotFound(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + w.Write([]byte("nothing here")) + }) + + ts := httptest.NewServer(r) + defer ts.Close() + + if _, body := testRequest(t, ts, "GET", "/user/123", nil); body != "userId = '123'" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/user/", nil); body != "nothing here" { + t.Fatalf(body) + } +} + +func TestMuxWildcardRoute(t *testing.T) { + handler := func(w http.ResponseWriter, r *http.Request) {} + + defer func() { + if recover() == nil { + t.Error("expected panic()") + } + }() + + r := NewRouter() + r.Get("/*/wildcard/must/be/at/end", handler) +} + +func TestMuxWildcardRouteCheckTwo(t *testing.T) { + handler := func(w http.ResponseWriter, r *http.Request) {} + + defer func() { + if recover() == nil { + t.Error("expected panic()") + } + }() + + r := NewRouter() + r.Get("/*/wildcard/{must}/be/at/end", handler) +} + +func TestMuxRegexp(t *testing.T) { + r := NewRouter() + r.Group("/{param:[0-9]*}/test", func(r Router) { + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(fmt.Sprintf("Hi: %s", URLParam(r, "param")))) + }) + }) + + ts := httptest.NewServer(r) + defer ts.Close() + + if _, body := testRequest(t, ts, "GET", "//test", nil); body != "Hi: " { + t.Fatalf(body) + } +} + +func TestMuxRegexp2(t *testing.T) { + r := NewRouter() + r.Get("/foo-{suffix:[a-z]{2,3}}.json", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(URLParam(r, "suffix"))) + }) + ts := httptest.NewServer(r) + defer ts.Close() + + if _, body := testRequest(t, ts, "GET", "/foo-.json", nil); body != "" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/foo-abc.json", nil); body != "abc" { + t.Fatalf(body) + } +} + +func TestMuxRegexp3(t *testing.T) { + r := NewRouter() + r.Get("/one/{firstId:[a-z0-9-]+}/{secondId:[a-z]+}/first", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("first")) + }) + r.Get("/one/{firstId:[a-z0-9-_]+}/{secondId:[0-9]+}/second", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("second")) + }) + r.Delete("/one/{firstId:[a-z0-9-_]+}/{secondId:[0-9]+}/second", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("third")) + }) + + (func(r Router) { + r.Get("/{dns:[a-z-0-9_]+}", func(writer http.ResponseWriter, request *http.Request) { + writer.Write([]byte("_")) + }) + r.Get("/{dns:[a-z-0-9_]+}/info", func(writer http.ResponseWriter, request *http.Request) { + writer.Write([]byte("_")) + }) + r.Delete("/{id:[0-9]+}", func(writer http.ResponseWriter, request *http.Request) { + writer.Write([]byte("forth")) + }) + })(r.Group("/one")) + + ts := httptest.NewServer(r) + defer ts.Close() + + if _, body := testRequest(t, ts, "GET", "/one/hello/peter/first", nil); body != "first" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/one/hithere/123/second", nil); body != "second" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "DELETE", "/one/hithere/123/second", nil); body != "third" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "DELETE", "/one/123", nil); body != "forth" { + t.Fatalf(body) + } +} + +func TestMuxSubrouterWildcardParam(t *testing.T) { + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "param:%v *:%v", URLParam(r, "param"), URLParam(r, "*")) + }) + + r := NewRouter() + + r.Get("/bare/{param}", h) + r.Get("/bare/{param}/*", h) + + (func(r Router) { + r.Get("/{param}", h) + r.Get("/{param}/*", h) + })(r.Group("/case0")) + + ts := httptest.NewServer(r) + defer ts.Close() + + if _, body := testRequest(t, ts, "GET", "/bare/hi", nil); body != "param:hi *:" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/bare/hi/yes", nil); body != "param:hi *:yes" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/case0/hi", nil); body != "param:hi *:" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/case0/hi/yes", nil); body != "param:hi *:yes" { + t.Fatalf(body) + } +} + +func TestMuxContextIsThreadSafe(t *testing.T) { + router := NewRouter() + router.Get("/{id}", func(w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithTimeout(r.Context(), 1*time.Millisecond) + defer cancel() + + <-ctx.Done() + }) + + wg := sync.WaitGroup{} + + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 10000; j++ { + w := httptest.NewRecorder() + r, err := http.NewRequest("GET", "/ok", nil) + if err != nil { + t.Error(err) + return + } + + ctx, cancel := context.WithCancel(r.Context()) + r = r.WithContext(ctx) + + go func() { + cancel() + }() + router.ServeHTTP(w, r) + } + }() + } + wg.Wait() +} + +func TestEscapedURLParams(t *testing.T) { + m := NewRouter() + m.Get("/api/{identifier}/{region}/{size}/{rotation}/*", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + rctx := FromRouteContext(r.Context()) + if rctx == nil { + t.Error("no context") + return + } + identifier := URLParam(r, "identifier") + if identifier != "http:%2f%2fexample.com%2fimage.png" { + t.Errorf("identifier path parameter incorrect %s", identifier) + return + } + region := URLParam(r, "region") + if region != "full" { + t.Errorf("region path parameter incorrect %s", region) + return + } + size := URLParam(r, "size") + if size != "max" { + t.Errorf("size path parameter incorrect %s", size) + return + } + rotation := URLParam(r, "rotation") + if rotation != "0" { + t.Errorf("rotation path parameter incorrect %s", rotation) + return + } + w.Write([]byte("success")) + }) + + ts := httptest.NewServer(m) + defer ts.Close() + + if _, body := testRequest(t, ts, "GET", "/api/http:%2f%2fexample.com%2fimage.png/full/max/0/color.png", nil); body != "success" { + t.Fatalf(body) + } +} + +func TestMuxMatch(t *testing.T) { + r := NewRouter() + r.Get("/hi", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Test", "yes") + w.Write([]byte("bye")) + }) + (func(r Router) { + r.Get("/{id}", func(w http.ResponseWriter, r *http.Request) { + id := URLParam(r, "id") + w.Header().Set("X-Article", id) + w.Write([]byte("article:" + id)) + }) + })(r.Group("/articles")) + (func(r Router) { + r.Head("/{id}", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-User", "-") + w.Write([]byte("user")) + }) + r.Get("/{id}", func(w http.ResponseWriter, r *http.Request) { + id := URLParam(r, "id") + w.Header().Set("X-User", id) + w.Write([]byte("user:" + id)) + }) + })(r.Group("/users")) + + tctx := &RouteContext{} + + tctx.Reset() + if r.(Routes).Match(tctx, "GET", "/users/1") == false { + t.Fatal("expecting to find match for route:", "GET", "/users/1") + } + + tctx.Reset() + if r.(Routes).Match(tctx, "HEAD", "/articles/10") == true { + t.Fatal("not expecting to find match for route:", "HEAD", "/articles/10") + } +} + +func TestServerBaseContext(t *testing.T) { + r := NewRouter() + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + baseYes := r.Context().Value(ctxKey{"base"}).(string) + if _, ok := r.Context().Value(http.ServerContextKey).(*http.Server); !ok { + panic("missing server context") + } + if _, ok := r.Context().Value(http.LocalAddrContextKey).(net.Addr); !ok { + panic("missing local addr context") + } + w.Write([]byte(baseYes)) + }) + + // Setup http Server with a base context + ctx := context.WithValue(context.Background(), ctxKey{"base"}, "yes") + ts := httptest.NewUnstartedServer(r) + ts.Config.BaseContext = func(_ net.Listener) context.Context { + return ctx + } + ts.Start() + + defer ts.Close() + + if _, body := testRequest(t, ts, "GET", "/", nil); body != "yes" { + t.Fatalf(body) + } +} + +func testRequest(t *testing.T, ts *httptest.Server, method, path string, body io.Reader) (*http.Response, string) { + req, err := http.NewRequest(method, ts.URL+path, body) + if err != nil { + t.Fatal(err) + return nil, "" + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + return nil, "" + } + + respBody, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + return nil, "" + } + defer resp.Body.Close() + + return resp, string(respBody) +} + +func testHandler(t *testing.T, h http.Handler, method, path string, body io.Reader) (*http.Response, string) { + r, _ := http.NewRequest(method, path, body) + w := httptest.NewRecorder() + h.ServeHTTP(w, r) + return w.Result(), w.Body.String() +} + +type ctxKey struct { + name string +} + +func (k ctxKey) String() string { + return "context value " + k.name +} + +func BenchmarkMux(b *testing.B) { + h1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + h2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + h3 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + h4 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + h5 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + h6 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + + mx := NewRouter() + mx.Get("/", h1) + mx.Get("/hi", h2) + mx.Post("/hi-post", h2) // used to benchmark 405 responses + mx.Get("/sup/{id}/and/{this}", h3) + mx.Get("/sup/{id}/{bar:foo}/{this}", h3) + + mx.Group("/sharing/{x}/{hash}", func(mx Router) { + mx.Get("/", h4) // subrouter-1 + mx.Get("/{network}", h5) // subrouter-1 + mx.Get("/twitter", h5) + mx.Group("/direct", func(mx Router) { + mx.Get("/", h6) // subrouter-2 + mx.Get("/download", h6) + }) + }) + + routes := []string{ + "/", + "/hi", + "/hi-post", + "/sup/123/and/this", + "/sup/123/foo/this", + "/sharing/z/aBc", // subrouter-1 + "/sharing/z/aBc/twitter", // subrouter-1 + "/sharing/z/aBc/direct", // subrouter-2 + "/sharing/z/aBc/direct/download", // subrouter-2 + } + + for _, path := range routes { + b.Run("route:"+path, func(b *testing.B) { + w := httptest.NewRecorder() + r, _ := http.NewRequest("GET", path, nil) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + mx.ServeHTTP(w, r) + } + }) + } +} diff --git a/server.go b/server.go new file mode 100644 index 0000000..8893c43 --- /dev/null +++ b/server.go @@ -0,0 +1,86 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package web + +import ( + "context" + "net/http" +) + +// A Server defines parameters for running an HTTP server. +type Server struct { + options Options + httpSvr *http.Server + Router +} + +// NewServer returns a new server instance. +func NewServer(options Options) *Server { + + var addr = options.Addr + if 0 == len(addr) { + addr = ":8080" // default port: 8080 + } + + var router = options.Router + if nil == router { + router = NewRouter() + } + + svr := &Server{ + options: options, + httpSvr: &http.Server{ + Addr: addr, + Handler: router, + TLSConfig: options.TlsConfig(), + ReadTimeout: options.ReadTimeout, + ReadHeaderTimeout: options.ReadHeaderTimeout, + WriteTimeout: options.WriteTimeout, + IdleTimeout: options.IdleTimeout, + MaxHeaderBytes: options.MaxHeaderBytes, + }, + Router: router, + } + + return svr +} + +// Addr returns the server listen address. +func (s *Server) Addr() string { + return s.httpSvr.Addr +} + +// Run listens on the TCP network address Addr and then +// calls Serve to handle requests on incoming connections. +// Accepted connections are configured to enable TCP keep-alives. +func (s *Server) Run() error { + if nil != s.httpSvr.TLSConfig { + return s.httpSvr.ListenAndServeTLS(s.options.CertFile, s.options.KeyFile) + } + return s.httpSvr.ListenAndServe() +} + +// Shutdown gracefully shuts down the server without interrupting any +// active connections. Shutdown works by first closing all open +// listeners, then closing all idle connections, and then waiting +// indefinitely for connections to return to idle and then shut down. +// If the provided context expires before the shutdown is complete, +// Shutdown returns the context's error, otherwise it returns any +// error returned from closing the Server's underlying Listener(s). +func (s *Server) Shutdown(ctx context.Context) error { + return s.httpSvr.Shutdown(ctx) +} diff --git a/tree.go b/tree.go new file mode 100644 index 0000000..7f4bd02 --- /dev/null +++ b/tree.go @@ -0,0 +1,872 @@ +package web + +// Radix tree implementation below is a based on the original work by +// Armon Dadgar in https://github.com/armon/go-radix/blob/master/radix.go +// (MIT licensed). It's been heavily modified for use as a HTTP routing tree. + +import ( + "fmt" + "net/http" + "regexp" + "sort" + "strings" +) + +type methodTyp uint + +const ( + mSTUB methodTyp = 1 << iota + mCONNECT + mDELETE + mGET + mHEAD + mOPTIONS + mPATCH + mPOST + mPUT + mTRACE + + mALL = mCONNECT | mDELETE | mGET | mHEAD | + mOPTIONS | mPATCH | mPOST | mPUT | mTRACE +) + +var methodMap = map[string]methodTyp{ + http.MethodConnect: mCONNECT, + http.MethodDelete: mDELETE, + http.MethodGet: mGET, + http.MethodHead: mHEAD, + http.MethodOptions: mOPTIONS, + http.MethodPatch: mPATCH, + http.MethodPost: mPOST, + http.MethodPut: mPUT, + http.MethodTrace: mTRACE, +} + +var reverseMethodMap = map[methodTyp]string{ + mCONNECT: http.MethodConnect, + mDELETE: http.MethodDelete, + mGET: http.MethodGet, + mHEAD: http.MethodHead, + mOPTIONS: http.MethodOptions, + mPATCH: http.MethodPatch, + mPOST: http.MethodPost, + mPUT: http.MethodPut, + mTRACE: http.MethodTrace, +} + +type nodeTyp uint8 + +const ( + ntStatic nodeTyp = iota // /home + ntRegexp // /{id:[0-9]+} + ntParam // /{user} + ntCatchAll // /api/v1/* +) + +type node struct { + // subroutes on the leaf node + subroutes Routes + + // regexp matcher for regexp nodes + rex *regexp.Regexp + + // HTTP handler endpoints on the leaf node + endpoints endpoints + + // prefix is the common prefix we ignore + prefix string + + // child nodes should be stored in-order for iteration, + // in groups of the node type. + children [ntCatchAll + 1]nodes + + // first byte of the child prefix + tail byte + + // node type: static, regexp, param, catchAll + typ nodeTyp + + // first byte of the prefix + label byte +} + +// endpoints is a mapping of http method constants to handlers +// for a given route. +type endpoints map[methodTyp]*endpoint + +type endpoint struct { + // endpoint handler + handler http.Handler + + // pattern is the routing pattern for handler nodes + pattern string + + // parameter keys recorded on handler nodes + paramKeys []string +} + +func (s endpoints) Value(method methodTyp) *endpoint { + mh, ok := s[method] + if !ok { + mh = &endpoint{} + s[method] = mh + } + return mh +} + +func (n *node) InsertRoute(method methodTyp, pattern string, handler http.Handler) *node { + var parent *node + search := pattern + + for { + // Handle key exhaustion + if len(search) == 0 { + // Insert or update the node's leaf handler + n.setEndpoint(method, handler, pattern) + return n + } + + // We're going to be searching for a wild node next, + // in this case, we need to get the tail + var label = search[0] + var segTail byte + var segEndIdx int + var segTyp nodeTyp + var segRexpat string + if label == '{' || label == '*' { + segTyp, _, segRexpat, segTail, _, segEndIdx = patNextSegment(search) + } + + var prefix string + if segTyp == ntRegexp { + prefix = segRexpat + } + + // Look for the edge to attach to + parent = n + n = n.getEdge(segTyp, label, segTail, prefix) + + // No edge, create one + if n == nil { + child := &node{label: label, tail: segTail, prefix: search} + hn := parent.addChild(child, search) + hn.setEndpoint(method, handler, pattern) + + return hn + } + + // Found an edge to match the pattern + + if n.typ > ntStatic { + // We found a param node, trim the param from the search path and continue. + // This param/wild pattern segment would already be on the tree from a previous + // call to addChild when creating a new node. + search = search[segEndIdx:] + continue + } + + // Static nodes fall below here. + // Determine longest prefix of the search key on match. + commonPrefix := longestPrefix(search, n.prefix) + if commonPrefix == len(n.prefix) { + // the common prefix is as long as the current node's prefix we're attempting to insert. + // keep the search going. + search = search[commonPrefix:] + continue + } + + // Split the node + child := &node{ + typ: ntStatic, + prefix: search[:commonPrefix], + } + parent.replaceChild(search[0], segTail, child) + + // Restore the existing node + n.label = n.prefix[commonPrefix] + n.prefix = n.prefix[commonPrefix:] + child.addChild(n, n.prefix) + + // If the new key is a subset, set the method/handler on this node and finish. + search = search[commonPrefix:] + if len(search) == 0 { + child.setEndpoint(method, handler, pattern) + return child + } + + // Create a new edge for the node + subchild := &node{ + typ: ntStatic, + label: search[0], + prefix: search, + } + hn := child.addChild(subchild, search) + hn.setEndpoint(method, handler, pattern) + return hn + } +} + +// addChild appends the new `child` node to the tree using the `pattern` as the trie key. +// For a URL router, we split the static, param, regexp and wildcard segments +// into different nodes. In addition, addChild will recursively call itself until every +// pattern segment is added to the url pattern tree as individual nodes, depending on type. +func (n *node) addChild(child *node, prefix string) *node { + search := prefix + + // handler leaf node added to the tree is the child. + // this may be overridden later down the flow + hn := child + + // Parse next segment + segTyp, _, segRexpat, segTail, segStartIdx, segEndIdx := patNextSegment(search) + + // Add child depending on next up segment + switch segTyp { + + case ntStatic: + // Search prefix is all static (that is, has no params in path) + // noop + + default: + // Search prefix contains a param, regexp or wildcard + + if segTyp == ntRegexp { + rex, err := regexp.Compile(segRexpat) + if err != nil { + panic(fmt.Sprintf("invalid regexp pattern '%s' in route param", segRexpat)) + } + child.prefix = segRexpat + child.rex = rex + } + + if segStartIdx == 0 { + // Route starts with a param + child.typ = segTyp + + if segTyp == ntCatchAll { + segStartIdx = -1 + } else { + segStartIdx = segEndIdx + } + if segStartIdx < 0 { + segStartIdx = len(search) + } + child.tail = segTail // for params, we set the tail + + if segStartIdx != len(search) { + // add static edge for the remaining part, split the end. + // its not possible to have adjacent param nodes, so its certainly + // going to be a static node next. + + search = search[segStartIdx:] // advance search position + + nn := &node{ + typ: ntStatic, + label: search[0], + prefix: search, + } + hn = child.addChild(nn, search) + } + + } else if segStartIdx > 0 { + // Route has some param + + // starts with a static segment + child.typ = ntStatic + child.prefix = search[:segStartIdx] + child.rex = nil + + // add the param edge node + search = search[segStartIdx:] + + nn := &node{ + typ: segTyp, + label: search[0], + tail: segTail, + } + hn = child.addChild(nn, search) + + } + } + + n.children[child.typ] = append(n.children[child.typ], child) + n.children[child.typ].Sort() + return hn +} + +func (n *node) replaceChild(label, tail byte, child *node) { + for i := 0; i < len(n.children[child.typ]); i++ { + if n.children[child.typ][i].label == label && n.children[child.typ][i].tail == tail { + n.children[child.typ][i] = child + n.children[child.typ][i].label = label + n.children[child.typ][i].tail = tail + return + } + } + panic("replacing missing child") +} + +func (n *node) getEdge(ntyp nodeTyp, label, tail byte, prefix string) *node { + nds := n.children[ntyp] + for i := 0; i < len(nds); i++ { + if nds[i].label == label && nds[i].tail == tail { + if ntyp == ntRegexp && nds[i].prefix != prefix { + continue + } + return nds[i] + } + } + return nil +} + +func (n *node) setEndpoint(method methodTyp, handler http.Handler, pattern string) { + // Set the handler for the method type on the node + if n.endpoints == nil { + n.endpoints = make(endpoints) + } + + paramKeys := patParamKeys(pattern) + + if method&mSTUB == mSTUB { + n.endpoints.Value(mSTUB).handler = handler + } + if method&mALL == mALL { + h := n.endpoints.Value(mALL) + h.handler = handler + h.pattern = pattern + h.paramKeys = paramKeys + for _, m := range methodMap { + h := n.endpoints.Value(m) + h.handler = handler + h.pattern = pattern + h.paramKeys = paramKeys + } + } else { + h := n.endpoints.Value(method) + h.handler = handler + h.pattern = pattern + h.paramKeys = paramKeys + } +} + +func (n *node) FindRoute(rctx *RouteContext, method methodTyp, path string) (*node, endpoints, http.Handler) { + // Reset the context routing pattern and params + rctx.RoutePattern = "" + rctx.routeParams.Keys = rctx.routeParams.Keys[:0] + rctx.routeParams.Values = rctx.routeParams.Values[:0] + + // Find the routing handlers for the path + rn := n.findRoute(rctx, method, path) + if rn == nil { + return nil, nil, nil + } + + // Record the routing params in the request lifecycle + rctx.URLParams.Keys = append(rctx.URLParams.Keys, rctx.routeParams.Keys...) + rctx.URLParams.Values = append(rctx.URLParams.Values, rctx.routeParams.Values...) + + // Record the routing pattern in the request lifecycle + if rn.endpoints[method].pattern != "" { + rctx.RoutePattern = rn.endpoints[method].pattern + rctx.routePatterns = append(rctx.routePatterns, rctx.RoutePattern) + } + + return rn, rn.endpoints, rn.endpoints[method].handler +} + +// Recursive edge traversal by checking all nodeTyp groups along the way. +// It's like searching through a multi-dimensional radix trie. +func (n *node) findRoute(rctx *RouteContext, method methodTyp, path string) *node { + nn := n + search := path + + for t, nds := range nn.children { + ntyp := nodeTyp(t) + if len(nds) == 0 { + continue + } + + var xn *node + xsearch := search + + var label byte + if search != "" { + label = search[0] + } + + switch ntyp { + case ntStatic: + xn = nds.findEdge(label) + if xn == nil || !strings.HasPrefix(xsearch, xn.prefix) { + continue + } + xsearch = xsearch[len(xn.prefix):] + + case ntParam, ntRegexp: + // short-circuit and return no matching route for empty param values + if xsearch == "" { + continue + } + + // serially loop through each node grouped by the tail delimiter + for idx := 0; idx < len(nds); idx++ { + xn = nds[idx] + + // label for param nodes is the delimiter byte + p := strings.IndexByte(xsearch, xn.tail) + + if p < 0 { + if xn.tail == '/' { + p = len(xsearch) + } else { + continue + } + } else if ntyp == ntRegexp && p == 0 { + continue + } + + if ntyp == ntRegexp && xn.rex != nil { + if !xn.rex.MatchString(xsearch[:p]) { + continue + } + } else if strings.IndexByte(xsearch[:p], '/') != -1 { + // avoid a match across path segments + continue + } + + prevlen := len(rctx.routeParams.Values) + rctx.routeParams.Values = append(rctx.routeParams.Values, xsearch[:p]) + xsearch = xsearch[p:] + + if len(xsearch) == 0 { + if xn.isLeaf() { + h := xn.endpoints[method] + if h != nil && h.handler != nil { + rctx.routeParams.Keys = append(rctx.routeParams.Keys, h.paramKeys...) + return xn + } + + for endpoints := range xn.endpoints { + if endpoints == mALL || endpoints == mSTUB { + continue + } + rctx.methodsAllowed = append(rctx.methodsAllowed, endpoints) + } + + // flag that the routing context found a route, but not a corresponding + // supported method + rctx.methodNotAllowed = true + } + } + + // recursively find the next node on this branch + fin := xn.findRoute(rctx, method, xsearch) + if fin != nil { + return fin + } + + // not found on this branch, reset vars + rctx.routeParams.Values = rctx.routeParams.Values[:prevlen] + xsearch = search + } + + rctx.routeParams.Values = append(rctx.routeParams.Values, "") + + default: + // catch-all nodes + rctx.routeParams.Values = append(rctx.routeParams.Values, search) + xn = nds[0] + xsearch = "" + } + + if xn == nil { + continue + } + + // did we find it yet? + if len(xsearch) == 0 { + if xn.isLeaf() { + h := xn.endpoints[method] + if h != nil && h.handler != nil { + rctx.routeParams.Keys = append(rctx.routeParams.Keys, h.paramKeys...) + return xn + } + + for endpoints := range xn.endpoints { + if endpoints == mALL || endpoints == mSTUB { + continue + } + rctx.methodsAllowed = append(rctx.methodsAllowed, endpoints) + } + + // flag that the routing context found a route, but not a corresponding + // supported method + rctx.methodNotAllowed = true + } + } + + // recursively find the next node.. + fin := xn.findRoute(rctx, method, xsearch) + if fin != nil { + return fin + } + + // Did not find final handler, let's remove the param here if it was set + if xn.typ > ntStatic { + if len(rctx.routeParams.Values) > 0 { + rctx.routeParams.Values = rctx.routeParams.Values[:len(rctx.routeParams.Values)-1] + } + } + + } + + return nil +} + +func (n *node) findEdge(ntyp nodeTyp, label byte) *node { + nds := n.children[ntyp] + num := len(nds) + idx := 0 + + switch ntyp { + case ntStatic, ntParam, ntRegexp: + i, j := 0, num-1 + for i <= j { + idx = i + (j-i)/2 + if label > nds[idx].label { + i = idx + 1 + } else if label < nds[idx].label { + j = idx - 1 + } else { + i = num // breaks cond + } + } + if nds[idx].label != label { + return nil + } + return nds[idx] + + default: // catch all + return nds[idx] + } +} + +func (n *node) isLeaf() bool { + return n.endpoints != nil +} + +func (n *node) findPattern(pattern string) bool { + nn := n + for _, nds := range nn.children { + if len(nds) == 0 { + continue + } + + n = nn.findEdge(nds[0].typ, pattern[0]) + if n == nil { + continue + } + + var idx int + var xpattern string + + switch n.typ { + case ntStatic: + idx = longestPrefix(pattern, n.prefix) + if idx < len(n.prefix) { + continue + } + + case ntParam, ntRegexp: + idx = strings.IndexByte(pattern, '}') + 1 + + case ntCatchAll: + idx = longestPrefix(pattern, "*") + + default: + panic("unknown node type") + } + + xpattern = pattern[idx:] + if len(xpattern) == 0 { + return true + } + + return n.findPattern(xpattern) + } + return false +} + +func (n *node) routes() []Route { + rts := []Route{} + + n.walk(func(eps endpoints, subroutes Routes) bool { + if eps[mSTUB] != nil && eps[mSTUB].handler != nil && subroutes == nil { + return false + } + + // Group methodHandlers by unique patterns + pats := make(map[string]endpoints) + + for mt, h := range eps { + if h.pattern == "" { + continue + } + p, ok := pats[h.pattern] + if !ok { + p = endpoints{} + pats[h.pattern] = p + } + p[mt] = h + } + + for p, mh := range pats { + hs := make(map[string]http.Handler) + if mh[mALL] != nil && mh[mALL].handler != nil { + hs["*"] = mh[mALL].handler + } + + for mt, h := range mh { + if h.handler == nil { + continue + } + m := methodTypString(mt) + if m == "" { + continue + } + hs[m] = h.handler + } + + rt := Route{subroutes, hs, p} + rts = append(rts, rt) + } + + return false + }) + + return rts +} + +func (n *node) walk(fn func(eps endpoints, subroutes Routes) bool) bool { + // Visit the leaf values if any + if (n.endpoints != nil || n.subroutes != nil) && fn(n.endpoints, n.subroutes) { + return true + } + + // Recurse on the children + for _, ns := range n.children { + for _, cn := range ns { + if cn.walk(fn) { + return true + } + } + } + return false +} + +// patNextSegment returns the next segment details from a pattern: +// node type, param key, regexp string, param tail byte, param starting index, param ending index +func patNextSegment(pattern string) (nodeTyp, string, string, byte, int, int) { + ps := strings.Index(pattern, "{") + ws := strings.Index(pattern, "*") + + if ps < 0 && ws < 0 { + return ntStatic, "", "", 0, 0, len(pattern) // we return the entire thing + } + + // Sanity check + if ps >= 0 && ws >= 0 && ws < ps { + panic("wildcard '*' must be the last pattern in a route, otherwise use a '{param}'") + } + + var tail byte = '/' // Default endpoint tail to / byte + + if ps >= 0 { + // Param/Regexp pattern is next + nt := ntParam + + // Read to closing } taking into account opens and closes in curl count (cc) + cc := 0 + pe := ps + for i, c := range pattern[ps:] { + if c == '{' { + cc++ + } else if c == '}' { + cc-- + if cc == 0 { + pe = ps + i + break + } + } + } + if pe == ps { + panic("route param closing delimiter '}' is missing") + } + + key := pattern[ps+1 : pe] + pe++ // set end to next position + + if pe < len(pattern) { + tail = pattern[pe] + } + + var rexpat string + if idx := strings.Index(key, ":"); idx >= 0 { + nt = ntRegexp + rexpat = key[idx+1:] + key = key[:idx] + } + + if len(rexpat) > 0 { + if rexpat[0] != '^' { + rexpat = "^" + rexpat + } + if rexpat[len(rexpat)-1] != '$' { + rexpat += "$" + } + } + + return nt, key, rexpat, tail, ps, pe + } + + // Wildcard pattern as finale + if ws < len(pattern)-1 { + panic("wildcard '*' must be the last value in a route. trim trailing text or use a '{param}' instead") + } + return ntCatchAll, "*", "", 0, ws, len(pattern) +} + +func patParamKeys(pattern string) []string { + pat := pattern + paramKeys := []string{} + for { + ptyp, paramKey, _, _, _, e := patNextSegment(pat) + if ptyp == ntStatic { + return paramKeys + } + for i := 0; i < len(paramKeys); i++ { + if paramKeys[i] == paramKey { + panic(fmt.Sprintf("routing pattern '%s' contains duplicate param key, '%s'", pattern, paramKey)) + } + } + paramKeys = append(paramKeys, paramKey) + pat = pat[e:] + } +} + +// longestPrefix finds the length of the shared prefix +// of two strings +func longestPrefix(k1, k2 string) int { + max := len(k1) + if l := len(k2); l < max { + max = l + } + var i int + for i = 0; i < max; i++ { + if k1[i] != k2[i] { + break + } + } + return i +} + +func methodTypString(method methodTyp) string { + for s, t := range methodMap { + if method == t { + return s + } + } + return "" +} + +type nodes []*node + +// Sort the list of nodes by label +func (ns nodes) Sort() { sort.Sort(ns); ns.tailSort() } +func (ns nodes) Len() int { return len(ns) } +func (ns nodes) Swap(i, j int) { ns[i], ns[j] = ns[j], ns[i] } +func (ns nodes) Less(i, j int) bool { return ns[i].label < ns[j].label } + +// tailSort pushes nodes with '/' as the tail to the end of the list for param nodes. +// The list order determines the traversal order. +func (ns nodes) tailSort() { + for i := len(ns) - 1; i >= 0; i-- { + if ns[i].typ > ntStatic && ns[i].tail == '/' { + ns.Swap(i, len(ns)-1) + return + } + } +} + +func (ns nodes) findEdge(label byte) *node { + num := len(ns) + idx := 0 + i, j := 0, num-1 + for i <= j { + idx = i + (j-i)/2 + if label > ns[idx].label { + i = idx + 1 + } else if label < ns[idx].label { + j = idx - 1 + } else { + i = num // breaks cond + } + } + if ns[idx].label != label { + return nil + } + return ns[idx] +} + +// Route describes the details of a routing handler. +// Handlers map key is an HTTP method +type Route struct { + SubRoutes Routes + Handlers map[string]http.Handler + Pattern string +} + +// WalkFunc is the type of the function called for each method and route visited by Walk. +type WalkFunc func(method string, route string, handler http.Handler, middlewares ...func(http.Handler) http.Handler) error + +// Walk walks any router tree that implements Routes interface. +func Walk(r Routes, walkFn WalkFunc) error { + return walk(r, walkFn, "") +} + +func walk(r Routes, walkFn WalkFunc, parentRoute string, parentMw ...func(http.Handler) http.Handler) error { + for _, route := range r.Routes() { + mws := make([]func(http.Handler) http.Handler, len(parentMw)) + copy(mws, parentMw) + mws = append(mws, r.Middlewares()...) + + if route.SubRoutes != nil { + if err := walk(route.SubRoutes, walkFn, parentRoute+route.Pattern, mws...); err != nil { + return err + } + continue + } + + for method, handler := range route.Handlers { + if method == "*" { + // Ignore a "catchAll" method, since we pass down all the specific methods for each route. + continue + } + + fullRoute := parentRoute + route.Pattern + fullRoute = strings.Replace(fullRoute, "/*/", "/", -1) + + if chain, ok := handler.(*chainHandler); ok { + if err := walkFn(method, fullRoute, chain.Endpoint, append(mws, chain.Middlewares...)...); err != nil { + return err + } + } else { + if err := walkFn(method, fullRoute, handler, mws...); err != nil { + return err + } + } + } + } + + return nil +} diff --git a/tree_test.go b/tree_test.go new file mode 100644 index 0000000..c350c9c --- /dev/null +++ b/tree_test.go @@ -0,0 +1,509 @@ +package web + +import ( + "fmt" + "log" + "net/http" + "testing" +) + +func TestTree(t *testing.T) { + hStub := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hIndex := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hFavicon := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hArticleList := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hArticleNear := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hArticleShow := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hArticleShowRelated := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hArticleShowOpts := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hArticleSlug := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hArticleByUser := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hUserList := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hUserShow := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hAdminCatchall := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hAdminAppShow := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hAdminAppShowCatchall := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hUserProfile := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hUserSuper := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hUserAll := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hHubView1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hHubView2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hHubView3 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + + tr := &node{} + + tr.InsertRoute(mGET, "/", hIndex) + tr.InsertRoute(mGET, "/favicon.ico", hFavicon) + + tr.InsertRoute(mGET, "/pages/*", hStub) + + tr.InsertRoute(mGET, "/article", hArticleList) + tr.InsertRoute(mGET, "/article/", hArticleList) + + tr.InsertRoute(mGET, "/article/near", hArticleNear) + tr.InsertRoute(mGET, "/article/{id}", hStub) + tr.InsertRoute(mGET, "/article/{id}", hArticleShow) + tr.InsertRoute(mGET, "/article/{id}", hArticleShow) // duplicate will have no effect + tr.InsertRoute(mGET, "/article/@{user}", hArticleByUser) + + tr.InsertRoute(mGET, "/article/{sup}/{opts}", hArticleShowOpts) + tr.InsertRoute(mGET, "/article/{id}/{opts}", hArticleShowOpts) // overwrite above route, latest wins + + tr.InsertRoute(mGET, "/article/{iffd}/edit", hStub) + tr.InsertRoute(mGET, "/article/{id}//related", hArticleShowRelated) + tr.InsertRoute(mGET, "/article/slug/{month}/-/{day}/{year}", hArticleSlug) + + tr.InsertRoute(mGET, "/admin/user", hUserList) + tr.InsertRoute(mGET, "/admin/user/", hStub) // will get replaced by next route + tr.InsertRoute(mGET, "/admin/user/", hUserList) + + tr.InsertRoute(mGET, "/admin/user//{id}", hUserShow) + tr.InsertRoute(mGET, "/admin/user/{id}", hUserShow) + + tr.InsertRoute(mGET, "/admin/apps/{id}", hAdminAppShow) + tr.InsertRoute(mGET, "/admin/apps/{id}/*", hAdminAppShowCatchall) + + tr.InsertRoute(mGET, "/admin/*", hStub) // catchall segment will get replaced by next route + tr.InsertRoute(mGET, "/admin/*", hAdminCatchall) + + tr.InsertRoute(mGET, "/users/{userID}/profile", hUserProfile) + tr.InsertRoute(mGET, "/users/super/*", hUserSuper) + tr.InsertRoute(mGET, "/users/*", hUserAll) + + tr.InsertRoute(mGET, "/hubs/{hubID}/view", hHubView1) + tr.InsertRoute(mGET, "/hubs/{hubID}/view/*", hHubView2) + sr := NewRouter() + sr.Get("/users", hHubView3) + tr.InsertRoute(mGET, "/hubs/{hubID}/*", sr) + tr.InsertRoute(mGET, "/hubs/{hubID}/users", hHubView3) + + tests := []struct { + r string // input request path + h http.Handler // output matched handler + k []string // output param keys + v []string // output param values + }{ + {r: "/", h: hIndex, k: []string{}, v: []string{}}, + {r: "/favicon.ico", h: hFavicon, k: []string{}, v: []string{}}, + + {r: "/pages", h: nil, k: []string{}, v: []string{}}, + {r: "/pages/", h: hStub, k: []string{"*"}, v: []string{""}}, + {r: "/pages/yes", h: hStub, k: []string{"*"}, v: []string{"yes"}}, + + {r: "/article", h: hArticleList, k: []string{}, v: []string{}}, + {r: "/article/", h: hArticleList, k: []string{}, v: []string{}}, + {r: "/article/near", h: hArticleNear, k: []string{}, v: []string{}}, + {r: "/article/neard", h: hArticleShow, k: []string{"id"}, v: []string{"neard"}}, + {r: "/article/123", h: hArticleShow, k: []string{"id"}, v: []string{"123"}}, + {r: "/article/123/456", h: hArticleShowOpts, k: []string{"id", "opts"}, v: []string{"123", "456"}}, + {r: "/article/@peter", h: hArticleByUser, k: []string{"user"}, v: []string{"peter"}}, + {r: "/article/22//related", h: hArticleShowRelated, k: []string{"id"}, v: []string{"22"}}, + {r: "/article/111/edit", h: hStub, k: []string{"iffd"}, v: []string{"111"}}, + {r: "/article/slug/sept/-/4/2015", h: hArticleSlug, k: []string{"month", "day", "year"}, v: []string{"sept", "4", "2015"}}, + {r: "/article/:id", h: hArticleShow, k: []string{"id"}, v: []string{":id"}}, + + {r: "/admin/user", h: hUserList, k: []string{}, v: []string{}}, + {r: "/admin/user/", h: hUserList, k: []string{}, v: []string{}}, + {r: "/admin/user/1", h: hUserShow, k: []string{"id"}, v: []string{"1"}}, + {r: "/admin/user//1", h: hUserShow, k: []string{"id"}, v: []string{"1"}}, + {r: "/admin/hi", h: hAdminCatchall, k: []string{"*"}, v: []string{"hi"}}, + {r: "/admin/lots/of/:fun", h: hAdminCatchall, k: []string{"*"}, v: []string{"lots/of/:fun"}}, + {r: "/admin/apps/333", h: hAdminAppShow, k: []string{"id"}, v: []string{"333"}}, + {r: "/admin/apps/333/woot", h: hAdminAppShowCatchall, k: []string{"id", "*"}, v: []string{"333", "woot"}}, + + {r: "/hubs/123/view", h: hHubView1, k: []string{"hubID"}, v: []string{"123"}}, + {r: "/hubs/123/view/index.html", h: hHubView2, k: []string{"hubID", "*"}, v: []string{"123", "index.html"}}, + {r: "/hubs/123/users", h: hHubView3, k: []string{"hubID"}, v: []string{"123"}}, + + {r: "/users/123/profile", h: hUserProfile, k: []string{"userID"}, v: []string{"123"}}, + {r: "/users/super/123/okay/yes", h: hUserSuper, k: []string{"*"}, v: []string{"123/okay/yes"}}, + {r: "/users/123/okay/yes", h: hUserAll, k: []string{"*"}, v: []string{"123/okay/yes"}}, + } + + // log.Println("~~~~~~~~~") + // log.Println("~~~~~~~~~") + // debugPrintTree(0, 0, tr, 0) + // log.Println("~~~~~~~~~") + // log.Println("~~~~~~~~~") + + for i, tt := range tests { + rctx := &RouteContext{} + + _, handlers, _ := tr.FindRoute(rctx, mGET, tt.r) + + var handler http.Handler + if methodHandler, ok := handlers[mGET]; ok { + handler = methodHandler.handler + } + + paramKeys := rctx.routeParams.Keys + paramValues := rctx.routeParams.Values + + if fmt.Sprintf("%v", tt.h) != fmt.Sprintf("%v", handler) { + t.Errorf("input [%d]: find '%s' expecting handler:%v , got:%v", i, tt.r, tt.h, handler) + } + if !stringSliceEqual(tt.k, paramKeys) { + t.Errorf("input [%d]: find '%s' expecting paramKeys:(%d)%v , got:(%d)%v", i, tt.r, len(tt.k), tt.k, len(paramKeys), paramKeys) + } + if !stringSliceEqual(tt.v, paramValues) { + t.Errorf("input [%d]: find '%s' expecting paramValues:(%d)%v , got:(%d)%v", i, tt.r, len(tt.v), tt.v, len(paramValues), paramValues) + } + } +} + +func TestTreeMoar(t *testing.T) { + hStub := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hStub1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hStub2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hStub3 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hStub4 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hStub5 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hStub6 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hStub7 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hStub8 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hStub9 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hStub10 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hStub11 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hStub12 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hStub13 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hStub14 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hStub15 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hStub16 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + + // TODO: panic if we see {id}{x} because we're missing a delimiter, its not possible. + // also {:id}* is not possible. + + tr := &node{} + + tr.InsertRoute(mGET, "/articlefun", hStub5) + tr.InsertRoute(mGET, "/articles/{id}", hStub) + tr.InsertRoute(mDELETE, "/articles/{slug}", hStub8) + tr.InsertRoute(mGET, "/articles/search", hStub1) + tr.InsertRoute(mGET, "/articles/{id}:delete", hStub8) + tr.InsertRoute(mGET, "/articles/{iidd}!sup", hStub4) + tr.InsertRoute(mGET, "/articles/{id}:{op}", hStub3) + tr.InsertRoute(mGET, "/articles/{id}:{op}", hStub2) // this route sets a new handler for the above route + tr.InsertRoute(mGET, "/articles/{slug:^[a-z]+}/posts", hStub) // up to tail '/' will only match if contents match the rex + tr.InsertRoute(mGET, "/articles/{id}/posts/{pid}", hStub6) // /articles/123/posts/1 + tr.InsertRoute(mGET, "/articles/{id}/posts/{month}/{day}/{year}/{slug}", hStub7) // /articles/123/posts/09/04/1984/juice + tr.InsertRoute(mGET, "/articles/{id}.json", hStub10) + tr.InsertRoute(mGET, "/articles/{id}/data.json", hStub11) + tr.InsertRoute(mGET, "/articles/files/{file}.{ext}", hStub12) + tr.InsertRoute(mPUT, "/articles/me", hStub13) + + // TODO: make a separate test case for this one.. + // tr.InsertRoute(mGET, "/articles/{id}/{id}", hStub1) // panic expected, we're duplicating param keys + + tr.InsertRoute(mGET, "/pages/*", hStub) + tr.InsertRoute(mGET, "/pages/*", hStub9) + + tr.InsertRoute(mGET, "/users/{id}", hStub14) + tr.InsertRoute(mGET, "/users/{id}/settings/{key}", hStub15) + tr.InsertRoute(mGET, "/users/{id}/settings/*", hStub16) + + tests := []struct { + h http.Handler + r string + k []string + v []string + m methodTyp + }{ + {m: mGET, r: "/articles/search", h: hStub1, k: []string{}, v: []string{}}, + {m: mGET, r: "/articlefun", h: hStub5, k: []string{}, v: []string{}}, + {m: mGET, r: "/articles/123", h: hStub, k: []string{"id"}, v: []string{"123"}}, + {m: mDELETE, r: "/articles/123mm", h: hStub8, k: []string{"slug"}, v: []string{"123mm"}}, + {m: mGET, r: "/articles/789:delete", h: hStub8, k: []string{"id"}, v: []string{"789"}}, + {m: mGET, r: "/articles/789!sup", h: hStub4, k: []string{"iidd"}, v: []string{"789"}}, + {m: mGET, r: "/articles/123:sync", h: hStub2, k: []string{"id", "op"}, v: []string{"123", "sync"}}, + {m: mGET, r: "/articles/456/posts/1", h: hStub6, k: []string{"id", "pid"}, v: []string{"456", "1"}}, + {m: mGET, r: "/articles/456/posts/09/04/1984/juice", h: hStub7, k: []string{"id", "month", "day", "year", "slug"}, v: []string{"456", "09", "04", "1984", "juice"}}, + {m: mGET, r: "/articles/456.json", h: hStub10, k: []string{"id"}, v: []string{"456"}}, + {m: mGET, r: "/articles/456/data.json", h: hStub11, k: []string{"id"}, v: []string{"456"}}, + + {m: mGET, r: "/articles/files/file.zip", h: hStub12, k: []string{"file", "ext"}, v: []string{"file", "zip"}}, + {m: mGET, r: "/articles/files/photos.tar.gz", h: hStub12, k: []string{"file", "ext"}, v: []string{"photos", "tar.gz"}}, + {m: mGET, r: "/articles/files/photos.tar.gz", h: hStub12, k: []string{"file", "ext"}, v: []string{"photos", "tar.gz"}}, + + {m: mPUT, r: "/articles/me", h: hStub13, k: []string{}, v: []string{}}, + {m: mGET, r: "/articles/me", h: hStub, k: []string{"id"}, v: []string{"me"}}, + {m: mGET, r: "/pages", h: nil, k: []string{}, v: []string{}}, + {m: mGET, r: "/pages/", h: hStub9, k: []string{"*"}, v: []string{""}}, + {m: mGET, r: "/pages/yes", h: hStub9, k: []string{"*"}, v: []string{"yes"}}, + + {m: mGET, r: "/users/1", h: hStub14, k: []string{"id"}, v: []string{"1"}}, + {m: mGET, r: "/users/", h: nil, k: []string{}, v: []string{}}, + {m: mGET, r: "/users/2/settings/password", h: hStub15, k: []string{"id", "key"}, v: []string{"2", "password"}}, + {m: mGET, r: "/users/2/settings/", h: hStub16, k: []string{"id", "*"}, v: []string{"2", ""}}, + } + + // log.Println("~~~~~~~~~") + // log.Println("~~~~~~~~~") + // debugPrintTree(0, 0, tr, 0) + // log.Println("~~~~~~~~~") + // log.Println("~~~~~~~~~") + + for i, tt := range tests { + rctx := &RouteContext{} + + _, handlers, _ := tr.FindRoute(rctx, tt.m, tt.r) + + var handler http.Handler + if methodHandler, ok := handlers[tt.m]; ok { + handler = methodHandler.handler + } + + paramKeys := rctx.routeParams.Keys + paramValues := rctx.routeParams.Values + + if fmt.Sprintf("%v", tt.h) != fmt.Sprintf("%v", handler) { + t.Errorf("input [%d]: find '%s' expecting handler:%v , got:%v", i, tt.r, tt.h, handler) + } + if !stringSliceEqual(tt.k, paramKeys) { + t.Errorf("input [%d]: find '%s' expecting paramKeys:(%d)%v , got:(%d)%v", i, tt.r, len(tt.k), tt.k, len(paramKeys), paramKeys) + } + if !stringSliceEqual(tt.v, paramValues) { + t.Errorf("input [%d]: find '%s' expecting paramValues:(%d)%v , got:(%d)%v", i, tt.r, len(tt.v), tt.v, len(paramValues), paramValues) + } + } +} + +func TestTreeRegexp(t *testing.T) { + hStub1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hStub2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hStub3 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hStub4 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hStub5 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hStub6 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hStub7 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + + tr := &node{} + tr.InsertRoute(mGET, "/articles/{rid:^[0-9]{5,6}}", hStub7) + tr.InsertRoute(mGET, "/articles/{zid:^0[0-9]+}", hStub3) + tr.InsertRoute(mGET, "/articles/{name:^@[a-z]+}/posts", hStub4) + tr.InsertRoute(mGET, "/articles/{op:^[0-9]+}/run", hStub5) + tr.InsertRoute(mGET, "/articles/{id:^[0-9]+}", hStub1) + tr.InsertRoute(mGET, "/articles/{id:^[1-9]+}-{aux}", hStub6) + tr.InsertRoute(mGET, "/articles/{slug}", hStub2) + + // log.Println("~~~~~~~~~") + // log.Println("~~~~~~~~~") + // debugPrintTree(0, 0, tr, 0) + // log.Println("~~~~~~~~~") + // log.Println("~~~~~~~~~") + + tests := []struct { + r string // input request path + h http.Handler // output matched handler + k []string // output param keys + v []string // output param values + }{ + {r: "/articles", h: nil, k: []string{}, v: []string{}}, + {r: "/articles/12345", h: hStub7, k: []string{"rid"}, v: []string{"12345"}}, + {r: "/articles/123", h: hStub1, k: []string{"id"}, v: []string{"123"}}, + {r: "/articles/how-to-build-a-router", h: hStub2, k: []string{"slug"}, v: []string{"how-to-build-a-router"}}, + {r: "/articles/0456", h: hStub3, k: []string{"zid"}, v: []string{"0456"}}, + {r: "/articles/@pk/posts", h: hStub4, k: []string{"name"}, v: []string{"@pk"}}, + {r: "/articles/1/run", h: hStub5, k: []string{"op"}, v: []string{"1"}}, + {r: "/articles/1122", h: hStub1, k: []string{"id"}, v: []string{"1122"}}, + {r: "/articles/1122-yes", h: hStub6, k: []string{"id", "aux"}, v: []string{"1122", "yes"}}, + } + + for i, tt := range tests { + rctx := &RouteContext{} + + _, handlers, _ := tr.FindRoute(rctx, mGET, tt.r) + + var handler http.Handler + if methodHandler, ok := handlers[mGET]; ok { + handler = methodHandler.handler + } + + paramKeys := rctx.routeParams.Keys + paramValues := rctx.routeParams.Values + + if fmt.Sprintf("%v", tt.h) != fmt.Sprintf("%v", handler) { + t.Errorf("input [%d]: find '%s' expecting handler:%v , got:%v", i, tt.r, tt.h, handler) + } + if !stringSliceEqual(tt.k, paramKeys) { + t.Errorf("input [%d]: find '%s' expecting paramKeys:(%d)%v , got:(%d)%v", i, tt.r, len(tt.k), tt.k, len(paramKeys), paramKeys) + } + if !stringSliceEqual(tt.v, paramValues) { + t.Errorf("input [%d]: find '%s' expecting paramValues:(%d)%v , got:(%d)%v", i, tt.r, len(tt.v), tt.v, len(paramValues), paramValues) + } + } +} + +func TestTreeRegexpRecursive(t *testing.T) { + hStub1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hStub2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + + tr := &node{} + tr.InsertRoute(mGET, "/one/{firstId:[a-z0-9-]+}/{secondId:[a-z0-9-]+}/first", hStub1) + tr.InsertRoute(mGET, "/one/{firstId:[a-z0-9-_]+}/{secondId:[a-z0-9-_]+}/second", hStub2) + + // log.Println("~~~~~~~~~") + // log.Println("~~~~~~~~~") + // debugPrintTree(0, 0, tr, 0) + // log.Println("~~~~~~~~~") + // log.Println("~~~~~~~~~") + + tests := []struct { + r string // input request path + h http.Handler // output matched handler + k []string // output param keys + v []string // output param values + }{ + {r: "/one/hello/world/first", h: hStub1, k: []string{"firstId", "secondId"}, v: []string{"hello", "world"}}, + {r: "/one/hi_there/ok/second", h: hStub2, k: []string{"firstId", "secondId"}, v: []string{"hi_there", "ok"}}, + {r: "/one///first", h: nil, k: []string{}, v: []string{}}, + {r: "/one/hi/123/second", h: hStub2, k: []string{"firstId", "secondId"}, v: []string{"hi", "123"}}, + } + + for i, tt := range tests { + rctx := &RouteContext{} + + _, handlers, _ := tr.FindRoute(rctx, mGET, tt.r) + + var handler http.Handler + if methodHandler, ok := handlers[mGET]; ok { + handler = methodHandler.handler + } + + paramKeys := rctx.routeParams.Keys + paramValues := rctx.routeParams.Values + + if fmt.Sprintf("%v", tt.h) != fmt.Sprintf("%v", handler) { + t.Errorf("input [%d]: find '%s' expecting handler:%v , got:%v", i, tt.r, tt.h, handler) + } + if !stringSliceEqual(tt.k, paramKeys) { + t.Errorf("input [%d]: find '%s' expecting paramKeys:(%d)%v , got:(%d)%v", i, tt.r, len(tt.k), tt.k, len(paramKeys), paramKeys) + } + if !stringSliceEqual(tt.v, paramValues) { + t.Errorf("input [%d]: find '%s' expecting paramValues:(%d)%v , got:(%d)%v", i, tt.r, len(tt.v), tt.v, len(paramValues), paramValues) + } + } +} + +func TestTreeRegexMatchWholeParam(t *testing.T) { + hStub1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + + rctx := &RouteContext{} + tr := &node{} + tr.InsertRoute(mGET, "/{id:[0-9]+}", hStub1) + tr.InsertRoute(mGET, "/{x:.+}/foo", hStub1) + tr.InsertRoute(mGET, "/{param:[0-9]*}/test", hStub1) + + tests := []struct { + expectedHandler http.Handler + url string + }{ + {url: "/13", expectedHandler: hStub1}, + {url: "/a13", expectedHandler: nil}, + {url: "/13.jpg", expectedHandler: nil}, + {url: "/a13.jpg", expectedHandler: nil}, + {url: "/a/foo", expectedHandler: hStub1}, + {url: "//foo", expectedHandler: nil}, + {url: "//test", expectedHandler: hStub1}, + } + + for _, tc := range tests { + _, _, handler := tr.FindRoute(rctx, mGET, tc.url) + if fmt.Sprintf("%v", tc.expectedHandler) != fmt.Sprintf("%v", handler) { + t.Errorf("url %v: expecting handler:%v , got:%v", tc.url, tc.expectedHandler, handler) + } + } +} + +func TestTreeFindPattern(t *testing.T) { + hStub1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hStub2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + hStub3 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + + tr := &node{} + tr.InsertRoute(mGET, "/pages/*", hStub1) + tr.InsertRoute(mGET, "/articles/{id}/*", hStub2) + tr.InsertRoute(mGET, "/articles/{slug}/{uid}/*", hStub3) + + if tr.findPattern("/pages") != false { + t.Errorf("find /pages failed") + } + if tr.findPattern("/pages*") != false { + t.Errorf("find /pages* failed - should be nil") + } + if tr.findPattern("/pages/*") == false { + t.Errorf("find /pages/* failed") + } + if tr.findPattern("/articles/{id}/*") == false { + t.Errorf("find /articles/{id}/* failed") + } + if tr.findPattern("/articles/{something}/*") == false { + t.Errorf("find /articles/{something}/* failed") + } + if tr.findPattern("/articles/{slug}/{uid}/*") == false { + t.Errorf("find /articles/{slug}/{uid}/* failed") + } +} + +func debugPrintTree(parent int, i int, n *node, label byte) bool { + numEdges := 0 + for _, nds := range n.children { + numEdges += len(nds) + } + + // if n.handlers != nil { + // log.Printf("[node %d parent:%d] typ:%d prefix:%s label:%s tail:%s numEdges:%d isLeaf:%v handler:%v pat:%s keys:%v\n", i, parent, n.typ, n.prefix, string(label), string(n.tail), numEdges, n.isLeaf(), n.handlers, n.pattern, n.paramKeys) + // } else { + // log.Printf("[node %d parent:%d] typ:%d prefix:%s label:%s tail:%s numEdges:%d isLeaf:%v pat:%s keys:%v\n", i, parent, n.typ, n.prefix, string(label), string(n.tail), numEdges, n.isLeaf(), n.pattern, n.paramKeys) + // } + if n.endpoints != nil { + log.Printf("[node %d parent:%d] typ:%d prefix:%s label:%s tail:%s numEdges:%d isLeaf:%v handler:%v\n", i, parent, n.typ, n.prefix, string(label), string(n.tail), numEdges, n.isLeaf(), n.endpoints) + } else { + log.Printf("[node %d parent:%d] typ:%d prefix:%s label:%s tail:%s numEdges:%d isLeaf:%v\n", i, parent, n.typ, n.prefix, string(label), string(n.tail), numEdges, n.isLeaf()) + } + parent = i + for _, nds := range n.children { + for _, e := range nds { + i++ + if debugPrintTree(parent, i, e, e.label) { + return true + } + } + } + return false +} + +func stringSliceEqual(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if b[i] != a[i] { + return false + } + } + return true +} + +func BenchmarkTreeGet(b *testing.B) { + h1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + h2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + + tr := &node{} + tr.InsertRoute(mGET, "/", h1) + tr.InsertRoute(mGET, "/ping", h2) + tr.InsertRoute(mGET, "/pingall", h2) + tr.InsertRoute(mGET, "/ping/{id}", h2) + tr.InsertRoute(mGET, "/ping/{id}/woop", h2) + tr.InsertRoute(mGET, "/ping/{id}/{opt}", h2) + tr.InsertRoute(mGET, "/pinggggg", h2) + tr.InsertRoute(mGET, "/hello", h1) + + mctx := &RouteContext{} + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + mctx.Reset() + tr.FindRoute(mctx, mGET, "/ping/123/456") + } +}