From bea6e56d10053885b22131508bcb7c9e70a64163 Mon Sep 17 00:00:00 2001 From: Adam Hamrick Date: Tue, 28 Jan 2025 20:16:08 -0500 Subject: [PATCH] Tests and examples --- parrot/.changeset/v0.2.0.md | 2 +- parrot/README.md | 6 +- parrot/cage.go | 4 +- parrot/cage_test.go | 2 +- parrot/examples_test.go | 101 +++++++++++++++++++++ parrot/parrot.go | 59 ++++++++---- parrot/parrot_benchmark_test.go | 8 ++ parrot/parrot_test.go | 156 ++++++++++++++++++++++++++++++-- 8 files changed, 304 insertions(+), 34 deletions(-) diff --git a/parrot/.changeset/v0.2.0.md b/parrot/.changeset/v0.2.0.md index db9b8da8a..0fbdfa7c0 100644 --- a/parrot/.changeset/v0.2.0.md +++ b/parrot/.changeset/v0.2.0.md @@ -27,4 +27,4 @@ BenchmarkRegisterRoute-14 3647503 313.8 ns/op BenchmarkRouteResponse-14 19143 62011 ns/op BenchmarkSave-14 5244 218697 ns/op BenchmarkLoad-14 1101 1049399 ns/op -``` \ No newline at end of file +``` diff --git a/parrot/README.md b/parrot/README.md index cb9fc8011..3e5cef3fb 100644 --- a/parrot/README.md +++ b/parrot/README.md @@ -4,9 +4,9 @@ A simple, high-performing mockserver that can dynamically build new routes with ## Features -* Simplistic and fast design -* Run within your Go code, through a small binary, or in a minimal Docker container -* Easily record all incoming requests to the server to programmatically react to +* Run as an imported package, through a small binary, or in a minimal Docker container +* Record all incoming requests to the server and programmatically react +* Match wildcard routes and methods ## Use diff --git a/parrot/cage.go b/parrot/cage.go index 0a5b200a9..c6f003ba4 100644 --- a/parrot/cage.go +++ b/parrot/cage.go @@ -197,9 +197,7 @@ func (cl *cageLevel) route(routeSegment, routeMethod string) (route *Route, foun if route, found = cl.routes[routeSegment][routeMethod]; found { return route, true, nil } - } - if _, ok := cl.wildCardRoutes[routeSegment]; ok { - if route, found = cl.wildCardRoutes[routeSegment][MethodAny]; found { + if route, found = cl.routes[routeSegment][MethodAny]; found { // Fallthrough to any method if it's designed return route, true, nil } } diff --git a/parrot/cage_test.go b/parrot/cage_test.go index 8e83f344b..165fdcd9e 100644 --- a/parrot/cage_test.go +++ b/parrot/cage_test.go @@ -7,7 +7,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestCageNewRoutes(t *testing.T) { +func TestCage(t *testing.T) { t.Parallel() testCases := []struct { diff --git a/parrot/examples_test.go b/parrot/examples_test.go index 65e096dbc..63b8eb5dd 100644 --- a/parrot/examples_test.go +++ b/parrot/examples_test.go @@ -70,6 +70,107 @@ func ExampleServer_Register_internal() { // 0 } +func ExampleServer_Register_wildcards() { + // Create a new parrot instance with no logging and a custom save file + saveFile := "register_example.json" + p, err := parrot.Wake(parrot.WithLogLevel(zerolog.NoLevel), parrot.WithSaveFile(saveFile)) + if err != nil { + panic(err) + } + defer func() { // Cleanup the parrot instance + err = p.Shutdown(context.Background()) // Gracefully shutdown the parrot instance + if err != nil { + panic(err) + } + p.WaitShutdown() // Wait for the parrot instance to shutdown. Usually unnecessary, but we want to clean up the save file + os.Remove(saveFile) // Cleanup the save file for the example + }() + + // You can use the MethodAny constant to match any HTTP method + anyMethodRoute := &parrot.Route{ + Method: parrot.MethodAny, + Path: "/any-method", + RawResponseBody: "Any Method", + ResponseStatusCode: http.StatusOK, + } + + err = p.Register(anyMethodRoute) + if err != nil { + panic(err) + } + resp, err := p.Call(http.MethodGet, "/any-method") + if err != nil { + panic(err) + } + fmt.Println(resp.Request.Method, string(resp.Body())) + resp, err = p.Call(http.MethodPost, "/any-method") + if err != nil { + panic(err) + } + fmt.Println(resp.Request.Method, string(resp.Body())) + + // A * in the path will match any characters in the path + basicWildCard := &parrot.Route{ + Method: parrot.MethodAny, + Path: "/wildcard/*", + RawResponseBody: "Basic Wildcard", + ResponseStatusCode: http.StatusOK, + } + + err = p.Register(basicWildCard) + if err != nil { + panic(err) + } + resp, err = p.Call(http.MethodGet, "/wildcard/anything") + if err != nil { + panic(err) + } + fmt.Println(resp.Request.RawRequest.URL.Path, string(resp.Body())) + + // Wild cards can be nested + nestedWildCardRoute := &parrot.Route{ + Method: parrot.MethodAny, + Path: "/wildcard/*/nested/*", + RawResponseBody: "Nested Wildcard", + ResponseStatusCode: http.StatusOK, + } + + err = p.Register(nestedWildCardRoute) + if err != nil { + panic(err) + } + resp, err = p.Call(http.MethodGet, "/wildcard/anything/nested/else") + if err != nil { + panic(err) + } + fmt.Println(resp.Request.RawRequest.URL.Path, string(resp.Body())) + + // Wild cards can also be partials + partialWildCardRoute := &parrot.Route{ + Method: parrot.MethodAny, + Path: "/partial*/wildcard", + RawResponseBody: "Partial Wildcard", + ResponseStatusCode: http.StatusOK, + } + + err = p.Register(partialWildCardRoute) + if err != nil { + panic(err) + } + resp, err = p.Call(http.MethodGet, "/partial_anything/wildcard") + if err != nil { + panic(err) + } + fmt.Println(resp.Request.RawRequest.URL.Path, string(resp.Body())) + + // Output: + // GET Any Method + // POST Any Method + // /wildcard/anything Basic Wildcard + // /wildcard/anything/nested/else Nested Wildcard + // /partial_anything/wildcard Partial Wildcard +} + func ExampleServer_Register_external() { var ( saveFile = "route_example.json" diff --git a/parrot/parrot.go b/parrot/parrot.go index 1d702d877..61bf0f9d6 100644 --- a/parrot/parrot.go +++ b/parrot/parrot.go @@ -11,7 +11,6 @@ import ( "net/url" "os" "path/filepath" - "regexp" "strconv" "strings" "sync" @@ -322,8 +321,8 @@ func (p *Server) Register(route *Route) error { if route == nil { return ErrNilRoute } - if !isValidPath(route.Path) { - return newDynamicError(ErrInvalidPath, fmt.Sprintf("'%s'", route.Path)) + if err := checkPath(route.Path); err != nil { + return newDynamicError(ErrInvalidPath, err.Error()) } if route.Method == "" { return ErrNoMethod @@ -400,6 +399,10 @@ func (p *Server) Delete(route *Route) error { } // Call makes a request to the parrot server +// The method is the HTTP method to use (GET, POST, PUT, DELETE, etc.) +// The path is the URL path to call +// The response is returned as a resty.Response +// Errors are returned if the server is shut down or if the request fails, not if the response is an error func (p *Server) Call(method, path string) (*resty.Response, error) { if p.shutDown { return nil, ErrServerShutdown @@ -491,7 +494,7 @@ func (p *Server) dynamicHandler(w http.ResponseWriter, r *http.Request) { route, err := p.cage.getRoute(r.URL.Path, r.Method) if err != nil { - if errors.Is(err, ErrRouteNotFound) { + if errors.Is(err, ErrRouteNotFound) || errors.Is(err, ErrCageNotFound) { http.Error(w, "Route not found", http.StatusNotFound) dynamicLogger.Debug().Msg("Route not found") return @@ -754,36 +757,58 @@ func (p *Server) loggingMiddleware(next http.Handler) http.Handler { return h(accessHandler(next)) } -var validPathRegex = regexp.MustCompile(`^\/[a-zA-Z0-9\-._~%!$&'()+,;=:@\/]`) - -func isValidPath(path string) bool { +func checkPath(path string) error { switch path { case "", "/", "//", healthRoute, recordRoute, routesRoute, "/..": - return false + return fmt.Errorf("cannot match special paths: '%s'", path) } if strings.Contains(path, "/..") { - return false + return fmt.Errorf("cannot match parent directory traversal: '%s'", path) } if strings.Contains(path, "/.") { - return false + return fmt.Errorf("cannot match hidden files: '%s'", path) } if strings.Contains(path, "//") { - return false + return fmt.Errorf("cannot match double slashes: '%s'", path) } if !strings.HasPrefix(path, "/") { - return false + return fmt.Errorf("path must start with a forward slash: '%s'", path) } if strings.HasSuffix(path, "/") { - return false + return fmt.Errorf("path cannot end with a forward slash: '%s'", path) } if strings.HasPrefix(path, recordRoute) { - return false + return fmt.Errorf("cannot match record route: '%s'", path) } if strings.HasPrefix(path, healthRoute) { - return false + return fmt.Errorf("cannot match health route: '%s'", path) } if strings.HasPrefix(path, routesRoute) { - return false + return fmt.Errorf("cannot match routes route: '%s'", path) + } + match, err := filepath.Match(path, healthRoute) + if err != nil { + return fmt.Errorf("failed to match: '%s'", path) + } + if match { + return fmt.Errorf("cannot match health route: '%s'", path) + } + + match, err = filepath.Match(path, recordRoute) + if err != nil { + return fmt.Errorf("failed to match: '%s'", path) } - return validPathRegex.MatchString(path) + if match { + return fmt.Errorf("cannot match record route: '%s'", path) + } + + match, err = filepath.Match(path, routesRoute) + if err != nil { + return fmt.Errorf("failed to match: '%s'", path) + } + if match { + return fmt.Errorf("cannot match routes route: '%s'", path) + } + + return nil } diff --git a/parrot/parrot_benchmark_test.go b/parrot/parrot_benchmark_test.go index 68d941b13..792355c84 100644 --- a/parrot/parrot_benchmark_test.go +++ b/parrot/parrot_benchmark_test.go @@ -41,6 +41,10 @@ func BenchmarkRegisterRoute(b *testing.B) { b.StopTimer() } +func BenchmarkRegisterWildCardRoute(b *testing.B) { + // TODO: Implement +} + func BenchmarkRouteResponse(b *testing.B) { saveFile := b.Name() + ".json" p, err := Wake(WithLogLevel(testLogLevel), WithSaveFile(saveFile)) @@ -72,6 +76,10 @@ func BenchmarkRouteResponse(b *testing.B) { b.StopTimer() } +func BenchmarkWildCardRouteResponse(b *testing.B) { + +} + func BenchmarkSave(b *testing.B) { var ( routes = []*Route{} diff --git a/parrot/parrot_test.go b/parrot/parrot_test.go index a53228ce3..725d05bc2 100644 --- a/parrot/parrot_test.go +++ b/parrot/parrot_test.go @@ -109,8 +109,6 @@ func TestRegisterRoutes(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - t.Parallel() - err := p.Register(tc.route) require.NoError(t, err, "error registering route") @@ -129,6 +127,143 @@ func TestRegisterRoutes(t *testing.T) { } } +func TestWildCardRoute(t *testing.T) { + t.Parallel() + + p := newParrot(t) + + simpleWildCardRoute := &Route{ + Method: http.MethodGet, + Path: "/wildcard/*", + RawResponseBody: "First Wildcard", + ResponseStatusCode: http.StatusOK, + } + + nestedWildCardRoute := &Route{ + Method: http.MethodGet, + Path: "/wildcard/*/nested/*", + RawResponseBody: "Nested Wildcard", + ResponseStatusCode: http.StatusOK, + } + + // Try to register a route that will confuse the wildcard route + confusingWildCardRoute := &Route{ + Method: http.MethodGet, + Path: "/wildcard/*/*", + RawResponseBody: "Confusing Wildcard", + ResponseStatusCode: http.StatusOK, + } + + baseWildCardRoute := &Route{ + Method: http.MethodGet, + Path: "/*/base/*", + RawResponseBody: "Base Wildcard", + ResponseStatusCode: http.StatusOK, + } + + partialWildCardRoute := &Route{ + Method: http.MethodGet, + Path: "/partial*/after", + RawResponseBody: "Partial Wildcard", + ResponseStatusCode: http.StatusOK, + } + + err := p.Register(simpleWildCardRoute) + require.NoError(t, err, "error registering route") + err = p.Register(nestedWildCardRoute) + require.NoError(t, err, "error registering route") + err = p.Register(confusingWildCardRoute) + require.NoError(t, err, "error registering route") + err = p.Register(baseWildCardRoute) + require.NoError(t, err, "error registering route") + err = p.Register(partialWildCardRoute) + require.NoError(t, err, "error registering route") + + testCases := []struct { + callingPath string + matchRoute *Route + expectErrStatusCode bool + }{ + { + callingPath: "/wildcard/anything", + matchRoute: simpleWildCardRoute, + expectErrStatusCode: false, + }, + { + callingPath: "/wildcard/anything/nested/thing", + matchRoute: nestedWildCardRoute, + expectErrStatusCode: false, + }, + { + callingPath: "/wildcard/anything/anything", + matchRoute: confusingWildCardRoute, + expectErrStatusCode: false, + }, + { + callingPath: "/route", + matchRoute: nil, + expectErrStatusCode: true, + }, + { + callingPath: "/wildcard/anything/nested/thing/extra", + matchRoute: nil, + expectErrStatusCode: true, + }, + { + callingPath: "/base/base/anything", + matchRoute: baseWildCardRoute, + expectErrStatusCode: false, + }, + { + callingPath: "/partialanything/after", + matchRoute: partialWildCardRoute, + expectErrStatusCode: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.callingPath[1:], func(t *testing.T) { + resp, err := p.Call(http.MethodGet, tc.callingPath) + require.NoError(t, err, "error calling parrot") + + if tc.expectErrStatusCode { + assert.Equal(t, + http.StatusNotFound, resp.StatusCode(), + fmt.Sprintf("expected route not found, got body ''%s''", string(resp.Body())), + ) + } else { + assert.Equal(t, tc.matchRoute.ResponseStatusCode, resp.StatusCode(), "status code mismatch") + assert.Equal(t, tc.matchRoute.RawResponseBody, string(resp.Body()), "response body mismatch") + } + }) + } +} + +func TestAnyMethodRoute(t *testing.T) { + t.Parallel() + + p := newParrot(t) + + route := &Route{ + Method: MethodAny, + Path: "/any", + RawResponseBody: "Squawk", + ResponseStatusCode: http.StatusOK, + } + + err := p.Register(route) + require.NoError(t, err, "error registering route") + + methods := []string{http.MethodGet, http.MethodPost, http.MethodPut, http.MethodDelete, http.MethodPatch} + + for _, method := range methods { + resp, err := p.Call(method, route.Path) + require.NoError(t, err, "error calling parrot") + assert.Equal(t, route.ResponseStatusCode, resp.StatusCode(), "status code mismatch") + assert.Equal(t, route.RawResponseBody, string(resp.Body()), "response body mismatch") + } +} + func TestGetRoutes(t *testing.T) { t.Parallel() @@ -168,27 +303,30 @@ func TestIsValidPath(t *testing.T) { }{ { name: "valid paths", - paths: []string{"/hello", "/hello/there", "/wildcard/*", "/wildcard/*/nested", "/wildcard/*/nested/*"}, + paths: []string{"/hello", "/hello/there", "/wildcard/*", "/wildcard/*/nested", "/wildcard/*/nested/*", "/*/nested/*"}, valid: true, }, { name: "no protected paths", - paths: []string{healthRoute, routesRoute, recordRoute, fmt.Sprintf("%s/%s", routesRoute, "route-id"), fmt.Sprintf("%s/%s", healthRoute, "recorder-id"), fmt.Sprintf("%s/%s", recordRoute, "recorder-id")}, + paths: []string{healthRoute, routesRoute, recordRoute, fmt.Sprintf("%s/%s", routesRoute, "route-id"), fmt.Sprintf("%s/%s", healthRoute, "recorder-id"), fmt.Sprintf("%s/%s", recordRoute, "recorder-id"), "/*"}, valid: false, }, { name: "invalid paths", - paths: []string{"", "/", " ", " /", "/ ", " / ", "/invalid//", "/invalid/../x", "/invalid/", "invalid", "invalid path"}, + paths: []string{"", "/", " ", " /", " / ", "/invalid//", "/invalid/../x", "/invalid/", "invalid", "invalid path"}, + valid: false, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - t.Parallel() - for _, path := range tc.paths { - valid := isValidPath(path) - assert.Equal(t, tc.valid, valid) + pathErr := checkPath(path) + if tc.valid { + assert.NoError(t, pathErr, "expected path to be valid") + } else { + assert.Error(t, pathErr, fmt.Sprintf("expected path '%s' to be invalid", path)) + } } }) }