Skip to content

Commit

Permalink
Improve performance of radix.FindCaseInsensitivePath
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergio Andres Virviescas Santana committed Mar 29, 2020
1 parent 510879a commit 6a00bf3
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 69 deletions.
74 changes: 41 additions & 33 deletions radix/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"sort"
"strings"

"github.com/valyala/bytebufferpool"
"github.com/valyala/fasthttp"
)

Expand Down Expand Up @@ -322,49 +323,53 @@ walk:
}
}

func (n *node) find(path string, buf []byte) ([]byte, bool) {
func (n *node) find(path string, buf *bytebufferpool.ByteBuffer) (bool, bool) {
if len(path) > len(n.path) {
if strings.EqualFold(path[:len(n.path)], n.path) {

path = path[len(n.path):]
buf = append(buf, n.path...)
buf.WriteString(n.path)

if len(path) == 1 {
if path == "/" && n.handler != nil {
if n.tsr {
buf = append(buf, '/')
buf.WriteByte('/')

return buf, false
return true, false
}

return buf, true
return true, true
}
}

return n.findChild(path, buf)
found, tsr := n.findChild(path, buf)
if found {
return found, tsr
}

bufferRemoveString(buf, n.path)
}
} else if strings.EqualFold(path, n.path) && n.handler != nil {
buf = append(buf, n.path...)
buf.WriteString(n.path)

if n.tsr {
buf = append(buf, '/')

return buf, true
buf.WriteByte('/')
return true, true
}

return buf, false
return true, false
}

return nil, false
return false, false
}

func (n *node) findChild(path string, buf []byte) ([]byte, bool) {
func (n *node) findChild(path string, buf *bytebufferpool.ByteBuffer) (bool, bool) {
for _, child := range n.children {
switch child.nType {
case static:
buf2, tsr := child.find(path, buf)
if buf2 != nil || tsr {
return buf2, tsr
found, tsr := child.find(path, buf)
if found {
return found, tsr
}

case param:
Expand All @@ -377,49 +382,52 @@ func (n *node) findChild(path string, buf []byte) ([]byte, bool) {
}
}

buf.WriteString(path[:end])

if child.handler != nil {
if end == len(path) {
buf = append(buf, path...)

if child.tsr {
buf = append(buf, '/')
buf.WriteByte('/')

return buf, true
return true, true
}

return buf, false
} else if path[end:] == "/" {
buf = append(buf, path[:end]...)
return true, false

} else if path[end:] == "/" {
if child.tsr {
buf = append(buf, '/')
buf.WriteByte('/')

return buf, false
return true, false
}

return buf, true
return true, true
}
} else if len(path[end:]) == 0 {
return nil, false
bufferRemoveString(buf, path[:end])

return false, false
}

buf2, tsr := child.findChild(path[end:], append(buf, path[:end]...))
if buf2 != nil || tsr {
return buf2, tsr
found, tsr := child.findChild(path[end:], buf)
if found {
return found, tsr
}

bufferRemoveString(buf, path[:end])

default:
panic("invalid node type")
}
}

if n.wildcard != nil {
buf = append(buf, path...)
buf.WriteString(path)

return buf, false
return true, false
}

return nil, false
return false, false
}

// clone clones the current node in a new pointer
Expand Down
40 changes: 27 additions & 13 deletions radix/node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"strings"
"testing"

"github.com/valyala/bytebufferpool"
"github.com/valyala/fasthttp"
)

Expand Down Expand Up @@ -469,24 +470,30 @@ func TestTreeFindCaseInsensitivePath(t *testing.T) {
}
}

buf := bytebufferpool.Get()

// Check out == in for all registered routes
// With fixTrailingSlash = true
for _, route := range routes {
out, found := tree.FindCaseInsensitivePath(route, true)
found := tree.FindCaseInsensitivePath(route, true, buf)
if !found {
t.Errorf("Route '%s' not found!", route)
} else if string(out) != route {
t.Errorf("Wrong result for route '%s': %s", route, string(out))
} else if out := buf.String(); out != route {
t.Errorf("Wrong result for route '%s': %s", route, out)
}

buf.Reset()
}
// With fixTrailingSlash = false
for _, route := range routes {
out, found := tree.FindCaseInsensitivePath(route, false)
found := tree.FindCaseInsensitivePath(route, false, buf)
if !found {
t.Errorf("Route '%s' not found!", route)
} else if string(out) != route {
t.Errorf("Wrong result for route '%s': %s", route, string(out))
} else if out := buf.String(); out != route {
t.Errorf("Wrong result for route '%s': %s", route, out)
}

buf.Reset()
}

tests := []struct {
Expand Down Expand Up @@ -514,6 +521,7 @@ func TestTreeFindCaseInsensitivePath(t *testing.T) {
{"/RegEx/a1b2_test/DaTA", "/regex/a1b2_test/data", true, false},
{"/RegEx/A1B2_test/DaTA/", "/regex/A1B2_test/data", true, true},
{"/RegEx/blabla/DaTA/", "", false, false},
{"/RegEx/blabla_test/fail", "", false, false},
{"/x/Y", "/x/y", true, false},
{"/x/Y/", "/x/y", true, true},
{"/X/y", "/x/y", true, false},
Expand Down Expand Up @@ -559,25 +567,29 @@ func TestTreeFindCaseInsensitivePath(t *testing.T) {
}
// With fixTrailingSlash = true
for _, test := range tests {
out, found := tree.FindCaseInsensitivePath(test.in, true)
if found != test.found || (found && (string(out) != test.out)) {
found := tree.FindCaseInsensitivePath(test.in, true, buf)
if out := buf.String(); found != test.found || (found && (out != test.out)) {
t.Errorf("Wrong result for '%s': got %s, %t; want %s, %t",
test.in, string(out), found, test.out, test.found)
}

buf.Reset()
}
// With fixTrailingSlash = false
for _, test := range tests {
out, found := tree.FindCaseInsensitivePath(test.in, false)
found := tree.FindCaseInsensitivePath(test.in, false, buf)
if test.slash {
if found { // test needs a trailingSlash fix. It must not be found!
t.Errorf("Found without fixTrailingSlash: %s; got %s", test.in, string(out))
t.Errorf("Found without fixTrailingSlash: %s; got %s", test.in, buf.String())
}
} else {
if found != test.found || (found && (string(out) != test.out)) {
if out := buf.String(); found != test.found || (found && (out != test.out)) {
t.Errorf("Wrong result for '%s': got %s, %t; want %s, %t",
test.in, string(out), found, test.out, test.found)
test.in, out, found, test.out, test.found)
}
}

buf.Reset()
}
}

Expand All @@ -599,9 +611,11 @@ func TestTreeInvalidNodeType(t *testing.T) {
t.Fatalf("Expected panic '"+panicMsg+"', got '%v'", recv)
}

buf := bytebufferpool.Get()

// case-insensitive lookup
recv = catchPanic(func() {
tree.FindCaseInsensitivePath("/test", true)
tree.FindCaseInsensitivePath("/test", true, buf)
})
if rs, ok := recv.(string); !ok || rs != panicMsg {
t.Fatalf("Expected panic '"+panicMsg+"', got '%v'", recv)
Expand Down
23 changes: 7 additions & 16 deletions radix/tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package radix
import (
"strings"

"github.com/valyala/bytebufferpool"
"github.com/valyala/fasthttp"
)

Expand Down Expand Up @@ -111,24 +112,14 @@ func (t *Tree) Get(path string, ctx *fasthttp.RequestCtx) (fasthttp.RequestHandl
// It can optionally also fix trailing slashes.
// It returns the case-corrected path and a bool indicating whether the lookup
// was successful.
func (t *Tree) FindCaseInsensitivePath(path string, fixTrailingSlash bool) ([]byte, bool) {
// Use a static sized buffer on the stack in the common case.
// If the path is too long, allocate a buffer on the heap instead.
buf := make([]byte, 0, stackBufSize)
if l := len(path) + 1; l > stackBufSize {
buf = make([]byte, 0, l)
}

tsr := false
func (t *Tree) FindCaseInsensitivePath(path string, fixTrailingSlash bool, buf *bytebufferpool.ByteBuffer) bool {
found, tsr := t.root.find(path, buf)

buf, tsr = t.root.find(path, buf)
if !found || (tsr && !fixTrailingSlash) {
buf.Reset()

switch {
case buf == nil:
return nil, false
case tsr && !fixTrailingSlash:
return nil, false
return false
}

return buf, true
return true
}
5 changes: 4 additions & 1 deletion radix/tree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"testing"

"github.com/savsgio/gotils"
"github.com/valyala/bytebufferpool"
"github.com/valyala/fasthttp"
)

Expand Down Expand Up @@ -283,9 +284,11 @@ func Benchmark_FindCaseInsensitivePath(b *testing.B) {
tree := New()
tree.Add("/endpoint", generateHandler())

buf := bytebufferpool.Get()

b.ResetTimer()

for i := 0; i < b.N; i++ {
tree.FindCaseInsensitivePath("/ENdpOiNT", false)
tree.FindCaseInsensitivePath("/ENdpOiNT", false, buf)
}
}
6 changes: 6 additions & 0 deletions radix/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"regexp"
"strings"
"unicode/utf8"

"github.com/valyala/bytebufferpool"
)

func min(a, b int) int {
Expand All @@ -13,6 +15,10 @@ func min(a, b int) int {
return b
}

func bufferRemoveString(buf *bytebufferpool.ByteBuffer, s string) {
buf.B = buf.B[:len(buf.B)-len(s)]
}

// func isIndexEqual(a, b string) bool {
// ra, _ := utf8.DecodeRuneInString(a)
// rb, _ := utf8.DecodeRuneInString(b)
Expand Down
18 changes: 12 additions & 6 deletions router.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ import (

var (
defaultContentType = []byte("text/plain; charset=utf-8")
questionMark = []byte("?")
questionMark = byte('?')
)

// MatchedRoutePathParam is the param name under which the path of the matched
Expand Down Expand Up @@ -462,7 +462,7 @@ func (r *Router) Handler(ctx *fasthttp.RequestCtx) {

queryBuf := ctx.URI().QueryString()
if len(queryBuf) > 0 {
uri.Write(questionMark)
uri.WriteByte(questionMark)
uri.Write(queryBuf)
}

Expand All @@ -474,18 +474,24 @@ func (r *Router) Handler(ctx *fasthttp.RequestCtx) {

// Try to fix the request path
if r.RedirectFixedPath {
fixedPath, found := root.FindCaseInsensitivePath(
uri := bytebufferpool.Get()
found := root.FindCaseInsensitivePath(
CleanPath(path),
r.RedirectTrailingSlash,
uri,
)

if found {
queryBuf := ctx.URI().QueryString()
if len(queryBuf) > 0 {
fixedPath = append(fixedPath, questionMark...)
fixedPath = append(fixedPath, queryBuf...)
uri.WriteByte(questionMark)
uri.Write(queryBuf)
}

ctx.RedirectBytes(fixedPath, code)
ctx.RedirectBytes(uri.Bytes(), code)

bytebufferpool.Put(uri)

return
}
}
Expand Down

0 comments on commit 6a00bf3

Please sign in to comment.