From 336f12b15db96968847b171f9f650a61e9a55472 Mon Sep 17 00:00:00 2001 From: Nikolay Tkachenko Date: Tue, 9 Apr 2024 15:34:44 +0300 Subject: [PATCH 01/12] Add suffix support. Add tests --- .../internal/updater/wallarm_api2_update.db | Bin 98304 -> 98304 bytes internal/platform/validator/internal.go | 23 ++++- internal/platform/validator/internal_test.go | 90 ++++++++++++++++++ .../platform/validator/req_resp_decoder.go | 57 +++++++++-- .../validator/req_resp_decoder_test.go | 12 +++ 5 files changed, 169 insertions(+), 13 deletions(-) create mode 100644 internal/platform/validator/internal_test.go diff --git a/cmd/api-firewall/internal/updater/wallarm_api2_update.db b/cmd/api-firewall/internal/updater/wallarm_api2_update.db index 31a9d14c27dc418efd93f4f03cb6b29fb97cce9a..b6705d62f8a9f6f9e47199ba4f15f0057b4a0b91 100644 GIT binary patch delta 34 qcmZo@U~6b#n;^}2cA|_kCun|I5)2rC8n|I5)2r*7+GFWE77ytmt*9!~) diff --git a/internal/platform/validator/internal.go b/internal/platform/validator/internal.go index bbe7dd1..588b5c1 100644 --- a/internal/platform/validator/internal.go +++ b/internal/platform/validator/internal.go @@ -7,12 +7,25 @@ import ( "github.com/valyala/fastjson" ) -func parseMediaType(contentType string) string { - i := strings.IndexByte(contentType, ';') - if i < 0 { - return contentType +// parseMediaType func parses content type and returns media type and suffix +func parseMediaType(contentType string) (string, string) { + + var mtSubtype, suffix string + mediaType := contentType + + if i := strings.IndexByte(mediaType, ';'); i >= 0 { + mediaType = strings.TrimSpace(mediaType[:i]) + } + + if i := strings.IndexByte(mediaType, '/'); i >= 0 { + mtSubtype = mediaType[i+1:] } - return contentType[:i] + + if i := strings.LastIndexByte(mtSubtype, '+'); i >= 0 { + suffix = mtSubtype[i:] + } + + return mediaType, suffix } func isNilValue(value any) bool { diff --git a/internal/platform/validator/internal_test.go b/internal/platform/validator/internal_test.go new file mode 100644 index 0000000..76d626e --- /dev/null +++ b/internal/platform/validator/internal_test.go @@ -0,0 +1,90 @@ +package validator + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_parseMediaType(t *testing.T) { + type response struct { + mediaType string + suffix string + } + tests := []struct { + name string + contentType string + response response + }{ + { + name: "json", + contentType: "application/json", + response: response{ + mediaType: "application/json", + suffix: "", + }, + }, + { + name: "json with charset", + contentType: "application/json; charset=utf-8", + response: response{ + mediaType: "application/json", + suffix: "", + }, + }, + { + name: "json with suffix", + contentType: "application/vnd.mycompany.myapp.v2+json", + response: response{ + mediaType: "application/vnd.mycompany.myapp.v2+json", + suffix: "+json", + }, + }, + { + name: "xml", + contentType: "application/xml", + response: response{ + mediaType: "application/xml", + suffix: "", + }, + }, + { + name: "xml with charset", + contentType: "application/xml; charset=utf-8", + response: response{ + mediaType: "application/xml", + suffix: "", + }, + }, + { + name: "xml with suffix", + contentType: "application/vnd.openstreetmap.data+xml", + response: response{ + mediaType: "application/vnd.openstreetmap.data+xml", + suffix: "+xml", + }, + }, + { + name: "json with suffix 2", + contentType: "application/test+myapp+json; charset=utf8", + response: response{ + mediaType: "application/test+myapp+json", + suffix: "+json", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + mt, suffix := parseMediaType(tt.contentType) + if tt.response.mediaType != mt { + require.Error(t, fmt.Errorf("test name - %s: content type is invalid. Expected: %s. Got: %s", tt.name, tt.response.mediaType, mt)) + } + + if tt.response.suffix != suffix { + require.Error(t, fmt.Errorf("test name - %s: content type suffix is invalid. Expected: %s. Got: %s", tt.name, tt.response.suffix, suffix)) + } + }) + } +} diff --git a/internal/platform/validator/req_resp_decoder.go b/internal/platform/validator/req_resp_decoder.go index 3cb3691..3604286 100644 --- a/internal/platform/validator/req_resp_decoder.go +++ b/internal/platform/validator/req_resp_decoder.go @@ -970,6 +970,18 @@ func RegisterBodyDecoder(contentType string, decoder BodyDecoder) { bodyDecoders[contentType] = decoder } +// RegisterBodyDecoderSuffix registers a request body's decoder for a content type suffix. +// This call is not thread-safe: body decoders should not be created/destroyed by multiple goroutines. +func RegisterBodyDecoderSuffix(suffix string, decoder BodyDecoder) { + if suffix == "" { + panic("content type suffix is empty") + } + if decoder == nil { + panic("decoder is not defined") + } + bodyDecoders[suffix] = decoder +} + // UnregisterBodyDecoder dissociates a body decoder from a content type. // // Decoding this content type will result in an error. @@ -985,6 +997,29 @@ var headerCT = http.CanonicalHeaderKey("Content-Type") const prefixUnsupportedCT = "unsupported content type" +// getBodyDecoder searches by media type or suffix and returns body decoder or error +func getBodyDecoder(mediaType, suffix string) (BodyDecoder, error) { + var decoder BodyDecoder + var ok bool + + if suffix != "" { + decoder, ok = bodyDecoders[suffix] + if ok { + return decoder, nil + } + } + + decoder, ok = bodyDecoders[mediaType] + if !ok { + return nil, &ParseError{ + Kind: KindUnsupportedFormat, + Reason: fmt.Sprintf("%s %q", prefixUnsupportedCT, mediaType), + } + } + + return decoder, nil +} + // decodeBody returns a decoded body. // The function returns ParseError when a body is invalid. func decodeBody(body io.Reader, header http.Header, schema *openapi3.SchemaRef, encFn EncodingFn, jsonParser *fastjson.Parser) ( @@ -1000,18 +1035,17 @@ func decodeBody(body io.Reader, header http.Header, schema *openapi3.SchemaRef, if err != nil { return "", nil, err } - return parseMediaType(contentType), value, nil + mediaType, _ := parseMediaType(contentType) + return mediaType, value, nil } } - mediaType := parseMediaType(contentType) - decoder, ok := bodyDecoders[mediaType] - if !ok { - return "", nil, &ParseError{ - Kind: KindUnsupportedFormat, - Reason: fmt.Sprintf("%s %q", prefixUnsupportedCT, mediaType), - } + mediaType, suffix := parseMediaType(contentType) + decoder, err := getBodyDecoder(mediaType, suffix) + if err != nil { + return "", nil, err } + value, err := decoder(body, header, schema, encFn, jsonParser) if err != nil { return "", nil, err @@ -1020,6 +1054,13 @@ func decodeBody(body io.Reader, header http.Header, schema *openapi3.SchemaRef, } func init() { + RegisterBodyDecoderSuffix("+json", jsonBodyDecoder) + RegisterBodyDecoderSuffix("+xml", xmlBodyDecoder) + RegisterBodyDecoderSuffix("+yaml", yamlBodyDecoder) + RegisterBodyDecoderSuffix("+csv", csvBodyDecoder) + RegisterBodyDecoderSuffix("+plain", plainBodyDecoder) + RegisterBodyDecoderSuffix("+zip", zipFileBodyDecoder) + RegisterBodyDecoder("application/json", jsonBodyDecoder) RegisterBodyDecoder("application/xml", xmlBodyDecoder) RegisterBodyDecoder("application/json-patch+json", jsonBodyDecoder) diff --git a/internal/platform/validator/req_resp_decoder_test.go b/internal/platform/validator/req_resp_decoder_test.go index efae67a..118beaa 100644 --- a/internal/platform/validator/req_resp_decoder_test.go +++ b/internal/platform/validator/req_resp_decoder_test.go @@ -1167,6 +1167,12 @@ func TestDecodeBody(t *testing.T) { body: strings.NewReader("\"foo\""), want: "foo", }, + { + name: "json-suffix", + mime: "application/test.content-type+json", + body: strings.NewReader("\"foo\""), + want: "foo", + }, { name: "x-yaml", mime: "application/x-yaml", @@ -1179,6 +1185,12 @@ func TestDecodeBody(t *testing.T) { body: strings.NewReader("foo"), want: "foo", }, + { + name: "yaml-suffix", + mime: "application/test.content-type+yaml", + body: strings.NewReader("foo"), + want: "foo", + }, { name: "urlencoded form", mime: "application/x-www-form-urlencoded", From 62722126ae7dea2ff52d33918d97dfa35f839298 Mon Sep 17 00:00:00 2001 From: Nikolay Tkachenko Date: Thu, 11 Apr 2024 16:38:35 +0300 Subject: [PATCH 02/12] Update the APIFW version --- Makefile | 2 +- demo/docker-compose/docker-compose-api-mode.yml | 2 +- demo/docker-compose/docker-compose-graphql-mode.yml | 2 +- demo/docker-compose/docker-compose.yml | 2 +- helm/api-firewall/Chart.yaml | 4 ++-- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/Makefile b/Makefile index cc774bc..c351f1a 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -VERSION := 0.7.0 +VERSION := 0.7.1 .DEFAULT_GOAL := build diff --git a/demo/docker-compose/docker-compose-api-mode.yml b/demo/docker-compose/docker-compose-api-mode.yml index 63fe26f..358d0be 100644 --- a/demo/docker-compose/docker-compose-api-mode.yml +++ b/demo/docker-compose/docker-compose-api-mode.yml @@ -2,7 +2,7 @@ version: '3.8' services: api-firewall: container_name: api-firewall - image: wallarm/api-firewall:v0.7.0 + image: wallarm/api-firewall:v0.7.1 restart: on-failure environment: APIFW_MODE: "api" diff --git a/demo/docker-compose/docker-compose-graphql-mode.yml b/demo/docker-compose/docker-compose-graphql-mode.yml index cdf3113..1d46bae 100644 --- a/demo/docker-compose/docker-compose-graphql-mode.yml +++ b/demo/docker-compose/docker-compose-graphql-mode.yml @@ -2,7 +2,7 @@ version: '3.8' services: api-firewall: container_name: api-firewall - image: wallarm/api-firewall:v0.7.0 + image: wallarm/api-firewall:v0.7.1 restart: on-failure environment: APIFW_MODE: "graphql" diff --git a/demo/docker-compose/docker-compose.yml b/demo/docker-compose/docker-compose.yml index ef93b93..251a808 100644 --- a/demo/docker-compose/docker-compose.yml +++ b/demo/docker-compose/docker-compose.yml @@ -2,7 +2,7 @@ version: "3.8" services: api-firewall: container_name: api-firewall - image: wallarm/api-firewall:v0.7.0 + image: wallarm/api-firewall:v0.7.1 restart: on-failure environment: APIFW_URL: "http://0.0.0.0:8080" diff --git a/helm/api-firewall/Chart.yaml b/helm/api-firewall/Chart.yaml index 6ade3d8..e5a6524 100644 --- a/helm/api-firewall/Chart.yaml +++ b/helm/api-firewall/Chart.yaml @@ -1,7 +1,7 @@ apiVersion: v1 name: api-firewall -version: 0.7.0 -appVersion: 0.7.0 +version: 0.7.1 +appVersion: 0.7.1 description: Wallarm OpenAPI-based API Firewall home: https://github.com/wallarm/api-firewall icon: https://static.wallarm.com/wallarm-logo.svg From 6eb27a015264e383d25eb96668129206f1c60d19 Mon Sep 17 00:00:00 2001 From: Nikolay Tkachenko Date: Thu, 11 Apr 2024 16:41:55 +0300 Subject: [PATCH 03/12] Fix/route conflict (#90) * Update router (chi) in API mode. Add tests --- .../api-firewall/internal/handlers/api/app.go | 219 +++-- .../internal/handlers/api/openapi.go | 7 +- .../internal/handlers/api/routes.go | 31 +- .../internal/updater/wallarm_api2_update.db | Bin 98304 -> 98304 bytes cmd/api-firewall/tests/main_api_mode_test.go | 8 +- internal/mid/allowiplist.go | 2 +- internal/platform/chi/LICENSE | 20 + internal/platform/chi/chi.go | 30 + internal/platform/chi/context.go | 162 ++++ internal/platform/chi/context_test.go | 87 ++ internal/platform/chi/mux.go | 82 ++ internal/platform/chi/mux_test.go | 518 ++++++++++ internal/platform/chi/path_value.go | 22 + internal/platform/chi/path_value_fallback.go | 21 + internal/platform/chi/path_value_test.go | 77 ++ internal/platform/chi/tree.go | 915 ++++++++++++++++++ internal/platform/chi/tree_test.go | 643 ++++++++++++ internal/platform/database/v1.go | 5 +- internal/platform/database/v2.go | 5 +- internal/platform/router/router.go | 6 +- internal/platform/web/apiMode.go | 36 + internal/platform/web/middleware.go | 4 +- internal/platform/web/web.go | 8 +- 23 files changed, 2769 insertions(+), 139 deletions(-) rename internal/platform/web/webAPIMode.go => cmd/api-firewall/internal/handlers/api/app.go (53%) create mode 100644 internal/platform/chi/LICENSE create mode 100644 internal/platform/chi/chi.go create mode 100644 internal/platform/chi/context.go create mode 100644 internal/platform/chi/context_test.go create mode 100644 internal/platform/chi/mux.go create mode 100644 internal/platform/chi/mux_test.go create mode 100644 internal/platform/chi/path_value.go create mode 100644 internal/platform/chi/path_value_fallback.go create mode 100644 internal/platform/chi/path_value_test.go create mode 100644 internal/platform/chi/tree.go create mode 100644 internal/platform/chi/tree_test.go create mode 100644 internal/platform/web/apiMode.go diff --git a/internal/platform/web/webAPIMode.go b/cmd/api-firewall/internal/handlers/api/app.go similarity index 53% rename from internal/platform/web/webAPIMode.go rename to cmd/api-firewall/internal/handlers/api/app.go index 4141233..c0fdcb9 100644 --- a/internal/platform/web/webAPIMode.go +++ b/cmd/api-firewall/internal/handlers/api/app.go @@ -1,29 +1,24 @@ -package web +package api import ( "errors" "fmt" + "net/http" "os" + "runtime/debug" strconv2 "strconv" "strings" "sync" "syscall" - "github.com/fasthttp/router" "github.com/google/uuid" "github.com/savsgio/gotils/strconv" "github.com/sirupsen/logrus" "github.com/valyala/fasthttp" + "github.com/valyala/fasthttp/fasthttpadaptor" + "github.com/wallarm/api-firewall/internal/platform/chi" "github.com/wallarm/api-firewall/internal/platform/database" -) - -const ( - APIModePostfixStatusCode = "_status_code" - APIModePostfixValidationErrors = "_validation_errors" - - GlobalResponseStatusCodeKey = "global_response_status_code" - - RequestSchemaID = "__wallarm_apifw_request_schema_id" + "github.com/wallarm/api-firewall/internal/platform/web" ) var ( @@ -31,81 +26,30 @@ var ( statusInternalError = fasthttp.StatusInternalServerError ) -type FieldTypeError struct { - Name string `json:"name"` - ExpectedType string `json:"expected_type,omitempty"` - Pattern string `json:"pattern,omitempty"` - CurrentValue string `json:"current_value,omitempty"` -} - -type ValidationError struct { - Message string `json:"message"` - Code string `json:"code"` - SchemaVersion string `json:"schema_version,omitempty"` - SchemaID *int `json:"schema_id"` - Fields []string `json:"related_fields,omitempty"` - FieldsDetails []FieldTypeError `json:"related_fields_details,omitempty"` -} - -type APIModeResponseSummary struct { - SchemaID *int `json:"schema_id"` - StatusCode *int `json:"status_code"` -} - -type APIModeResponse struct { - Summary []*APIModeResponseSummary `json:"summary"` - Errors []*ValidationError `json:"errors,omitempty"` -} - // APIModeApp is the entrypoint into our application and what configures our context // object for each of our http handlers. Feel free to add any configuration // data/logic on this App struct type APIModeApp struct { - Routers map[int]*router.Router + Routers map[int]*chi.Mux Log *logrus.Logger passOPTIONS bool shutdown chan os.Signal - mw []Middleware + mw []web.Middleware storedSpecs database.DBOpenAPILoader lock *sync.RWMutex } -func (a *APIModeApp) SetDefaultBehavior(schemaID int, handler Handler, mw ...Middleware) { - // First wrap handler specific middleware around this handler. - handler = wrapMiddleware(mw, handler) - - // Add the application's general middleware to the handler chain. - handler = wrapMiddleware(a.mw, handler) - - customHandler := func(ctx *fasthttp.RequestCtx) { - - // Add request ID - ctx.SetUserValue(RequestID, uuid.NewString()) - - if err := handler(ctx); err != nil { - a.SignalShutdown() - return - } - - } - - // Set NOT FOUND behavior - a.Routers[schemaID].NotFound = customHandler - - // Set Method Not Allowed behavior - a.Routers[schemaID].MethodNotAllowed = customHandler -} - // NewAPIModeApp creates an APIModeApp value that handle a set of routes for the set of application. -func NewAPIModeApp(lock *sync.RWMutex, passOPTIONS bool, storedSpecs database.DBOpenAPILoader, shutdown chan os.Signal, logger *logrus.Logger, mw ...Middleware) *APIModeApp { +func NewAPIModeApp(lock *sync.RWMutex, passOPTIONS bool, storedSpecs database.DBOpenAPILoader, shutdown chan os.Signal, logger *logrus.Logger, mw ...web.Middleware) *APIModeApp { schemaIDs := storedSpecs.SchemaIDs() // Init routers - routers := make(map[int]*router.Router) + routers := make(map[int]*chi.Mux) for _, schemaID := range schemaIDs { - routers[schemaID] = router.New() - routers[schemaID].HandleOPTIONS = passOPTIONS + //routers[schemaID] = make(map[string]*mux.Router) + routers[schemaID] = chi.NewRouter() + //routers[schemaID].HandleOPTIONS = passOPTIONS } app := APIModeApp{ @@ -123,25 +67,29 @@ func NewAPIModeApp(lock *sync.RWMutex, passOPTIONS bool, storedSpecs database.DB // Handle is our mechanism for mounting Handlers for a given HTTP verb and path // pair, this makes for really easy, convenient routing. -func (a *APIModeApp) Handle(schemaID int, method string, path string, handler Handler, mw ...Middleware) { +func (a *APIModeApp) Handle(schemaID int, method string, path string, handler web.Handler, mw ...web.Middleware) error { // First wrap handler specific middleware around this handler. - handler = wrapMiddleware(mw, handler) + handler = web.WrapMiddleware(mw, handler) // Add the application's general middleware to the handler chain. - handler = wrapMiddleware(a.mw, handler) + handler = web.WrapMiddleware(a.mw, handler) // The function to execute for each request. - h := func(ctx *fasthttp.RequestCtx) { + h := func(ctx *fasthttp.RequestCtx) error { if err := handler(ctx); err != nil { a.SignalShutdown() - return + return err } + return nil } // Add this handler for the specified verb and route. - a.Routers[schemaID].Handle(method, path, h) + if err := a.Routers[schemaID].AddEndpoint(method, path, h); err != nil { + return err + } + return nil } // getWallarmSchemaID returns lists of found schema IDs in the DB, not found schema IDs in the DB and errors @@ -152,7 +100,7 @@ func getWallarmSchemaID(ctx *fasthttp.RequestCtx, storedSpecs database.DBOpenAPI } // Get Wallarm Schema ID - xWallarmSchemaIDsStr := string(ctx.Request.Header.Peek(XWallarmSchemaIDHeader)) + xWallarmSchemaIDsStr := string(ctx.Request.Header.Peek(web.XWallarmSchemaIDHeader)) if xWallarmSchemaIDsStr == "" { return nil, nil, errors.New("required X-WALLARM-SCHEMA-ID header is missing") } @@ -184,43 +132,42 @@ func getWallarmSchemaID(ctx *fasthttp.RequestCtx, storedSpecs database.DBOpenAPI return } -// APIModeHandler routes request to the appropriate handler according to the OpenAPI specification schema ID -func (a *APIModeApp) APIModeHandler(ctx *fasthttp.RequestCtx) { - - // Add request ID - ctx.SetUserValue(RequestID, uuid.NewString()) +// APIModeRouteHandler routes request to the appropriate handler according to the OpenAPI specification schema ID +func (a *APIModeApp) APIModeRouteHandler(ctx *fasthttp.RequestCtx) { + // handle panic defer func() { - // If pass request with OPTIONS method is enabled then log request - if ctx.Response.StatusCode() == fasthttp.StatusOK && a.passOPTIONS && strconv.B2S(ctx.Method()) == fasthttp.MethodOptions { - a.Log.WithFields(logrus.Fields{ - "request_id": ctx.UserValue(RequestID), - "host": string(ctx.Request.Header.Host()), - "path": string(ctx.Path()), - "method": string(ctx.Request.Header.Method()), - }).Info("Pass request with OPTIONS method") + if r := recover(); r != nil { + a.Log.Errorf("panic: %v", r) + + // Log the Go stack trace for this panic'd goroutine. + a.Log.Debugf("%s", debug.Stack()) + return } }() + // Add request ID + ctx.SetUserValue(web.RequestID, uuid.NewString()) + schemaIDs, notFoundSchemaIDs, err := getWallarmSchemaID(ctx, a.storedSpecs) if err != nil { - defer LogRequestResponseAtTraceLevel(ctx, a.Log) + defer web.LogRequestResponseAtTraceLevel(ctx, a.Log) a.Log.WithFields(logrus.Fields{ "error": err, "host": string(ctx.Request.Header.Host()), "path": string(ctx.Path()), "method": string(ctx.Request.Header.Method()), - "request_id": ctx.UserValue(RequestID), + "request_id": ctx.UserValue(web.RequestID), }).Error("error while getting schema ID") - if err := RespondError(ctx, fasthttp.StatusInternalServerError, ""); err != nil { + if err := web.RespondError(ctx, fasthttp.StatusInternalServerError, ""); err != nil { a.Log.WithFields(logrus.Fields{ "error": err, "host": string(ctx.Request.Header.Host()), "path": string(ctx.Path()), "method": string(ctx.Request.Header.Method()), - "request_id": ctx.UserValue(RequestID), + "request_id": ctx.UserValue(web.RequestID), }).Error("error while sending response") } @@ -228,34 +175,92 @@ func (a *APIModeApp) APIModeHandler(ctx *fasthttp.RequestCtx) { } // Delete internal header - ctx.Request.Header.Del(XWallarmSchemaIDHeader) + ctx.Request.Header.Del(web.XWallarmSchemaIDHeader) a.lock.RLock() defer a.lock.RUnlock() + //w := NewFastHTTPResponseAdapter(ctx) // Validate requests against list of schemas - for _, schemaID := range schemaIDs { + for _, sID := range schemaIDs { + schemaID := sID // Save schema IDs - ctx.SetUserValue(RequestSchemaID, strconv2.Itoa(schemaID)) - a.Routers[schemaID].Handler(ctx) + ctx.SetUserValue(web.RequestSchemaID, strconv2.Itoa(schemaID)) + var r http.Request + if err := fasthttpadaptor.ConvertRequest(ctx, &r, true); err != nil { + a.Log.WithFields(logrus.Fields{ + "error": err, + "host": strconv.B2S(ctx.Request.Header.Host()), + "path": strconv.B2S(ctx.Path()), + "method": strconv.B2S(ctx.Request.Header.Method()), + "request_id": ctx.UserValue(web.RequestID), + }).Error("error converting request") + return + } + + // find the handler with the OAS information + rctx := chi.NewRouteContext() + handler := a.Routers[schemaID].Find(rctx, strconv.B2S(ctx.Method()), strconv.B2S(ctx.Request.URI().Path())) + + // handler not found in the OAS + if handler == nil { + keyValidationErrors := strconv2.Itoa(schemaID) + web.APIModePostfixValidationErrors + keyStatusCode := strconv2.Itoa(schemaID) + web.APIModePostfixStatusCode + + // OPTIONS methods are passed if the passOPTIONS is set to true + if a.passOPTIONS == true && strconv.B2S(ctx.Method()) == fasthttp.MethodOptions { + ctx.SetUserValue(keyStatusCode, fasthttp.StatusOK) + a.Log.WithFields(logrus.Fields{ + "host": strconv.B2S(ctx.Request.Header.Host()), + "path": strconv.B2S(ctx.Path()), + "method": strconv.B2S(ctx.Request.Header.Method()), + "request_id": ctx.UserValue(web.RequestID), + }).Debug("Pass request with OPTIONS method") + continue + } + + // Method or Path were not found + a.Log.WithFields(logrus.Fields{ + "host": strconv.B2S(ctx.Request.Header.Host()), + "path": strconv.B2S(ctx.Path()), + "method": strconv.B2S(ctx.Request.Header.Method()), + "request_id": ctx.UserValue(web.RequestID), + }).Debug("Method or path were not found") + ctx.SetUserValue(keyValidationErrors, []*web.ValidationError{{Message: ErrMethodAndPathNotFound.Error(), Code: ErrCodeMethodAndPathNotFound, SchemaID: &schemaID}}) + ctx.SetUserValue(keyStatusCode, fasthttp.StatusForbidden) + continue + } + + // add router context to get URL params in the Handler + ctx.SetUserValue(chi.RouteCtxKey, rctx) + + if err := handler(ctx); err != nil { + a.Log.WithFields(logrus.Fields{ + "error": err, + "host": strconv.B2S(ctx.Request.Header.Host()), + "path": strconv.B2S(ctx.Path()), + "method": strconv.B2S(ctx.Request.Header.Method()), + "request_id": ctx.UserValue(web.RequestID), + }).Error("error in the request handler") + } } - responseSummary := make([]*APIModeResponseSummary, 0, len(schemaIDs)) - responseErrors := make([]*ValidationError, 0) + responseSummary := make([]*web.APIModeResponseSummary, 0, len(schemaIDs)) + responseErrors := make([]*web.ValidationError, 0) for i := 0; i < len(schemaIDs); i++ { - if statusCode, ok := ctx.UserValue(GlobalResponseStatusCodeKey).(int); ok { + if statusCode, ok := ctx.UserValue(web.GlobalResponseStatusCodeKey).(int); ok { ctx.Response.Header.Reset() ctx.Response.Header.SetStatusCode(statusCode) return } - statusCode, ok := ctx.UserValue(strconv2.Itoa(schemaIDs[i]) + APIModePostfixStatusCode).(int) + statusCode, ok := ctx.UserValue(strconv2.Itoa(schemaIDs[i]) + web.APIModePostfixStatusCode).(int) if !ok { // set summary for the schema ID in pass Options mode if a.passOPTIONS && strconv.B2S(ctx.Method()) == fasthttp.MethodOptions { - responseSummary = append(responseSummary, &APIModeResponseSummary{ + responseSummary = append(responseSummary, &web.APIModeResponseSummary{ SchemaID: &schemaIDs[i], StatusCode: &statusOK, }) @@ -268,19 +273,19 @@ func (a *APIModeApp) APIModeHandler(ctx *fasthttp.RequestCtx) { statusCode = fasthttp.StatusInternalServerError } - responseSummary = append(responseSummary, &APIModeResponseSummary{ + responseSummary = append(responseSummary, &web.APIModeResponseSummary{ SchemaID: &schemaIDs[i], StatusCode: &statusCode, }) - if validationErrors, ok := ctx.UserValue(strconv2.Itoa(schemaIDs[i]) + APIModePostfixValidationErrors).([]*ValidationError); ok && validationErrors != nil { + if validationErrors, ok := ctx.UserValue(strconv2.Itoa(schemaIDs[i]) + web.APIModePostfixValidationErrors).([]*web.ValidationError); ok && validationErrors != nil { responseErrors = append(responseErrors, validationErrors...) } } // Add schema IDs that were not found in the DB to the response for i := 0; i < len(notFoundSchemaIDs); i++ { - responseSummary = append(responseSummary, &APIModeResponseSummary{ + responseSummary = append(responseSummary, &web.APIModeResponseSummary{ SchemaID: ¬FoundSchemaIDs[i], StatusCode: &statusInternalError, }) @@ -294,9 +299,9 @@ func (a *APIModeApp) APIModeHandler(ctx *fasthttp.RequestCtx) { ctx.Request.Header.SetMethod(fasthttp.MethodGet) } - if err := Respond(ctx, APIModeResponse{Summary: responseSummary, Errors: responseErrors}, fasthttp.StatusOK); err != nil { + if err := web.Respond(ctx, web.APIModeResponse{Summary: responseSummary, Errors: responseErrors}, fasthttp.StatusOK); err != nil { a.Log.WithFields(logrus.Fields{ - "request_id": ctx.UserValue(RequestID), + "request_id": ctx.UserValue(web.RequestID), "host": string(ctx.Request.Header.Host()), "path": string(ctx.Path()), "method": string(ctx.Request.Header.Method()), diff --git a/cmd/api-firewall/internal/handlers/api/openapi.go b/cmd/api-firewall/internal/handlers/api/openapi.go index 9512c0d..c24e44d 100644 --- a/cmd/api-firewall/internal/handlers/api/openapi.go +++ b/cmd/api-firewall/internal/handlers/api/openapi.go @@ -3,6 +3,7 @@ package api import ( "context" "fmt" + "github.com/wallarm/api-firewall/internal/platform/chi" "net/http" strconv2 "strconv" "strings" @@ -106,11 +107,7 @@ func (s *APIMode) APIModeHandler(ctx *fasthttp.RequestCtx) error { var pathParams map[string]string if s.CustomRoute.ParametersNumberInPath > 0 { - pathParams = make(map[string]string) - - ctx.VisitUserValues(func(key []byte, value interface{}) { - pathParams[strconv.B2S(key)] = value.(string) - }) + pathParams = chi.AllURLParams(ctx) } // Convert fasthttp request to net/http request diff --git a/cmd/api-firewall/internal/handlers/api/routes.go b/cmd/api-firewall/internal/handlers/api/routes.go index 9e82e65..598d5c7 100644 --- a/cmd/api-firewall/internal/handlers/api/routes.go +++ b/cmd/api-firewall/internal/handlers/api/routes.go @@ -3,6 +3,7 @@ package api import ( "net/url" "os" + "runtime/debug" "sync" "github.com/corazawaf/coraza/v3" @@ -18,6 +19,18 @@ import ( ) func Handlers(lock *sync.RWMutex, cfg *config.APIMode, shutdown chan os.Signal, logger *logrus.Logger, storedSpecs database.DBOpenAPILoader, AllowedIPCache *allowiplist.AllowedIPsType, waf coraza.WAF) fasthttp.RequestHandler { + + // handle panic + defer func() { + if r := recover(); r != nil { + logger.Errorf("panic: %v", r) + + // Log the Go stack trace for this panic'd goroutine. + logger.Debugf("%s", debug.Stack()) + return + } + }() + // define FastJSON parsers pool var parserPool fastjson.ParserPool schemaIDs := storedSpecs.SchemaIDs() @@ -37,7 +50,7 @@ func Handlers(lock *sync.RWMutex, cfg *config.APIMode, shutdown chan os.Signal, } // Construct the web.App which holds all routes as well as common Middleware. - apps := web.NewAPIModeApp(lock, cfg.PassOptionsRequests, storedSpecs, shutdown, logger, mid.IPAllowlist(&ipAllowlistOptions), mid.WAFModSecurity(&modSecOptions), mid.Logger(logger), mid.MIMETypeIdentifier(logger), mid.Errors(logger), mid.Panics(logger)) + apps := NewAPIModeApp(lock, cfg.PassOptionsRequests, storedSpecs, shutdown, logger, mid.IPAllowlist(&ipAllowlistOptions), mid.WAFModSecurity(&modSecOptions), mid.Logger(logger), mid.MIMETypeIdentifier(logger), mid.Errors(logger), mid.Panics(logger)) for _, schemaID := range schemaIDs { @@ -86,20 +99,12 @@ func Handlers(lock *sync.RWMutex, cfg *config.APIMode, shutdown chan os.Signal, s.Log.Debugf("handler: Schema ID %d: OpenAPI version %s: Loaded path %s - %s", schemaID, storedSpecs.SpecificationVersion(schemaID), newSwagRouter.Routes[i].Method, updRoutePath) - apps.Handle(schemaID, newSwagRouter.Routes[i].Method, updRoutePath, s.APIModeHandler) + if err := apps.Handle(schemaID, newSwagRouter.Routes[i].Method, updRoutePath, s.APIModeHandler); err != nil { + logger.WithFields(logrus.Fields{"error": err, "schema_id": schemaID}).Error("Registration of the OAS failed") + } } - //set handler for default behavior (404, 405) - s := APIMode{ - CustomRoute: nil, - Log: logger, - Cfg: cfg, - ParserPool: &parserPool, - OpenAPIRouter: newSwagRouter, - SchemaID: schemaID, - } - apps.SetDefaultBehavior(schemaID, s.APIModeHandler) } - return apps.APIModeHandler + return apps.APIModeRouteHandler } diff --git a/cmd/api-firewall/internal/updater/wallarm_api2_update.db b/cmd/api-firewall/internal/updater/wallarm_api2_update.db index 31a9d14c27dc418efd93f4f03cb6b29fb97cce9a..603b995af0300b6f2561ae7b5fb24cd9a10ea460 100644 GIT binary patch delta 36 qcmZo@U~6b#n;^x+Bsx*X2}o{Ch>~SA+PquVMTl`zlfg0r#sC1d^$Gw0 delta 36 qcmZo@U~6b#n;^x+xO<|E6Oi1P5GBhfvU#_xixA_KCWB=Li~#`ABMS}y diff --git a/cmd/api-firewall/tests/main_api_mode_test.go b/cmd/api-firewall/tests/main_api_mode_test.go index 45663a8..739ed92 100644 --- a/cmd/api-firewall/tests/main_api_mode_test.go +++ b/cmd/api-firewall/tests/main_api_mode_test.go @@ -2646,9 +2646,9 @@ func (s *APIModeServiceTests) testAPIModeInvalidRouteInRequest(t *testing.T) { t.Errorf("Incorrect error code. Expected: %d and got %d", DefaultSchemaID, *apifwResponse.Summary[0].SchemaID) } - if *apifwResponse.Summary[0].StatusCode != fasthttp.StatusInternalServerError { + if *apifwResponse.Summary[0].StatusCode != fasthttp.StatusForbidden { t.Errorf("Incorrect result status. Expected: %d and got %d", - fasthttp.StatusInternalServerError, *apifwResponse.Summary[0].StatusCode) + fasthttp.StatusForbidden, *apifwResponse.Summary[0].StatusCode) } } } @@ -2703,9 +2703,9 @@ func (s *APIModeServiceTests) testAPIModeInvalidRouteInRequestInMultipleSchemas( t.Errorf("Incorrect error code. Expected: %d or %d and got %d", DefaultSchemaID, DefaultCopySchemaID, *apifwResponse.Summary[0].SchemaID) } - if *apifwResponse.Summary[i].StatusCode != fasthttp.StatusInternalServerError { + if *apifwResponse.Summary[i].StatusCode != fasthttp.StatusForbidden { t.Errorf("Incorrect result status. Expected: %d and got %d", - fasthttp.StatusInternalServerError, *apifwResponse.Summary[0].StatusCode) + fasthttp.StatusForbidden, *apifwResponse.Summary[0].StatusCode) } } } diff --git a/internal/mid/allowiplist.go b/internal/mid/allowiplist.go index 6b29ab2..01dd802 100644 --- a/internal/mid/allowiplist.go +++ b/internal/mid/allowiplist.go @@ -23,7 +23,7 @@ type IPAllowListOptions struct { var errAccessDeniedIP = errors.New("access denied to this IP") -// This function checks if an IP is allowed else gives error +// The IPAllowlist function checks if an IP is allowed else gives error func IPAllowlist(options *IPAllowListOptions) web.Middleware { // This is the actual middleware function to be executed. diff --git a/internal/platform/chi/LICENSE b/internal/platform/chi/LICENSE new file mode 100644 index 0000000..d99f02f --- /dev/null +++ b/internal/platform/chi/LICENSE @@ -0,0 +1,20 @@ +Copyright (c) 2015-present Peter Kieltyka (https://github.com/pkieltyka), Google Inc. + +MIT License + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/internal/platform/chi/chi.go b/internal/platform/chi/chi.go new file mode 100644 index 0000000..0d7a5c2 --- /dev/null +++ b/internal/platform/chi/chi.go @@ -0,0 +1,30 @@ +package chi + +import "github.com/wallarm/api-firewall/internal/platform/web" + +// NewRouter returns a new Mux object that implements the Router interface. +func NewRouter() *Mux { + return NewMux() +} + +// Router consisting of the core routing methods used by chi's Mux, +// using only the standard net/http. +type Router interface { + Routes + + // AddEndpoint adds routes for `pattern` that matches + // the `method` HTTP method. + AddEndpoint(method, pattern string, handler web.Handler) error +} + +// Routes interface adds two methods for router traversal, which is also +// used by the `docgen` subpackage to generation documentation for Routers. +type Routes interface { + // Routes returns the routing tree in an easily traversable structure. + Routes() []Route + + // Find searches the routing tree for a handler that matches + // the method/path - similar to routing a http request, but without + // executing the handler thereafter. + Find(rctx *Context, method, path string) web.Handler +} diff --git a/internal/platform/chi/context.go b/internal/platform/chi/context.go new file mode 100644 index 0000000..f5c506a --- /dev/null +++ b/internal/platform/chi/context.go @@ -0,0 +1,162 @@ +package chi + +import ( + "strings" + + "github.com/valyala/fasthttp" +) + +// URLParam returns the url parameter from a fasthttp.Request object. +func URLParam(ctx *fasthttp.RequestCtx, key string) string { + if rctx := RouteContext(ctx); rctx != nil { + return rctx.URLParam(key) + } + return "" +} + +// AllURLParams returns the map of the url parameters from a fasthttp.Request object. +func AllURLParams(ctx *fasthttp.RequestCtx) map[string]string { + if rctx := RouteContext(ctx); rctx != nil { + params := make(map[string]string) + for i := range rctx.URLParams.Keys { + params[rctx.URLParams.Keys[i]] = rctx.URLParams.Values[i] + } + return params + } + + return nil +} + +// RouteContext returns chi's routing Context object from a +// http.Request Context. +func RouteContext(ctx *fasthttp.RequestCtx) *Context { + val, _ := ctx.Value(RouteCtxKey).(*Context) + return val +} + +// NewRouteContext returns a new routing Context object. +func NewRouteContext() *Context { + return &Context{} +} + +var ( + // RouteCtxKey is the context.Context key to store the request context. + RouteCtxKey = &contextKey{"RouteContext"} +) + +// Context is the default routing context set on the root node of a +// request context to track route patterns, URL parameters and +// an optional routing path. +type Context struct { + Routes Routes + + // Routing path/method override used during the route search. + // See Mux#routeHTTP method. + RoutePath string + RouteMethod string + + // URLParams are the stack of routeParams captured during the + // routing lifecycle across a stack of sub-routers. + URLParams RouteParams + + // Route parameters matched for the current sub-router. It is + // intentionally unexported so it can't be tampered. + routeParams RouteParams + + // The endpoint routing pattern that matched the request URI path + // or `RoutePath` of the current sub-router. This value will update + // during the lifecycle of a request passing through a stack of + // sub-routers. + routePattern string + + // Routing pattern stack throughout the lifecycle of the request, + // across all connected routers. It is a record of all matching + // patterns across a stack of sub-routers. + RoutePatterns []string + + // methodNotAllowed hint + methodNotAllowed bool + methodsAllowed []methodTyp // allowed methods in case of a 405 +} + +// Reset a routing context to its initial state. +func (x *Context) Reset() { + x.Routes = nil + x.RoutePath = "" + x.RouteMethod = "" + x.RoutePatterns = x.RoutePatterns[:0] + x.URLParams.Keys = x.URLParams.Keys[:0] + x.URLParams.Values = x.URLParams.Values[:0] + + x.routePattern = "" + x.routeParams.Keys = x.routeParams.Keys[:0] + x.routeParams.Values = x.routeParams.Values[:0] + x.methodNotAllowed = false + x.methodsAllowed = x.methodsAllowed[:0] +} + +// URLParam returns the corresponding URL parameter value from the request +// routing context. +func (x *Context) URLParam(key string) string { + for k := len(x.URLParams.Keys) - 1; k >= 0; k-- { + if x.URLParams.Keys[k] == key { + return x.URLParams.Values[k] + } + } + return "" +} + +// RoutePattern builds the routing pattern string for the particular +// request, at the particular point during routing. This means, the value +// will change throughout the execution of a request in a router. That is +// why its advised to only use this value after calling the next handler. +// +// For example, +// +// func Instrument(next web.Handler) web.Handler { +// return web.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +// next.ServeHTTP(w, r) +// routePattern := chi.RouteContext(r.Context()).RoutePattern() +// measure(w, r, routePattern) +// }) +// } +func (x *Context) RoutePattern() string { + routePattern := strings.Join(x.RoutePatterns, "") + routePattern = replaceWildcards(routePattern) + if routePattern != "/" { + routePattern = strings.TrimSuffix(routePattern, "//") + routePattern = strings.TrimSuffix(routePattern, "/") + } + return routePattern +} + +// replaceWildcards takes a route pattern and recursively replaces all +// occurrences of "/*/" to "/". +func replaceWildcards(p string) string { + if strings.Contains(p, "/*/") { + return replaceWildcards(strings.Replace(p, "/*/", "/", -1)) + } + return p +} + +// RouteParams is a structure to track URL routing parameters efficiently. +type RouteParams struct { + Keys, Values []string +} + +// Add will append a URL parameter to the end of the route param +func (s *RouteParams) Add(key, value string) { + s.Keys = append(s.Keys, key) + s.Values = append(s.Values, value) +} + +// contextKey is a value for use with context.WithValue. It's used as +// a pointer so it fits in an interface{} without allocation. This technique +// for defining context keys was copied from Go 1.7's new use of context in net/http. +type contextKey struct { + name string +} + +func (k *contextKey) String() string { + return "chi context value " + k.name +} diff --git a/internal/platform/chi/context_test.go b/internal/platform/chi/context_test.go new file mode 100644 index 0000000..4731c70 --- /dev/null +++ b/internal/platform/chi/context_test.go @@ -0,0 +1,87 @@ +package chi + +import "testing" + +// TestRoutePattern tests correct in-the-middle wildcard removals. +// If user organizes a router like this: +// +// (router.go) +// +// r.Route("/v1", func(r chi.Router) { +// r.Mount("/resources", resourcesController{}.Router()) +// } +// +// (resources_controller.go) +// +// r.Route("/", func(r chi.Router) { +// r.Get("/{resource_id}", getResource()) +// // other routes... +// } +// +// This test checks how the route pattern is calculated +// "/v1/resources/{resource_id}" (right) +// "/v1/resources/*/{resource_id}" (wrong) +func TestRoutePattern(t *testing.T) { + routePatterns := []string{ + "/v1/*", + "/resources/*", + "/{resource_id}", + } + + x := &Context{ + RoutePatterns: routePatterns, + } + + if p := x.RoutePattern(); p != "/v1/resources/{resource_id}" { + t.Fatal("unexpected route pattern: " + p) + } + + x.RoutePatterns = []string{ + "/v1/*", + "/resources/*", + // Additional wildcard, depending on the router structure of the user + "/*", + "/{resource_id}", + } + + // Correctly removes in-the-middle wildcards instead of "/v1/resources/*/{resource_id}" + if p := x.RoutePattern(); p != "/v1/resources/{resource_id}" { + t.Fatal("unexpected route pattern: " + p) + } + + x.RoutePatterns = []string{ + "/v1/*", + "/resources/*", + // Even with many wildcards + "/*", + "/*", + "/*", + "/{resource_id}/*", // Keeping trailing wildcard + } + + // Correctly removes in-the-middle wildcards instead of "/v1/resources/*/*/{resource_id}/*" + if p := x.RoutePattern(); p != "/v1/resources/{resource_id}/*" { + t.Fatal("unexpected route pattern: " + p) + } + + x.RoutePatterns = []string{ + "/v1/*", + "/resources/*", + // And respects asterisks as part of the paths + "/*special_path/*", + "/with_asterisks*/*", + "/{resource_id}", + } + + // Correctly removes in-the-middle wildcards instead of "/v1/resourcesspecial_path/with_asterisks{resource_id}" + if p := x.RoutePattern(); p != "/v1/resources/*special_path/with_asterisks*/{resource_id}" { + t.Fatal("unexpected route pattern: " + p) + } + + // Testing for the root route pattern + x.RoutePatterns = []string{"/"} + // It should just return "/" as the pattern + if p := x.RoutePattern(); p != "/" { + t.Fatal("unexpected route pattern for root: " + p) + } +} diff --git a/internal/platform/chi/mux.go b/internal/platform/chi/mux.go new file mode 100644 index 0000000..c4d3fd9 --- /dev/null +++ b/internal/platform/chi/mux.go @@ -0,0 +1,82 @@ +package chi + +import ( + "fmt" + "strings" + + "github.com/wallarm/api-firewall/internal/platform/web" +) + +var _ Router = &Mux{} + +// Mux is a simple fastHTTP route multiplexer that parses a request path, +// records any URL params, and searched for the appropriate web.Handler. It implements +// the web.Handler interface and is friendly with the standard library. +type Mux struct { + // The radix trie router + tree *node +} + +// NewMux returns a newly initialized Mux object that implements the Router +// interface. +func NewMux() *Mux { + mux := &Mux{tree: &node{}} + return mux +} + +// AddEndpoint adds the route `pattern` that matches `method` http method to +// execute the `handler` web.Handler. +func (mx *Mux) AddEndpoint(method, pattern string, handler web.Handler) error { + m, ok := methodMap[strings.ToUpper(method)] + if !ok { + return fmt.Errorf("'%s' http method is not supported", method) + } + + if _, err := mx.handle(m, pattern, handler); err != nil { + return err + } + + return nil +} + +// Routes returns a slice of routing information from the tree, +// useful for traversing available routes of a router. +func (mx *Mux) Routes() []Route { + return mx.tree.routes() +} + +func (mx *Mux) Find(rctx *Context, method, path string) web.Handler { + m, ok := methodMap[method] + if !ok { + return nil + } + + node, _, h := mx.tree.FindRoute(rctx, m, path) + + if node != nil && node.subroutes != nil { + rctx.RoutePath = mx.nextRoutePath(rctx) + return node.subroutes.Find(rctx, method, rctx.RoutePath) + } + + return h +} + +// handle registers a web.Handler in the routing tree for a particular http method +// and routing pattern. +func (mx *Mux) handle(method methodTyp, pattern string, handler web.Handler) (*node, error) { + if len(pattern) == 0 || pattern[0] != '/' { + return nil, fmt.Errorf("routing pattern must begin with '/' in '%s'", pattern) + } + + // Add the endpoint to the tree and return the node + return mx.tree.InsertRoute(method, pattern, handler) +} + +func (mx *Mux) nextRoutePath(rctx *Context) string { + routePath := "/" + nx := len(rctx.routeParams.Keys) - 1 // index of last param in list + if nx >= 0 && rctx.routeParams.Keys[nx] == "*" && len(rctx.routeParams.Values) > nx { + routePath = "/" + rctx.routeParams.Values[nx] + } + return routePath +} diff --git a/internal/platform/chi/mux_test.go b/internal/platform/chi/mux_test.go new file mode 100644 index 0000000..9a4789e --- /dev/null +++ b/internal/platform/chi/mux_test.go @@ -0,0 +1,518 @@ +package chi + +import ( + "bytes" + "fmt" + "github.com/valyala/fasthttp" + "github.com/wallarm/api-firewall/internal/platform/web" + "io" + "net/http" + "testing" +) + +func TestMuxBasic(t *testing.T) { + + cxindex := func(ctx *fasthttp.RequestCtx) error { + ctx.SetStatusCode(200) + ctx.SetBody([]byte("hi peter")) + return nil + } + + ping := func(ctx *fasthttp.RequestCtx) error { + ctx.SetStatusCode(200) + ctx.SetBody([]byte(".")) + return nil + } + + headPing := func(ctx *fasthttp.RequestCtx) error { + ctx.Response.Header.Set("X-Ping", "1") + ctx.SetStatusCode(200) + return nil + } + + createPing := func(ctx *fasthttp.RequestCtx) error { + // create .... + ctx.SetStatusCode(201) + return nil + } + + pingAll := func(ctx *fasthttp.RequestCtx) error { + ctx.SetStatusCode(200) + ctx.SetBody([]byte("ping all")) + return nil + } + + pingAll2 := func(ctx *fasthttp.RequestCtx) error { + ctx.SetStatusCode(200) + ctx.SetBody([]byte("ping all2")) + return nil + } + + pingOne := func(ctx *fasthttp.RequestCtx) error { + ctx.SetStatusCode(200) + ctx.SetBody([]byte("ping one id: " + URLParam(ctx, "id"))) + return nil + } + + pingWoop := func(ctx *fasthttp.RequestCtx) error { + ctx.SetStatusCode(200) + ctx.SetBody([]byte("woop." + URLParam(ctx, "iidd"))) + return nil + } + + catchAll := func(ctx *fasthttp.RequestCtx) error { + ctx.SetStatusCode(200) + ctx.SetBody([]byte("catchall")) + return nil + } + + m := NewRouter() + m.AddEndpoint("GET", "/", cxindex) + m.AddEndpoint("GET", "/ping", ping) + + m.AddEndpoint("GET", "/pingall", pingAll) + m.AddEndpoint("get", "/ping/all", pingAll) + m.AddEndpoint("GET", "/ping/all2", pingAll2) + m.AddEndpoint("HEAD", "/ping", headPing) + m.AddEndpoint("POST", "/ping", createPing) + m.AddEndpoint("GET", "/ping/{id}", pingWoop) + m.AddEndpoint("POST", "/ping/{id}", pingOne) + m.AddEndpoint("GET", "/ping/{iidd}/woop", pingWoop) + m.AddEndpoint("POST", "/admin/*", catchAll) + + // GET / + if _, body := testRequest(t, m, "GET", "/", nil); body != "hi peter" { + t.Fatalf(body) + } + + // GET /ping + if _, body := testRequest(t, m, "GET", "/ping", nil); body != "." { + t.Fatalf(body) + } + + // GET /pingall + if _, body := testRequest(t, m, "GET", "/pingall", nil); body != "ping all" { + t.Fatalf(body) + } + + // GET /ping/all + if _, body := testRequest(t, m, "GET", "/ping/all", nil); body != "ping all" { + t.Fatalf(body) + } + + // GET /ping/all2 + if _, body := testRequest(t, m, "GET", "/ping/all2", nil); body != "ping all2" { + t.Fatalf(body) + } + + // POST /ping/123 + if _, body := testRequest(t, m, "POST", "/ping/123", nil); body != "ping one id: 123" { + t.Fatalf(body) + } + + // GET /ping/allan + if _, body := testRequest(t, m, "POST", "/ping/allan", nil); body != "ping one id: allan" { + t.Fatalf(body) + } + + // GET /ping/1/woop + if _, body := testRequest(t, m, "GET", "/ping/1/woop", nil); body != "woop.1" { + t.Fatalf(body) + } + + if status, _ := testRequest(t, m, "HEAD", "/ping", nil); status != 200 { + t.Fatal("wrong status code") + } + + // GET /admin/catch-this + if status, body := testRequest(t, m, "GET", "/admin/catch-thazzzzz", nil); body != "" && status != 0 { + t.Fatalf("method not found failed") + } + + // POST /admin/catch-this + if _, body := testRequest(t, m, "POST", "/admin/casdfsadfs", bytes.NewReader([]byte{})); body != "catchall" { + t.Fatalf(body) + } +} + +func TestMuxHandlePatternValidation(t *testing.T) { + testCases := []struct { + name string + pattern string + shouldPanic bool + method string // Method to be used for the test request + path string // Path to be used for the test request + expectedBody string // Expected response body + expectedStatus int // Expected HTTP status code + }{ + // Valid patterns + { + name: "Valid pattern without HTTP GET", + pattern: "/user/{id}", + shouldPanic: false, + method: "GET", + path: "/user/123", + expectedBody: "without-prefix GET", + expectedStatus: http.StatusOK, + }, + { + name: "Valid pattern with HTTP POST", + pattern: "POST /products/{id}", + shouldPanic: false, + method: "POST", + path: "/products/456", + expectedBody: "with-prefix POST", + expectedStatus: http.StatusOK, + }, + // Invalid patterns + { + name: "Invalid pattern with no method", + pattern: "INVALID/user/{id}", + shouldPanic: true, + }, + { + name: "Invalid pattern with supported method", + pattern: "GET/user/{id}", + shouldPanic: true, + }, + { + name: "Invalid pattern with unsupported method", + pattern: "UNSUPPORTED /unsupported-method", + shouldPanic: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + defer func() { + if r := recover(); r != nil && !tc.shouldPanic { + t.Errorf("Unexpected panic for pattern %s:\n%v", tc.pattern, r) + } + }() + + r := NewRouter() + r.AddEndpoint(tc.method, tc.path, func(ctx *fasthttp.RequestCtx) error { + ctx.SetStatusCode(200) + ctx.SetBody([]byte(tc.expectedBody)) + return nil + }) + + if !tc.shouldPanic { + statusCode, body := testRequest(t, r, tc.method, tc.path, nil) + if body != tc.expectedBody || statusCode != tc.expectedStatus { + t.Errorf("Expected status %d and body %s; got status %d and body %s for pattern %s", + tc.expectedStatus, tc.expectedBody, statusCode, body, tc.pattern) + } + } + }) + } +} + +func TestMuxEmptyParams(t *testing.T) { + r := NewRouter() + if err := r.AddEndpoint("GET", "/users/{x}/{y}/{z}", func(ctx *fasthttp.RequestCtx) error { + x := URLParam(ctx, "x") + y := URLParam(ctx, "y") + z := URLParam(ctx, "z") + ctx.SetBody([]byte(fmt.Sprintf("%s-%s-%s", x, y, z))) + + return nil + }); err != nil { + t.Fatal(err) + } + + if _, body := testRequest(t, r, "GET", "/users/a/b/c", nil); body != "a-b-c" { + t.Fatalf(body) + } + if _, body := testRequest(t, r, "GET", "/users///c", nil); body != "--c" { + t.Fatalf(body) + } +} + +func TestMuxWildcardRoute(t *testing.T) { + handler := func(ctx *fasthttp.RequestCtx) error { return nil } + + r := NewRouter() + if err := r.AddEndpoint("GET", "/*/wildcard/must/be/at/end", handler); err == nil { + t.Fatal("expected error") + } +} + +func TestMuxWildcardRouteCheckTwo(t *testing.T) { + handler := func(ctx *fasthttp.RequestCtx) error { return nil } + + r := NewRouter() + if err := r.AddEndpoint("GET", "/*/wildcard/{must}/be/at/end", handler); err == nil { + t.Fatal("expected error") + } + +} + +func TestMuxRegexp(t *testing.T) { + r := NewRouter() + + if err := r.AddEndpoint("GET", "/{param:[0-9]*}/test", func(ctx *fasthttp.RequestCtx) error { + ctx.SetBody([]byte(fmt.Sprintf("Hi: %s", URLParam(ctx, "param")))) + return nil + }); err != nil { + t.Fatal(err) + } + + if _, body := testRequest(t, r, "GET", "//test", nil); body != "Hi: " { + t.Fatal(body) + } +} + +func TestMuxRegexp2(t *testing.T) { + r := NewRouter() + if err := r.AddEndpoint("GET", "/foo-{suffix:[a-z]{2,3}}.json", func(ctx *fasthttp.RequestCtx) error { + ctx.SetBody([]byte(URLParam(ctx, "suffix"))) + return nil + }); err != nil { + t.Fatal(err) + } + + if _, body := testRequest(t, r, "GET", "/foo-.json", nil); body != "" { + t.Fatalf(body) + } + if _, body := testRequest(t, r, "GET", "/foo-abc.json", nil); body != "abc" { + t.Fatalf(body) + } +} + +func TestMuxRegexp3(t *testing.T) { + r := NewRouter() + if err := r.AddEndpoint("GET", "/one/{firstId:[a-z0-9-]+}/{secondId:[a-z]+}/first", func(ctx *fasthttp.RequestCtx) error { + ctx.SetBody([]byte("first")) + return nil + }); err != nil { + t.Fatal(err) + } + if err := r.AddEndpoint("GET", "/one/{firstId:[a-z0-9-_]+}/{secondId:[0-9]+}/second", func(ctx *fasthttp.RequestCtx) error { + ctx.SetBody([]byte("second")) + return nil + }); err != nil { + t.Fatal(err) + } + + if err := r.AddEndpoint("DELETE", "/one/{firstId:[a-z0-9-_]+}/{secondId:[0-9]+}/second", func(ctx *fasthttp.RequestCtx) error { + ctx.SetBody([]byte("third")) + return nil + }); err != nil { + t.Fatal(err) + } + + if _, body := testRequest(t, r, "GET", "/one/hello/peter/first", nil); body != "first" { + t.Fatalf(body) + } + if _, body := testRequest(t, r, "GET", "/one/hithere/123/second", nil); body != "second" { + t.Fatalf(body) + } + if _, body := testRequest(t, r, "DELETE", "/one/hithere/123/second", nil); body != "third" { + t.Fatalf(body) + } +} + +func TestMuxSubrouterWildcardParam(t *testing.T) { + h := web.Handler(func(ctx *fasthttp.RequestCtx) error { + ctx.SetBody([]byte(fmt.Sprintf("param:%v *:%v", URLParam(ctx, "param"), URLParam(ctx, "*")))) + return nil + }) + + r := NewRouter() + + if err := r.AddEndpoint("GET", "/bare/{param}", h); err != nil { + t.Fatal(err) + } + if err := r.AddEndpoint("GET", "/bare/{param}/*", h); err != nil { + t.Fatal(err) + } + + if err := r.AddEndpoint("GET", "/case0/{param}", h); err != nil { + t.Fatal(err) + } + if err := r.AddEndpoint("GET", "/case0/{param}/*", h); err != nil { + t.Fatal(err) + } + + if _, body := testRequest(t, r, "GET", "/bare/hi", nil); body != "param:hi *:" { + t.Fatalf(body) + } + if _, body := testRequest(t, r, "GET", "/bare/hi/yes", nil); body != "param:hi *:yes" { + t.Fatalf(body) + } + if _, body := testRequest(t, r, "GET", "/case0/hi", nil); body != "param:hi *:" { + t.Fatalf(body) + } + if _, body := testRequest(t, r, "GET", "/case0/hi/yes", nil); body != "param:hi *:yes" { + t.Fatalf(body) + } +} + +func TestEscapedURLParams(t *testing.T) { + m := NewRouter() + if err := m.AddEndpoint("GET", "/api/{identifier}/{region}/{size}/{rotation}/*", func(ctx *fasthttp.RequestCtx) error { + ctx.SetStatusCode(200) + rctx := RouteContext(ctx) + if rctx == nil { + t.Error("no context") + return nil + } + identifier := URLParam(ctx, "identifier") + if identifier != "http:%2f%2fexample.com%2fimage.png" { + t.Errorf("identifier path parameter incorrect %s", identifier) + return nil + } + region := URLParam(ctx, "region") + if region != "full" { + t.Errorf("region path parameter incorrect %s", region) + return nil + } + size := URLParam(ctx, "size") + if size != "max" { + t.Errorf("size path parameter incorrect %s", size) + return nil + } + rotation := URLParam(ctx, "rotation") + if rotation != "0" { + t.Errorf("rotation path parameter incorrect %s", rotation) + return nil + } + ctx.SetBody([]byte("success")) + return nil + }); err != nil { + t.Fatal(err) + } + + if _, body := testRequest(t, m, "GET", "/api/http:%2f%2fexample.com%2fimage.png/full/max/0/color.png", nil); body != "success" { + t.Fatalf(body) + } +} + +func TestCustomHTTPMethod(t *testing.T) { + // first we must register this method to be accepted, then we + // can define method handlers on the router below + if err := RegisterMethod("BOO"); err != nil { + t.Fatal(err) + } + + r := NewRouter() + if err := r.AddEndpoint("GET", "/", func(ctx *fasthttp.RequestCtx) error { + ctx.SetBody([]byte(".")) + return nil + }); err != nil { + t.Fatal(err) + } + + // note the custom BOO method for route /hi + if err := r.AddEndpoint("BOO", "/hi", func(ctx *fasthttp.RequestCtx) error { + ctx.SetBody([]byte("custom method")) + return nil + }); err != nil { + t.Fatal(err) + } + + if _, body := testRequest(t, r, "GET", "/", nil); body != "." { + t.Fatalf(body) + } + if _, body := testRequest(t, r, "BOO", "/hi", nil); body != "custom method" { + t.Fatalf(body) + } +} + +func testRequest(t *testing.T, mux *Mux, method, path string, body io.Reader) (int, string) { + + rctx := NewRouteContext() + handler := mux.Find(rctx, method, path) + + if handler == nil { + return 0, "" + } + + req := fasthttp.AcquireRequest() + req.SetRequestURI(path) + req.Header.SetMethod(method) + + if body != nil { + reqBody, err := io.ReadAll(body) + if err != nil { + t.Fatal(err) + return 0, "" + } + req.SetBody(reqBody) + } + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + // add url params + reqCtx.SetUserValue(RouteCtxKey, rctx) + + if err := handler(&reqCtx); err != nil { + t.Fatal(err) + return 0, "" + } + + return reqCtx.Response.StatusCode(), string(reqCtx.Response.Body()) +} + +type ctxKey struct { + name string +} + +func (k ctxKey) String() string { + return "context value " + k.name +} + +//func BenchmarkMux(b *testing.B) { +// h1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) +// h2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) +// h3 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) +// h4 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) +// h5 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) +// h6 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) +// +// mx := NewRouter() +// mx.Get("/", h1) +// mx.Get("/hi", h2) +// mx.Post("/hi-post", h2) // used to benchmark 405 responses +// mx.Get("/sup/{id}/and/{this}", h3) +// mx.Get("/sup/{id}/{bar:foo}/{this}", h3) +// +// mx.Route("/sharing/{x}/{hash}", func(mx Router) { +// mx.Get("/", h4) // subrouter-1 +// mx.Get("/{network}", h5) // subrouter-1 +// mx.Get("/twitter", h5) +// mx.Route("/direct", func(mx Router) { +// mx.Get("/", h6) // subrouter-2 +// mx.Get("/download", h6) +// }) +// }) +// +// routes := []string{ +// "/", +// "/hi", +// "/hi-post", +// "/sup/123/and/this", +// "/sup/123/foo/this", +// "/sharing/z/aBc", // subrouter-1 +// "/sharing/z/aBc/twitter", // subrouter-1 +// "/sharing/z/aBc/direct", // subrouter-2 +// "/sharing/z/aBc/direct/download", // subrouter-2 +// } +// +// for _, path := range routes { +// b.Run("route:"+path, func(b *testing.B) { +// w := httptest.NewRecorder() +// r, _ := http.NewRequest("GET", path, nil) +// +// b.ReportAllocs() +// b.ResetTimer() +// +// for i := 0; i < b.N; i++ { +// mx.ServeHTTP(w, r) +// } +// }) +// } +//} diff --git a/internal/platform/chi/path_value.go b/internal/platform/chi/path_value.go new file mode 100644 index 0000000..8ab89cd --- /dev/null +++ b/internal/platform/chi/path_value.go @@ -0,0 +1,22 @@ +//go:build go1.22 +// +build go1.22 + +package chi + +import ( + "github.com/valyala/fasthttp" +) + +// supportsPathValue is true if the Go version is 1.22 and above. +// +// If this is true, `net/http.Request` has methods `SetPathValue` and `PathValue`. +const supportsPathValue = true + +// setPathValue sets the path values in the Request value +// based on the provided request context. +func setPathValue(rctx *Context, ctx *fasthttp.RequestCtx) { + for i, key := range rctx.URLParams.Keys { + value := rctx.URLParams.Values[i] + ctx.SetUserValue(key, value) + } +} diff --git a/internal/platform/chi/path_value_fallback.go b/internal/platform/chi/path_value_fallback.go new file mode 100644 index 0000000..9f0288b --- /dev/null +++ b/internal/platform/chi/path_value_fallback.go @@ -0,0 +1,21 @@ +//go:build !go1.22 +// +build !go1.22 + +package chi + +import ( + "github.com/valyala/fasthttp" +) + +// supportsPathValue is true if the Go version is 1.22 and above. +// +// If this is true, `net/http.Request` has methods `SetPathValue` and `PathValue`. +const supportsPathValue = false + +// setPathValue sets the path values in the Request value +// based on the provided request context. +// +// setPathValue is only supported in Go 1.22 and above so +// this is just a blank function so that it compiles. +func setPathValue(rctx *Context, ctx *fasthttp.RequestCtx) { +} diff --git a/internal/platform/chi/path_value_test.go b/internal/platform/chi/path_value_test.go new file mode 100644 index 0000000..5a48698 --- /dev/null +++ b/internal/platform/chi/path_value_test.go @@ -0,0 +1,77 @@ +//go:build go1.22 +// +build go1.22 + +package chi + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestPathValue(t *testing.T) { + testCases := []struct { + name string + pattern string + method string + pathKeys []string + requestPath string + expectedBody string + }{ + { + name: "Basic path value", + pattern: "/hubs/{hubID}", + method: "GET", + pathKeys: []string{"hubID"}, + requestPath: "/hubs/392", + expectedBody: "392", + }, + { + name: "Two path values", + pattern: "/users/{userID}/conversations/{conversationID}", + method: "POST", + pathKeys: []string{"userID", "conversationID"}, + requestPath: "/users/Gojo/conversations/2948", + expectedBody: "Gojo 2948", + }, + { + name: "Wildcard path", + pattern: "/users/{userID}/friends/*", + method: "POST", + pathKeys: []string{"userID", "*"}, + requestPath: "/users/Gojo/friends/all-of-them/and/more", + expectedBody: "Gojo all-of-them/and/more", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + r := NewRouter() + + r.Handle(tc.method+" "+tc.pattern, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + pathValues := []string{} + for _, pathKey := range tc.pathKeys { + pathValue := r.PathValue(pathKey) + if pathValue == "" { + pathValue = "NOT_FOUND:" + pathKey + } + + pathValues = append(pathValues, pathValue) + } + + body := strings.Join(pathValues, " ") + + w.Write([]byte(body)) + })) + + ts := httptest.NewServer(r) + defer ts.Close() + + _, body := testRequest(t, ts, tc.method, tc.requestPath, nil) + if body != tc.expectedBody { + t.Fatalf("expecting %q, got %q", tc.expectedBody, body) + } + }) + } +} diff --git a/internal/platform/chi/tree.go b/internal/platform/chi/tree.go new file mode 100644 index 0000000..762dded --- /dev/null +++ b/internal/platform/chi/tree.go @@ -0,0 +1,915 @@ +package chi + +// Radix tree implementation below is a based on the original work by +// Armon Dadgar in https://github.com/armon/go-radix/blob/master/radix.go +// (MIT licensed). It's been heavily modified for use as a HTTP routing tree. + +import ( + "fmt" + "regexp" + "sort" + "strconv" + "strings" + + "github.com/valyala/fasthttp" + "github.com/wallarm/api-firewall/internal/platform/web" +) + +type methodTyp uint + +const ( + mSTUB methodTyp = 1 << iota + mCONNECT + mDELETE + mGET + mHEAD + mOPTIONS + mPATCH + mPOST + mPUT + mTRACE +) + +var mALL = mCONNECT | mDELETE | mGET | mHEAD | + mOPTIONS | mPATCH | mPOST | mPUT | mTRACE + +var methodMap = map[string]methodTyp{ + fasthttp.MethodConnect: mCONNECT, + fasthttp.MethodDelete: mDELETE, + fasthttp.MethodGet: mGET, + fasthttp.MethodHead: mHEAD, + fasthttp.MethodOptions: mOPTIONS, + fasthttp.MethodPatch: mPATCH, + fasthttp.MethodPost: mPOST, + fasthttp.MethodPut: mPUT, + fasthttp.MethodTrace: mTRACE, +} + +// RegisterMethod adds support for custom HTTP method handlers, available +// via Router#Method and Router#MethodFunc +func RegisterMethod(method string) error { + if method == "" { + return nil + } + method = strings.ToUpper(method) + if _, ok := methodMap[method]; ok { + return nil + } + n := len(methodMap) + if n > strconv.IntSize-2 { + return fmt.Errorf("max number of methods reached (%d)", strconv.IntSize) + } + mt := methodTyp(2 << n) + methodMap[method] = mt + mALL |= mt + + return nil +} + +type nodeTyp uint8 + +const ( + ntStatic nodeTyp = iota // /home + ntRegexp // /{id:[0-9]+} + ntParam // /{user} + ntCatchAll // /api/v1/* +) + +type node struct { + // subroutes on the leaf node + subroutes Routes + + // regexp matcher for regexp nodes + rex *regexp.Regexp + + // HTTP handler endpoints on the leaf node + endpoints endpoints + + // prefix is the common prefix we ignore + prefix string + + // child nodes should be stored in-order for iteration, + // in groups of the node type. + children [ntCatchAll + 1]nodes + + // first byte of the child prefix + tail byte + + // node type: static, regexp, param, catchAll + typ nodeTyp + + // first byte of the prefix + label byte +} + +// endpoints is a mapping of http method constants to handlers +// for a given route. +type endpoints map[methodTyp]*endpoint + +type endpoint struct { + // endpoint handler + handler web.Handler + + // pattern is the routing pattern for handler nodes + pattern string + + // parameter keys recorded on handler nodes + paramKeys []string +} + +func (s endpoints) Value(method methodTyp) *endpoint { + mh, ok := s[method] + if !ok { + mh = &endpoint{} + s[method] = mh + } + return mh +} + +func (n *node) InsertRoute(method methodTyp, pattern string, handler web.Handler) (*node, error) { + var parent *node + search := pattern + + for { + // Handle key exhaustion + if len(search) == 0 { + // Insert or update the node's leaf handler + if err := n.setEndpoint(method, handler, pattern); err != nil { + return nil, err + } + return n, nil + } + + // We're going to be searching for a wild node next, + // in this case, we need to get the tail + var label = search[0] + var segTail byte + var segEndIdx int + var segTyp nodeTyp + var segRexpat string + if label == '{' || label == '*' { + var err error + segTyp, _, segRexpat, segTail, _, segEndIdx, err = patNextSegment(search) + if err != nil { + return nil, err + } + } + + var prefix string + if segTyp == ntRegexp { + prefix = segRexpat + } + + // Look for the edge to attach to + parent = n + n = n.getEdge(segTyp, label, segTail, prefix) + + // No edge, create one + if n == nil { + child := &node{label: label, tail: segTail, prefix: search} + hn, err := parent.addChild(child, search) + if err != nil { + return nil, err + } + if err := hn.setEndpoint(method, handler, pattern); err != nil { + return nil, err + } + + return hn, nil + } + + // Found an edge to match the pattern + + if n.typ > ntStatic { + // We found a param node, trim the param from the search path and continue. + // This param/wild pattern segment would already be on the tree from a previous + // call to addChild when creating a new node. + search = search[segEndIdx:] + continue + } + + // Static nodes fall below here. + // Determine longest prefix of the search key on match. + commonPrefix := longestPrefix(search, n.prefix) + if commonPrefix == len(n.prefix) { + // the common prefix is as long as the current node's prefix we're attempting to insert. + // keep the search going. + search = search[commonPrefix:] + continue + } + + // Split the node + child := &node{ + typ: ntStatic, + prefix: search[:commonPrefix], + } + if err := parent.replaceChild(search[0], segTail, child); err != nil { + return nil, err + } + + // Restore the existing node + n.label = n.prefix[commonPrefix] + n.prefix = n.prefix[commonPrefix:] + if _, err := child.addChild(n, n.prefix); err != nil { + return nil, err + } + + // If the new key is a subset, set the method/handler on this node and finish. + search = search[commonPrefix:] + if len(search) == 0 { + if err := child.setEndpoint(method, handler, pattern); err != nil { + return nil, err + } + return child, nil + } + + // Create a new edge for the node + subchild := &node{ + typ: ntStatic, + label: search[0], + prefix: search, + } + hn, err := child.addChild(subchild, search) + if err != nil { + return nil, err + } + if err := hn.setEndpoint(method, handler, pattern); err != nil { + return nil, err + } + return hn, nil + } +} + +// addChild appends the new `child` node to the tree using the `pattern` as the trie key. +// For a URL router like chi's, we split the static, param, regexp and wildcard segments +// into different nodes. In addition, addChild will recursively call itself until every +// pattern segment is added to the url pattern tree as individual nodes, depending on type. +func (n *node) addChild(child *node, prefix string) (*node, error) { + search := prefix + + // handler leaf node added to the tree is the child. + // this may be overridden later down the flow + hn := child + + // Parse next segment + segTyp, _, segRexpat, segTail, segStartIdx, segEndIdx, err := patNextSegment(search) + if err != nil { + return nil, err + } + + // Add child depending on next up segment + switch segTyp { + + case ntStatic: + // Search prefix is all static (that is, has no params in path) + // noop + + default: + // Search prefix contains a param, regexp or wildcard + + if segTyp == ntRegexp { + rex, err := regexp.Compile(segRexpat) + if err != nil { + return nil, fmt.Errorf("invalid regexp pattern '%s' in route param", segRexpat) + } + child.prefix = segRexpat + child.rex = rex + } + + if segStartIdx == 0 { + // Route starts with a param + child.typ = segTyp + + if segTyp == ntCatchAll { + segStartIdx = -1 + } else { + segStartIdx = segEndIdx + } + if segStartIdx < 0 { + segStartIdx = len(search) + } + child.tail = segTail // for params, we set the tail + + if segStartIdx != len(search) { + // add static edge for the remaining part, split the end. + // its not possible to have adjacent param nodes, so its certainly + // going to be a static node next. + + search = search[segStartIdx:] // advance search position + + nn := &node{ + typ: ntStatic, + label: search[0], + prefix: search, + } + hn, err = child.addChild(nn, search) + if err != nil { + return nil, err + } + } + + } else if segStartIdx > 0 { + // Route has some param + + // starts with a static segment + child.typ = ntStatic + child.prefix = search[:segStartIdx] + child.rex = nil + + // add the param edge node + search = search[segStartIdx:] + + nn := &node{ + typ: segTyp, + label: search[0], + tail: segTail, + } + hn, err = child.addChild(nn, search) + if err != nil { + return nil, err + } + + } + } + + n.children[child.typ] = append(n.children[child.typ], child) + n.children[child.typ].Sort() + return hn, nil +} + +func (n *node) replaceChild(label, tail byte, child *node) error { + for i := 0; i < len(n.children[child.typ]); i++ { + if n.children[child.typ][i].label == label && n.children[child.typ][i].tail == tail { + n.children[child.typ][i] = child + n.children[child.typ][i].label = label + n.children[child.typ][i].tail = tail + return nil + } + } + return fmt.Errorf("replacing missing child") +} + +func (n *node) getEdge(ntyp nodeTyp, label, tail byte, prefix string) *node { + nds := n.children[ntyp] + for i := 0; i < len(nds); i++ { + if nds[i].label == label && nds[i].tail == tail { + if ntyp == ntRegexp && nds[i].prefix != prefix { + continue + } + return nds[i] + } + } + return nil +} + +func (n *node) setEndpoint(method methodTyp, handler web.Handler, pattern string) error { + // Set the handler for the method type on the node + if n.endpoints == nil { + n.endpoints = make(endpoints) + } + + paramKeys, err := patParamKeys(pattern) + if err != nil { + return err + } + + if method&mSTUB == mSTUB { + n.endpoints.Value(mSTUB).handler = handler + } + if method&mALL == mALL { + h := n.endpoints.Value(mALL) + h.handler = handler + h.pattern = pattern + h.paramKeys = paramKeys + for _, m := range methodMap { + h := n.endpoints.Value(m) + h.handler = handler + h.pattern = pattern + h.paramKeys = paramKeys + } + } else { + h := n.endpoints.Value(method) + h.handler = handler + h.pattern = pattern + h.paramKeys = paramKeys + } + return nil +} + +func (n *node) FindRoute(rctx *Context, method methodTyp, path string) (*node, endpoints, web.Handler) { + // Reset the context routing pattern and params + rctx.routePattern = "" + rctx.routeParams.Keys = rctx.routeParams.Keys[:0] + rctx.routeParams.Values = rctx.routeParams.Values[:0] + + // Find the routing handlers for the path + rn := n.findRoute(rctx, method, path) + if rn == nil { + return nil, nil, nil + } + + // Record the routing params in the request lifecycle + rctx.URLParams.Keys = append(rctx.URLParams.Keys, rctx.routeParams.Keys...) + rctx.URLParams.Values = append(rctx.URLParams.Values, rctx.routeParams.Values...) + + // Record the routing pattern in the request lifecycle + if rn.endpoints[method].pattern != "" { + rctx.routePattern = rn.endpoints[method].pattern + rctx.RoutePatterns = append(rctx.RoutePatterns, rctx.routePattern) + } + + return rn, rn.endpoints, rn.endpoints[method].handler +} + +// Recursive edge traversal by checking all nodeTyp groups along the way. +// It's like searching through a multi-dimensional radix trie. +func (n *node) findRoute(rctx *Context, method methodTyp, path string) *node { + nn := n + search := path + + for t, nds := range nn.children { + ntyp := nodeTyp(t) + if len(nds) == 0 { + continue + } + + var xn *node + xsearch := search + + var label byte + if search != "" { + label = search[0] + } + + switch ntyp { + case ntStatic: + xn = nds.findEdge(label) + if xn == nil || !strings.HasPrefix(xsearch, xn.prefix) { + continue + } + xsearch = xsearch[len(xn.prefix):] + + case ntParam, ntRegexp: + // short-circuit and return no matching route for empty param values + if xsearch == "" { + continue + } + + // serially loop through each node grouped by the tail delimiter + for idx := 0; idx < len(nds); idx++ { + xn = nds[idx] + + // label for param nodes is the delimiter byte + p := strings.IndexByte(xsearch, xn.tail) + + if p < 0 { + if xn.tail == '/' { + p = len(xsearch) + } else { + continue + } + } else if ntyp == ntRegexp && p == 0 { + continue + } + + if ntyp == ntRegexp && xn.rex != nil { + if !xn.rex.MatchString(xsearch[:p]) { + continue + } + } else if strings.IndexByte(xsearch[:p], '/') != -1 { + // avoid a match across path segments + continue + } + + prevlen := len(rctx.routeParams.Values) + rctx.routeParams.Values = append(rctx.routeParams.Values, xsearch[:p]) + xsearch = xsearch[p:] + + if len(xsearch) == 0 { + if xn.isLeaf() { + h := xn.endpoints[method] + if h != nil && h.handler != nil { + rctx.routeParams.Keys = append(rctx.routeParams.Keys, h.paramKeys...) + return xn + } + + for endpoints := range xn.endpoints { + if endpoints == mALL || endpoints == mSTUB { + continue + } + rctx.methodsAllowed = append(rctx.methodsAllowed, endpoints) + } + + // flag that the routing context found a route, but not a corresponding + // supported method + rctx.methodNotAllowed = true + } + } + + // recursively find the next node on this branch + fin := xn.findRoute(rctx, method, xsearch) + if fin != nil { + return fin + } + + // not found on this branch, reset vars + rctx.routeParams.Values = rctx.routeParams.Values[:prevlen] + xsearch = search + } + + rctx.routeParams.Values = append(rctx.routeParams.Values, "") + + default: + // catch-all nodes + rctx.routeParams.Values = append(rctx.routeParams.Values, search) + xn = nds[0] + xsearch = "" + } + + if xn == nil { + continue + } + + // did we find it yet? + if len(xsearch) == 0 { + if xn.isLeaf() { + h := xn.endpoints[method] + if h != nil && h.handler != nil { + rctx.routeParams.Keys = append(rctx.routeParams.Keys, h.paramKeys...) + return xn + } + + for endpoints := range xn.endpoints { + if endpoints == mALL || endpoints == mSTUB { + continue + } + rctx.methodsAllowed = append(rctx.methodsAllowed, endpoints) + } + + // flag that the routing context found a route, but not a corresponding + // supported method + rctx.methodNotAllowed = true + } + } + + // recursively find the next node.. + fin := xn.findRoute(rctx, method, xsearch) + if fin != nil { + return fin + } + + // Did not find final handler, let's remove the param here if it was set + if xn.typ > ntStatic { + if len(rctx.routeParams.Values) > 0 { + rctx.routeParams.Values = rctx.routeParams.Values[:len(rctx.routeParams.Values)-1] + } + } + + } + + return nil +} + +func (n *node) findEdge(ntyp nodeTyp, label byte) *node { + nds := n.children[ntyp] + num := len(nds) + idx := 0 + + switch ntyp { + case ntStatic, ntParam, ntRegexp: + i, j := 0, num-1 + for i <= j { + idx = i + (j-i)/2 + if label > nds[idx].label { + i = idx + 1 + } else if label < nds[idx].label { + j = idx - 1 + } else { + i = num // breaks cond + } + } + if nds[idx].label != label { + return nil + } + return nds[idx] + + default: // catch all + return nds[idx] + } +} + +func (n *node) isLeaf() bool { + return n.endpoints != nil +} + +func (n *node) findPattern(pattern string) (bool, error) { + nn := n + for _, nds := range nn.children { + if len(nds) == 0 { + continue + } + + n = nn.findEdge(nds[0].typ, pattern[0]) + if n == nil { + continue + } + + var idx int + var xpattern string + + switch n.typ { + case ntStatic: + idx = longestPrefix(pattern, n.prefix) + if idx < len(n.prefix) { + continue + } + + case ntParam, ntRegexp: + idx = strings.IndexByte(pattern, '}') + 1 + + case ntCatchAll: + idx = longestPrefix(pattern, "*") + + default: + return false, fmt.Errorf("unknown node type") + } + + xpattern = pattern[idx:] + if len(xpattern) == 0 { + return true, fmt.Errorf("unknown node type") + } + + return n.findPattern(xpattern) + } + return false, fmt.Errorf("unknown node type") +} + +func (n *node) routes() []Route { + rts := []Route{} + + n.walk(func(eps endpoints, subroutes Routes) bool { + if eps[mSTUB] != nil && eps[mSTUB].handler != nil && subroutes == nil { + return false + } + + // Group methodHandlers by unique patterns + pats := make(map[string]endpoints) + + for mt, h := range eps { + if h.pattern == "" { + continue + } + p, ok := pats[h.pattern] + if !ok { + p = endpoints{} + pats[h.pattern] = p + } + p[mt] = h + } + + for p, mh := range pats { + hs := make(map[string]web.Handler) + if mh[mALL] != nil && mh[mALL].handler != nil { + hs["*"] = mh[mALL].handler + } + + for mt, h := range mh { + if h.handler == nil { + continue + } + m := methodTypString(mt) + if m == "" { + continue + } + hs[m] = h.handler + } + + rt := Route{subroutes, hs, p} + rts = append(rts, rt) + } + + return false + }) + + return rts +} + +func (n *node) walk(fn func(eps endpoints, subroutes Routes) bool) bool { + // Visit the leaf values if any + if (n.endpoints != nil || n.subroutes != nil) && fn(n.endpoints, n.subroutes) { + return true + } + + // Recurse on the children + for _, ns := range n.children { + for _, cn := range ns { + if cn.walk(fn) { + return true + } + } + } + return false +} + +// patNextSegment returns the next segment details from a pattern: +// node type, param key, regexp string, param tail byte, param starting index, param ending index +func patNextSegment(pattern string) (nodeTyp, string, string, byte, int, int, error) { + ps := strings.Index(pattern, "{") + ws := strings.Index(pattern, "*") + + if ps < 0 && ws < 0 { + return ntStatic, "", "", 0, 0, len(pattern), nil // we return the entire thing + } + + // Sanity check + if ps >= 0 && ws >= 0 && ws < ps { + return ntStatic, "", "", 0, 0, 0, fmt.Errorf("wildcard '*' must be the last pattern in a route, otherwise use a '{param}'") + } + + var tail byte = '/' // Default endpoint tail to / byte + + if ps >= 0 { + // Param/Regexp pattern is next + nt := ntParam + + // Read to closing } taking into account opens and closes in curl count (cc) + cc := 0 + pe := ps + for i, c := range pattern[ps:] { + if c == '{' { + cc++ + } else if c == '}' { + cc-- + if cc == 0 { + pe = ps + i + break + } + } + } + if pe == ps { + return ntStatic, "", "", 0, 0, 0, fmt.Errorf("route param closing delimiter '}' is missing") + } + + key := pattern[ps+1 : pe] + pe++ // set end to next position + + if pe < len(pattern) { + tail = pattern[pe] + } + + var rexpat string + if idx := strings.Index(key, ":"); idx >= 0 { + nt = ntRegexp + rexpat = key[idx+1:] + key = key[:idx] + } + + if len(rexpat) > 0 { + if rexpat[0] != '^' { + rexpat = "^" + rexpat + } + if rexpat[len(rexpat)-1] != '$' { + rexpat += "$" + } + } + + return nt, key, rexpat, tail, ps, pe, nil + } + + // Wildcard pattern as finale + if ws < len(pattern)-1 { + return ntStatic, "", "", 0, 0, 0, fmt.Errorf("wildcard '*' must be the last value in a route. trim trailing text or use a '{param}' instead") + } + return ntCatchAll, "*", "", 0, ws, len(pattern), nil +} + +func patParamKeys(pattern string) ([]string, error) { + pat := pattern + paramKeys := []string{} + for { + ptyp, paramKey, _, _, _, e, err := patNextSegment(pat) + if err != nil { + return nil, err + } + if ptyp == ntStatic { + return paramKeys, nil + } + for i := 0; i < len(paramKeys); i++ { + if paramKeys[i] == paramKey { + return nil, fmt.Errorf("routing pattern '%s' contains duplicate param key, '%s'", pattern, paramKey) + } + } + paramKeys = append(paramKeys, paramKey) + pat = pat[e:] + } +} + +// longestPrefix finds the length of the shared prefix +// of two strings +func longestPrefix(k1, k2 string) int { + maxLen := len(k1) + if l := len(k2); l < maxLen { + maxLen = l + } + var i int + for i = 0; i < maxLen; i++ { + if k1[i] != k2[i] { + break + } + } + return i +} + +func methodTypString(method methodTyp) string { + for s, t := range methodMap { + if method == t { + return s + } + } + return "" +} + +type nodes []*node + +// Sort the list of nodes by label +func (ns nodes) Sort() { sort.Sort(ns); ns.tailSort() } +func (ns nodes) Len() int { return len(ns) } +func (ns nodes) Swap(i, j int) { ns[i], ns[j] = ns[j], ns[i] } +func (ns nodes) Less(i, j int) bool { return ns[i].label < ns[j].label } + +// tailSort pushes nodes with '/' as the tail to the end of the list for param nodes. +// The list order determines the traversal order. +func (ns nodes) tailSort() { + for i := len(ns) - 1; i >= 0; i-- { + if ns[i].typ > ntStatic && ns[i].tail == '/' { + ns.Swap(i, len(ns)-1) + return + } + } +} + +func (ns nodes) findEdge(label byte) *node { + num := len(ns) + idx := 0 + i, j := 0, num-1 + for i <= j { + idx = i + (j-i)/2 + if label > ns[idx].label { + i = idx + 1 + } else if label < ns[idx].label { + j = idx - 1 + } else { + i = num // breaks cond + } + } + if ns[idx].label != label { + return nil + } + return ns[idx] +} + +// Route describes the details of a routing handler. +// Handlers map key is an HTTP method +type Route struct { + SubRoutes Routes + Handlers map[string]web.Handler + Pattern string +} + +// WalkFunc is the type of the function called for each method and route visited by Walk. +type WalkFunc func(method string, route string, handler web.Handler, middlewares ...func(web.Handler) web.Handler) error + +// Walk walks any router tree that implements Routes interface. +func Walk(r Routes, walkFn WalkFunc) error { + return walk(r, walkFn, "") +} + +func walk(r Routes, walkFn WalkFunc, parentRoute string, parentMw ...func(web.Handler) web.Handler) error { + for _, route := range r.Routes() { + mws := make([]func(web.Handler) web.Handler, len(parentMw)) + copy(mws, parentMw) + + if route.SubRoutes != nil { + if err := walk(route.SubRoutes, walkFn, parentRoute+route.Pattern, mws...); err != nil { + return err + } + continue + } + + for method, handler := range route.Handlers { + if method == "*" { + // Ignore a "catchAll" method, since we pass down all the specific methods for each route. + continue + } + + fullRoute := parentRoute + route.Pattern + fullRoute = strings.Replace(fullRoute, "/*/", "/", -1) + + if err := walkFn(method, fullRoute, handler, mws...); err != nil { + return err + } + } + } + + return nil +} diff --git a/internal/platform/chi/tree_test.go b/internal/platform/chi/tree_test.go new file mode 100644 index 0000000..53e02e9 --- /dev/null +++ b/internal/platform/chi/tree_test.go @@ -0,0 +1,643 @@ +package chi + +import ( + "fmt" + "log" + "testing" + + "github.com/valyala/fasthttp" + "github.com/wallarm/api-firewall/internal/platform/web" +) + +func TestTree(t *testing.T) { + hStub := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hIndex := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hFavicon := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hArticleList := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hArticleNear := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hArticleShow := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hArticleShowRelated := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hArticleShowOpts := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hArticleSlug := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hArticleByUser := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hUserList := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hUserShow := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hAdminCatchall := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hAdminAppShow := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hAdminAppShowCatchall := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hUserProfile := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hUserSuper := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hUserAll := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hHubView1 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hHubView2 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hHubView3 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + + tr := &node{} + + if _, err := tr.InsertRoute(mGET, "/", hIndex); err != nil { + t.Fatal(err) + } + if _, err := tr.InsertRoute(mGET, "/favicon.ico", hFavicon); err != nil { + t.Fatal(err) + } + + if _, err := tr.InsertRoute(mGET, "/pages/*", hStub); err != nil { + t.Fatal(err) + } + + if _, err := tr.InsertRoute(mGET, "/article", hArticleList); err != nil { + t.Fatal(err) + } + if _, err := tr.InsertRoute(mGET, "/article/", hArticleList); err != nil { + t.Fatal(err) + } + + if _, err := tr.InsertRoute(mGET, "/article/near", hArticleNear); err != nil { + t.Fatal(err) + } + if _, err := tr.InsertRoute(mGET, "/article/{id}", hStub); err != nil { + t.Fatal(err) + } + if _, err := tr.InsertRoute(mGET, "/article/{id}", hArticleShow); err != nil { + t.Fatal(err) + } + if _, err := tr.InsertRoute(mGET, "/article/{id}", hArticleShow); err != nil { + t.Fatal(err) + } // duplicate will have no effect + + if _, err := tr.InsertRoute(mGET, "/article/@{user}", hArticleByUser); err != nil { + t.Fatal(err) + } + + if _, err := tr.InsertRoute(mGET, "/article/{sup}/{opts}", hArticleShowOpts); err != nil { + t.Fatal(err) + } + if _, err := tr.InsertRoute(mGET, "/article/{id}/{opts}", hArticleShowOpts); err != nil { + t.Fatal(err) + } // overwrite above route, latest wins + + if _, err := tr.InsertRoute(mGET, "/article/{iffd}/edit", hStub); err != nil { + t.Fatal(err) + } + if _, err := tr.InsertRoute(mGET, "/article/{id}//related", hArticleShowRelated); err != nil { + t.Fatal(err) + } + if _, err := tr.InsertRoute(mGET, "/article/slug/{month}/-/{day}/{year}", hArticleSlug); err != nil { + t.Fatal(err) + } + + if _, err := tr.InsertRoute(mGET, "/admin/user", hUserList); err != nil { + t.Fatal(err) + } + if _, err := tr.InsertRoute(mGET, "/admin/user/", hStub); err != nil { + t.Fatal(err) + } // will get replaced by next route + + if _, err := tr.InsertRoute(mGET, "/admin/user/", hUserList); err != nil { + t.Fatal(err) + } + + if _, err := tr.InsertRoute(mGET, "/admin/user//{id}", hUserShow); err != nil { + t.Fatal(err) + } + if _, err := tr.InsertRoute(mGET, "/admin/user/{id}", hUserShow); err != nil { + t.Fatal(err) + } + + if _, err := tr.InsertRoute(mGET, "/admin/apps/{id}", hAdminAppShow); err != nil { + t.Fatal(err) + } + if _, err := tr.InsertRoute(mGET, "/admin/apps/{id}/*", hAdminAppShowCatchall); err != nil { + t.Fatal(err) + } + + if _, err := tr.InsertRoute(mGET, "/admin/*", hStub); err != nil { + t.Fatal(err) + } // catchall segment will get replaced by next route + + if _, err := tr.InsertRoute(mGET, "/admin/*", hAdminCatchall); err != nil { + t.Fatal(err) + } + + if _, err := tr.InsertRoute(mGET, "/users/{userID}/profile", hUserProfile); err != nil { + t.Fatal(err) + } + if _, err := tr.InsertRoute(mGET, "/users/super/*", hUserSuper); err != nil { + t.Fatal(err) + } + if _, err := tr.InsertRoute(mGET, "/users/*", hUserAll); err != nil { + t.Fatal(err) + } + + if _, err := tr.InsertRoute(mGET, "/hubs/{hubID}/view", hHubView1); err != nil { + t.Fatal(err) + } + if _, err := tr.InsertRoute(mGET, "/hubs/{hubID}/view/*", hHubView2); err != nil { + t.Fatal(err) + } + if _, err := tr.InsertRoute(mGET, "/hubs/{hubID}/users", hHubView3); err != nil { + t.Fatal(err) + } + + tests := []struct { + r string // input request path + h web.Handler // output matched handler + k []string // output param keys + v []string // output param values + }{ + {r: "/", h: hIndex, k: []string{}, v: []string{}}, + {r: "/favicon.ico", h: hFavicon, k: []string{}, v: []string{}}, + + {r: "/pages", h: nil, k: []string{}, v: []string{}}, + {r: "/pages/", h: hStub, k: []string{"*"}, v: []string{""}}, + {r: "/pages/yes", h: hStub, k: []string{"*"}, v: []string{"yes"}}, + + {r: "/article", h: hArticleList, k: []string{}, v: []string{}}, + {r: "/article/", h: hArticleList, k: []string{}, v: []string{}}, + {r: "/article/near", h: hArticleNear, k: []string{}, v: []string{}}, + {r: "/article/neard", h: hArticleShow, k: []string{"id"}, v: []string{"neard"}}, + {r: "/article/123", h: hArticleShow, k: []string{"id"}, v: []string{"123"}}, + {r: "/article/123/456", h: hArticleShowOpts, k: []string{"id", "opts"}, v: []string{"123", "456"}}, + {r: "/article/@peter", h: hArticleByUser, k: []string{"user"}, v: []string{"peter"}}, + {r: "/article/22//related", h: hArticleShowRelated, k: []string{"id"}, v: []string{"22"}}, + {r: "/article/111/edit", h: hStub, k: []string{"iffd"}, v: []string{"111"}}, + {r: "/article/slug/sept/-/4/2015", h: hArticleSlug, k: []string{"month", "day", "year"}, v: []string{"sept", "4", "2015"}}, + {r: "/article/:id", h: hArticleShow, k: []string{"id"}, v: []string{":id"}}, + + {r: "/admin/user", h: hUserList, k: []string{}, v: []string{}}, + {r: "/admin/user/", h: hUserList, k: []string{}, v: []string{}}, + {r: "/admin/user/1", h: hUserShow, k: []string{"id"}, v: []string{"1"}}, + {r: "/admin/user//1", h: hUserShow, k: []string{"id"}, v: []string{"1"}}, + {r: "/admin/hi", h: hAdminCatchall, k: []string{"*"}, v: []string{"hi"}}, + {r: "/admin/lots/of/:fun", h: hAdminCatchall, k: []string{"*"}, v: []string{"lots/of/:fun"}}, + {r: "/admin/apps/333", h: hAdminAppShow, k: []string{"id"}, v: []string{"333"}}, + {r: "/admin/apps/333/woot", h: hAdminAppShowCatchall, k: []string{"id", "*"}, v: []string{"333", "woot"}}, + + {r: "/hubs/123/view", h: hHubView1, k: []string{"hubID"}, v: []string{"123"}}, + {r: "/hubs/123/view/index.html", h: hHubView2, k: []string{"hubID", "*"}, v: []string{"123", "index.html"}}, + {r: "/hubs/123/users", h: hHubView3, k: []string{"hubID"}, v: []string{"123"}}, + + {r: "/users/123/profile", h: hUserProfile, k: []string{"userID"}, v: []string{"123"}}, + {r: "/users/super/123/okay/yes", h: hUserSuper, k: []string{"*"}, v: []string{"123/okay/yes"}}, + {r: "/users/123/okay/yes", h: hUserAll, k: []string{"*"}, v: []string{"123/okay/yes"}}, + } + + // log.Println("~~~~~~~~~") + // log.Println("~~~~~~~~~") + // debugPrintTree(0, 0, tr, 0) + // log.Println("~~~~~~~~~") + // log.Println("~~~~~~~~~") + + for i, tt := range tests { + rctx := NewRouteContext() + + _, handlers, _ := tr.FindRoute(rctx, mGET, tt.r) + + var handler web.Handler + if methodHandler, ok := handlers[mGET]; ok { + handler = methodHandler.handler + } + + paramKeys := rctx.routeParams.Keys + paramValues := rctx.routeParams.Values + + if fmt.Sprintf("%v", tt.h) != fmt.Sprintf("%v", handler) { + t.Errorf("input [%d]: find '%s' expecting handler:%v , got:%v", i, tt.r, tt.h, handler) + } + if !stringSliceEqual(tt.k, paramKeys) { + t.Errorf("input [%d]: find '%s' expecting paramKeys:(%d)%v , got:(%d)%v", i, tt.r, len(tt.k), tt.k, len(paramKeys), paramKeys) + } + if !stringSliceEqual(tt.v, paramValues) { + t.Errorf("input [%d]: find '%s' expecting paramValues:(%d)%v , got:(%d)%v", i, tt.r, len(tt.v), tt.v, len(paramValues), paramValues) + } + } +} + +func TestTreeMoar(t *testing.T) { + hStub := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub1 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub2 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub3 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub4 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub5 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub6 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub7 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub8 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub9 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub10 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub11 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub12 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub13 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub14 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub15 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub16 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + + // TODO: panic if we see {id}{x} because we're missing a delimiter, its not possible. + // also {:id}* is not possible. + + tr := &node{} + + if _, err := tr.InsertRoute(mGET, "/articlefun", hStub5); err != nil { + t.Fatal(err) + } + if _, err := tr.InsertRoute(mGET, "/articles/{id}", hStub); err != nil { + t.Fatal(err) + } + if _, err := tr.InsertRoute(mDELETE, "/articles/{slug}", hStub8); err != nil { + t.Fatal(err) + } + if _, err := tr.InsertRoute(mGET, "/articles/search", hStub1); err != nil { + t.Fatal(err) + } + if _, err := tr.InsertRoute(mGET, "/articles/{id}:delete", hStub8); err != nil { + t.Fatal(err) + } + if _, err := tr.InsertRoute(mGET, "/articles/{iidd}!sup", hStub4); err != nil { + t.Fatal(err) + } + if _, err := tr.InsertRoute(mGET, "/articles/{id}:{op}", hStub3); err != nil { + t.Fatal(err) + } + if _, err := tr.InsertRoute(mGET, "/articles/{id}:{op}", hStub2); err != nil { + t.Fatal(err) // this route sets a new handler for the above route + } + if _, err := tr.InsertRoute(mGET, "/articles/{slug:^[a-z]+}/posts", hStub); err != nil { // up to tail '/' will only match if contents match the rex + t.Fatal(err) + } + if _, err := tr.InsertRoute(mGET, "/articles/{id}/posts/{pid}", hStub6); err != nil { // /articles/123/posts/1 + t.Fatal(err) + } + if _, err := tr.InsertRoute(mGET, "/articles/{id}/posts/{month}/{day}/{year}/{slug}", hStub7); err != nil { // /articles/123/posts/09/04/1984/juice + t.Fatal(err) + } + if _, err := tr.InsertRoute(mGET, "/articles/{id}.json", hStub10); err != nil { + t.Fatal(err) + } + if _, err := tr.InsertRoute(mGET, "/articles/{id}/data.json", hStub11); err != nil { + t.Fatal(err) + } + if _, err := tr.InsertRoute(mGET, "/articles/files/{file}.{ext}", hStub12); err != nil { + t.Fatal(err) + } + if _, err := tr.InsertRoute(mPUT, "/articles/me", hStub13); err != nil { + t.Fatal(err) + } + + // TODO: make a separate test case for this one.. + // tr.InsertRoute(mGET, "/articles/{id}/{id}", hStub1) // panic expected, we're duplicating param keys + + tr.InsertRoute(mGET, "/pages/*", hStub) + tr.InsertRoute(mGET, "/pages/*", hStub9) + + tr.InsertRoute(mGET, "/users/{id}", hStub14) + tr.InsertRoute(mGET, "/users/{id}/settings/{key}", hStub15) + tr.InsertRoute(mGET, "/users/{id}/settings/*", hStub16) + + tests := []struct { + h web.Handler + r string + k []string + v []string + m methodTyp + }{ + {m: mGET, r: "/articles/search", h: hStub1, k: []string{}, v: []string{}}, + {m: mGET, r: "/articlefun", h: hStub5, k: []string{}, v: []string{}}, + {m: mGET, r: "/articles/123", h: hStub, k: []string{"id"}, v: []string{"123"}}, + {m: mDELETE, r: "/articles/123mm", h: hStub8, k: []string{"slug"}, v: []string{"123mm"}}, + {m: mGET, r: "/articles/789:delete", h: hStub8, k: []string{"id"}, v: []string{"789"}}, + {m: mGET, r: "/articles/789!sup", h: hStub4, k: []string{"iidd"}, v: []string{"789"}}, + {m: mGET, r: "/articles/123:sync", h: hStub2, k: []string{"id", "op"}, v: []string{"123", "sync"}}, + {m: mGET, r: "/articles/456/posts/1", h: hStub6, k: []string{"id", "pid"}, v: []string{"456", "1"}}, + {m: mGET, r: "/articles/456/posts/09/04/1984/juice", h: hStub7, k: []string{"id", "month", "day", "year", "slug"}, v: []string{"456", "09", "04", "1984", "juice"}}, + {m: mGET, r: "/articles/456.json", h: hStub10, k: []string{"id"}, v: []string{"456"}}, + {m: mGET, r: "/articles/456/data.json", h: hStub11, k: []string{"id"}, v: []string{"456"}}, + + {m: mGET, r: "/articles/files/file.zip", h: hStub12, k: []string{"file", "ext"}, v: []string{"file", "zip"}}, + {m: mGET, r: "/articles/files/photos.tar.gz", h: hStub12, k: []string{"file", "ext"}, v: []string{"photos", "tar.gz"}}, + {m: mGET, r: "/articles/files/photos.tar.gz", h: hStub12, k: []string{"file", "ext"}, v: []string{"photos", "tar.gz"}}, + + {m: mPUT, r: "/articles/me", h: hStub13, k: []string{}, v: []string{}}, + {m: mGET, r: "/articles/me", h: hStub, k: []string{"id"}, v: []string{"me"}}, + {m: mGET, r: "/pages", h: nil, k: []string{}, v: []string{}}, + {m: mGET, r: "/pages/", h: hStub9, k: []string{"*"}, v: []string{""}}, + {m: mGET, r: "/pages/yes", h: hStub9, k: []string{"*"}, v: []string{"yes"}}, + + {m: mGET, r: "/users/1", h: hStub14, k: []string{"id"}, v: []string{"1"}}, + {m: mGET, r: "/users/", h: nil, k: []string{}, v: []string{}}, + {m: mGET, r: "/users/2/settings/password", h: hStub15, k: []string{"id", "key"}, v: []string{"2", "password"}}, + {m: mGET, r: "/users/2/settings/", h: hStub16, k: []string{"id", "*"}, v: []string{"2", ""}}, + } + + // log.Println("~~~~~~~~~") + // log.Println("~~~~~~~~~") + // debugPrintTree(0, 0, tr, 0) + // log.Println("~~~~~~~~~") + // log.Println("~~~~~~~~~") + + for i, tt := range tests { + rctx := NewRouteContext() + + _, handlers, _ := tr.FindRoute(rctx, tt.m, tt.r) + + var handler web.Handler + if methodHandler, ok := handlers[tt.m]; ok { + handler = methodHandler.handler + } + + paramKeys := rctx.routeParams.Keys + paramValues := rctx.routeParams.Values + + if fmt.Sprintf("%v", tt.h) != fmt.Sprintf("%v", handler) { + t.Errorf("input [%d]: find '%s' expecting handler:%v , got:%v", i, tt.r, tt.h, handler) + } + if !stringSliceEqual(tt.k, paramKeys) { + t.Errorf("input [%d]: find '%s' expecting paramKeys:(%d)%v , got:(%d)%v", i, tt.r, len(tt.k), tt.k, len(paramKeys), paramKeys) + } + if !stringSliceEqual(tt.v, paramValues) { + t.Errorf("input [%d]: find '%s' expecting paramValues:(%d)%v , got:(%d)%v", i, tt.r, len(tt.v), tt.v, len(paramValues), paramValues) + } + } +} + +func TestTreeRegexp(t *testing.T) { + hStub1 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub2 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub3 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub4 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub5 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub6 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub7 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + + tr := &node{} + if _, err := tr.InsertRoute(mGET, "/articles/{rid:^[0-9]{5,6}}", hStub7); err != nil { + t.Fatal(err) + } + if _, err := tr.InsertRoute(mGET, "/articles/{zid:^0[0-9]+}", hStub3); err != nil { + t.Fatal(err) + } + if _, err := tr.InsertRoute(mGET, "/articles/{name:^@[a-z]+}/posts", hStub4); err != nil { + t.Fatal(err) + } + if _, err := tr.InsertRoute(mGET, "/articles/{op:^[0-9]+}/run", hStub5); err != nil { + t.Fatal(err) + } + if _, err := tr.InsertRoute(mGET, "/articles/{id:^[0-9]+}", hStub1); err != nil { + t.Fatal(err) + } + if _, err := tr.InsertRoute(mGET, "/articles/{id:^[1-9]+}-{aux}", hStub6); err != nil { + t.Fatal(err) + } + if _, err := tr.InsertRoute(mGET, "/articles/{slug}", hStub2); err != nil { + t.Fatal(err) + } + + // log.Println("~~~~~~~~~") + // log.Println("~~~~~~~~~") + // debugPrintTree(0, 0, tr, 0) + // log.Println("~~~~~~~~~") + // log.Println("~~~~~~~~~") + + tests := []struct { + r string // input request path + h web.Handler // output matched handler + k []string // output param keys + v []string // output param values + }{ + {r: "/articles", h: nil, k: []string{}, v: []string{}}, + {r: "/articles/12345", h: hStub7, k: []string{"rid"}, v: []string{"12345"}}, + {r: "/articles/123", h: hStub1, k: []string{"id"}, v: []string{"123"}}, + {r: "/articles/how-to-build-a-router", h: hStub2, k: []string{"slug"}, v: []string{"how-to-build-a-router"}}, + {r: "/articles/0456", h: hStub3, k: []string{"zid"}, v: []string{"0456"}}, + {r: "/articles/@pk/posts", h: hStub4, k: []string{"name"}, v: []string{"@pk"}}, + {r: "/articles/1/run", h: hStub5, k: []string{"op"}, v: []string{"1"}}, + {r: "/articles/1122", h: hStub1, k: []string{"id"}, v: []string{"1122"}}, + {r: "/articles/1122-yes", h: hStub6, k: []string{"id", "aux"}, v: []string{"1122", "yes"}}, + } + + for i, tt := range tests { + rctx := NewRouteContext() + + _, handlers, _ := tr.FindRoute(rctx, mGET, tt.r) + + var handler web.Handler + if methodHandler, ok := handlers[mGET]; ok { + handler = methodHandler.handler + } + + paramKeys := rctx.routeParams.Keys + paramValues := rctx.routeParams.Values + + if fmt.Sprintf("%v", tt.h) != fmt.Sprintf("%v", handler) { + t.Errorf("input [%d]: find '%s' expecting handler:%v , got:%v", i, tt.r, tt.h, handler) + } + if !stringSliceEqual(tt.k, paramKeys) { + t.Errorf("input [%d]: find '%s' expecting paramKeys:(%d)%v , got:(%d)%v", i, tt.r, len(tt.k), tt.k, len(paramKeys), paramKeys) + } + if !stringSliceEqual(tt.v, paramValues) { + t.Errorf("input [%d]: find '%s' expecting paramValues:(%d)%v , got:(%d)%v", i, tt.r, len(tt.v), tt.v, len(paramValues), paramValues) + } + } +} + +func TestTreeRegexpRecursive(t *testing.T) { + hStub1 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub2 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + + tr := &node{} + if _, err := tr.InsertRoute(mGET, "/one/{firstId:[a-z0-9-]+}/{secondId:[a-z0-9-]+}/first", hStub1); err != nil { + t.Fatal(err) + } + if _, err := tr.InsertRoute(mGET, "/one/{firstId:[a-z0-9-_]+}/{secondId:[a-z0-9-_]+}/second", hStub2); err != nil { + t.Fatal(err) + } + + // log.Println("~~~~~~~~~") + // log.Println("~~~~~~~~~") + // debugPrintTree(0, 0, tr, 0) + // log.Println("~~~~~~~~~") + // log.Println("~~~~~~~~~") + + tests := []struct { + r string // input request path + h web.Handler // output matched handler + k []string // output param keys + v []string // output param values + }{ + {r: "/one/hello/world/first", h: hStub1, k: []string{"firstId", "secondId"}, v: []string{"hello", "world"}}, + {r: "/one/hi_there/ok/second", h: hStub2, k: []string{"firstId", "secondId"}, v: []string{"hi_there", "ok"}}, + {r: "/one///first", h: nil, k: []string{}, v: []string{}}, + {r: "/one/hi/123/second", h: hStub2, k: []string{"firstId", "secondId"}, v: []string{"hi", "123"}}, + } + + for i, tt := range tests { + rctx := NewRouteContext() + + _, handlers, _ := tr.FindRoute(rctx, mGET, tt.r) + + var handler web.Handler + if methodHandler, ok := handlers[mGET]; ok { + handler = methodHandler.handler + } + + paramKeys := rctx.routeParams.Keys + paramValues := rctx.routeParams.Values + + if fmt.Sprintf("%v", tt.h) != fmt.Sprintf("%v", handler) { + t.Errorf("input [%d]: find '%s' expecting handler:%v , got:%v", i, tt.r, tt.h, handler) + } + if !stringSliceEqual(tt.k, paramKeys) { + t.Errorf("input [%d]: find '%s' expecting paramKeys:(%d)%v , got:(%d)%v", i, tt.r, len(tt.k), tt.k, len(paramKeys), paramKeys) + } + if !stringSliceEqual(tt.v, paramValues) { + t.Errorf("input [%d]: find '%s' expecting paramValues:(%d)%v , got:(%d)%v", i, tt.r, len(tt.v), tt.v, len(paramValues), paramValues) + } + } +} + +func TestTreeRegexMatchWholeParam(t *testing.T) { + hStub1 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + + rctx := NewRouteContext() + tr := &node{} + if _, err := tr.InsertRoute(mGET, "/{id:[0-9]+}", hStub1); err != nil { + t.Fatal(err) + } + if _, err := tr.InsertRoute(mGET, "/{x:.+}/foo", hStub1); err != nil { + t.Fatal(err) + } + if _, err := tr.InsertRoute(mGET, "/{param:[0-9]*}/test", hStub1); err != nil { + t.Fatal(err) + } + + tests := []struct { + expectedHandler web.Handler + url string + }{ + {url: "/13", expectedHandler: hStub1}, + {url: "/a13", expectedHandler: nil}, + {url: "/13.jpg", expectedHandler: nil}, + {url: "/a13.jpg", expectedHandler: nil}, + {url: "/a/foo", expectedHandler: hStub1}, + {url: "//foo", expectedHandler: nil}, + {url: "//test", expectedHandler: hStub1}, + } + + for _, tc := range tests { + _, _, handler := tr.FindRoute(rctx, mGET, tc.url) + if fmt.Sprintf("%v", tc.expectedHandler) != fmt.Sprintf("%v", handler) { + t.Errorf("url %v: expecting handler:%v , got:%v", tc.url, tc.expectedHandler, handler) + } + } +} + +func TestTreeFindPattern(t *testing.T) { + hStub1 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub2 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub3 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + + tr := &node{} + if _, err := tr.InsertRoute(mGET, "/pages/*", hStub1); err != nil { + t.Fatal(err) + } + if _, err := tr.InsertRoute(mGET, "/articles/{id}/*", hStub2); err != nil { + t.Fatal(err) + } + if _, err := tr.InsertRoute(mGET, "/articles/{slug}/{uid}/*", hStub3); err != nil { + t.Fatal(err) + } + + if ok, err := tr.findPattern("/pages"); ok != false { + t.Errorf("find /pages failed: %v", err) + } + if ok, err := tr.findPattern("/pages*"); ok != false { + t.Errorf("find /pages* failed - should be nil: %v", err) + } + if ok, err := tr.findPattern("/pages/*"); ok == false { + t.Errorf("find /pages/* failed: %v", err) + } + if ok, err := tr.findPattern("/articles/{id}/*"); ok == false { + t.Errorf("find /articles/{id}/* failed: %v", err) + } + if ok, err := tr.findPattern("/articles/{something}/*"); ok == false { + t.Errorf("find /articles/{something}/* failed: %v", err) + } + if ok, err := tr.findPattern("/articles/{slug}/{uid}/*"); ok == false { + t.Errorf("find /articles/{slug}/{uid}/* failed: %v", err) + } +} + +func debugPrintTree(parent int, i int, n *node, label byte) bool { + numEdges := 0 + for _, nds := range n.children { + numEdges += len(nds) + } + + if n.endpoints != nil { + log.Printf("[node %d parent:%d] typ:%d prefix:%s label:%s tail:%s numEdges:%d isLeaf:%v handler:%v\n", i, parent, n.typ, n.prefix, string(label), string(n.tail), numEdges, n.isLeaf(), n.endpoints) + } else { + log.Printf("[node %d parent:%d] typ:%d prefix:%s label:%s tail:%s numEdges:%d isLeaf:%v\n", i, parent, n.typ, n.prefix, string(label), string(n.tail), numEdges, n.isLeaf()) + } + parent = i + for _, nds := range n.children { + for _, e := range nds { + i++ + if debugPrintTree(parent, i, e, e.label) { + return true + } + } + } + return false +} + +func stringSliceEqual(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if b[i] != a[i] { + return false + } + } + return true +} + +func BenchmarkTreeGet(b *testing.B) { + h1 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + h2 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + + tr := &node{} + if _, err := tr.InsertRoute(mGET, "/", h1); err != nil { + b.Fatal(err) + } + if _, err := tr.InsertRoute(mGET, "/ping", h2); err != nil { + b.Fatal(err) + } + if _, err := tr.InsertRoute(mGET, "/pingall", h2); err != nil { + b.Fatal(err) + } + if _, err := tr.InsertRoute(mGET, "/ping/{id}", h2); err != nil { + b.Fatal(err) + } + if _, err := tr.InsertRoute(mGET, "/ping/{id}/woop", h2); err != nil { + b.Fatal(err) + } + if _, err := tr.InsertRoute(mGET, "/ping/{id}/{opt}", h2); err != nil { + b.Fatal(err) + } + if _, err := tr.InsertRoute(mGET, "/pinggggg", h2); err != nil { + b.Fatal(err) + } + if _, err := tr.InsertRoute(mGET, "/hello", h1); err != nil { + b.Fatal(err) + } + + mctx := NewRouteContext() + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + mctx.Reset() + tr.FindRoute(mctx, mGET, "/ping/123/456") + } +} diff --git a/internal/platform/database/v1.go b/internal/platform/database/v1.go index 5951fc8..702f984 100644 --- a/internal/platform/database/v1.go +++ b/internal/platform/database/v1.go @@ -116,7 +116,10 @@ func (s *SQLLiteV1) Load(dbStoragePath string) error { continue } - if err := parsedSpec.Validate(loader.Context); err != nil { + if err := parsedSpec.Validate( + loader.Context, + openapi3.DisableExamplesValidation(), + ); err != nil { s.Log.Errorf("error: validation of the OpenAPI specification %s (schema ID %d): %v", spec.SchemaVersion, schemaID, err) delete(s.RawSpecs, schemaID) continue diff --git a/internal/platform/database/v2.go b/internal/platform/database/v2.go index d0368dd..f681401 100644 --- a/internal/platform/database/v2.go +++ b/internal/platform/database/v2.go @@ -114,7 +114,10 @@ func (s *SQLLiteV2) Load(dbStoragePath string) error { continue } - if err := parsedSpec.Validate(loader.Context); err != nil { + if err := parsedSpec.Validate( + loader.Context, + openapi3.DisableExamplesValidation(), + ); err != nil { s.Log.Errorf("error: validation of the OpenAPI specification %s (schema ID %d): %v", spec.SchemaVersion, schemaID, err) delete(s.RawSpecs, schemaID) continue diff --git a/internal/platform/router/router.go b/internal/platform/router/router.go index 21150b3..ab981eb 100644 --- a/internal/platform/router/router.go +++ b/internal/platform/router/router.go @@ -28,9 +28,13 @@ type CustomRoute struct { // If the given Swagger has servers, router will use them. // All operations of the Swagger will be added to the router. func NewRouter(doc *openapi3.T) (*Router, error) { - if err := doc.Validate(context.Background()); err != nil { + if err := doc.Validate( + context.Background(), + openapi3.DisableExamplesValidation(), + ); err != nil { return nil, fmt.Errorf("validating OpenAPI failed: %v", err) } + var router Router for path, pathItem := range doc.Paths { diff --git a/internal/platform/web/apiMode.go b/internal/platform/web/apiMode.go new file mode 100644 index 0000000..abfbd17 --- /dev/null +++ b/internal/platform/web/apiMode.go @@ -0,0 +1,36 @@ +package web + +const ( + APIModePostfixStatusCode = "_status_code" + APIModePostfixValidationErrors = "_validation_errors" + + GlobalResponseStatusCodeKey = "global_response_status_code" + + RequestSchemaID = "__wallarm_apifw_request_schema_id" +) + +type FieldTypeError struct { + Name string `json:"name"` + ExpectedType string `json:"expected_type,omitempty"` + Pattern string `json:"pattern,omitempty"` + CurrentValue string `json:"current_value,omitempty"` +} + +type ValidationError struct { + Message string `json:"message"` + Code string `json:"code"` + SchemaVersion string `json:"schema_version,omitempty"` + SchemaID *int `json:"schema_id"` + Fields []string `json:"related_fields,omitempty"` + FieldsDetails []FieldTypeError `json:"related_fields_details,omitempty"` +} + +type APIModeResponseSummary struct { + SchemaID *int `json:"schema_id"` + StatusCode *int `json:"status_code"` +} + +type APIModeResponse struct { + Summary []*APIModeResponseSummary `json:"summary"` + Errors []*ValidationError `json:"errors,omitempty"` +} diff --git a/internal/platform/web/middleware.go b/internal/platform/web/middleware.go index 398337b..2490869 100644 --- a/internal/platform/web/middleware.go +++ b/internal/platform/web/middleware.go @@ -5,10 +5,10 @@ package web // direct to any given Handler. type Middleware func(Handler) Handler -// wrapMiddleware creates a new handler by wrapping middleware around a final +// WrapMiddleware creates a new handler by wrapping middleware around a final // handler. The middlewares' Handlers will be executed by requests in the order // they are provided. -func wrapMiddleware(mw []Middleware, handler Handler) Handler { +func WrapMiddleware(mw []Middleware, handler Handler) Handler { // Loop backwards through the middleware invoking each one. Replace the // handler with the new wrapped handler. Looping backwards ensures that the diff --git a/internal/platform/web/web.go b/internal/platform/web/web.go index 475d532..09d9a5e 100644 --- a/internal/platform/web/web.go +++ b/internal/platform/web/web.go @@ -64,10 +64,10 @@ type AppAdditionalOptions struct { func (a *App) SetDefaultBehavior(handler Handler, mw ...Middleware) { // First wrap handler specific middleware around this handler. - handler = wrapMiddleware(mw, handler) + handler = WrapMiddleware(mw, handler) // Add the application's general middleware to the handler chain. - handler = wrapMiddleware(a.mw, handler) + handler = WrapMiddleware(a.mw, handler) customHandler := func(ctx *fasthttp.RequestCtx) { @@ -126,10 +126,10 @@ func NewApp(options *AppAdditionalOptions, shutdown chan os.Signal, logger *logr func (a *App) Handle(method string, path string, handler Handler, mw ...Middleware) { // First wrap handler specific middleware around this handler. - handler = wrapMiddleware(mw, handler) + handler = WrapMiddleware(mw, handler) // Add the application's general middleware to the handler chain. - handler = wrapMiddleware(a.mw, handler) + handler = WrapMiddleware(a.mw, handler) // The function to execute for each request. h := func(ctx *fasthttp.RequestCtx) { From 751fdd77d037c1f802445ccaba3b1209add6dee6 Mon Sep 17 00:00:00 2001 From: Nikolay Tkachenko Date: Sat, 13 Apr 2024 02:17:33 +0300 Subject: [PATCH 04/12] Update Dockerfile. Refactor loader. --- Dockerfile | 2 +- cmd/api-firewall/internal/handlers/api/app.go | 12 +++--- .../internal/handlers/api/openapi.go | 10 ++--- .../internal/handlers/api/routes.go | 4 +- .../internal/handlers/proxy/openapi.go | 4 +- .../internal/handlers/proxy/routes.go | 4 +- .../internal/updater/wallarm_api2_update.db | Bin 98304 -> 98304 bytes cmd/api-firewall/main.go | 4 +- cmd/api-firewall/tests/main_json_test.go | 4 +- cmd/api-firewall/tests/main_modsec_test.go | 6 +-- cmd/api-firewall/tests/main_test.go | 6 +-- go.mod | 2 +- internal/platform/database/database.go | 1 - internal/platform/database/v1.go | 16 ++----- internal/platform/database/v2.go | 16 ++----- internal/platform/loader/loader.go | 39 ++++++++++++++++++ .../platform/{router => loader}/router.go | 24 +++++------ internal/platform/{chi => router}/LICENSE | 0 internal/platform/{chi => router}/chi.go | 2 +- internal/platform/{chi => router}/context.go | 2 +- .../platform/{chi => router}/context_test.go | 2 +- internal/platform/{chi => router}/mux.go | 2 +- internal/platform/{chi => router}/mux_test.go | 2 +- .../platform/{chi => router}/path_value.go | 2 +- .../{chi => router}/path_value_fallback.go | 2 +- .../{chi => router}/path_value_test.go | 2 +- internal/platform/{chi => router}/tree.go | 2 +- .../platform/{chi => router}/tree_test.go | 2 +- .../platform/validator/req_resp_decoder.go | 3 +- 29 files changed, 96 insertions(+), 81 deletions(-) create mode 100644 internal/platform/loader/loader.go rename internal/platform/{router => loader}/router.go (75%) rename internal/platform/{chi => router}/LICENSE (100%) rename internal/platform/{chi => router}/chi.go (98%) rename internal/platform/{chi => router}/context.go (99%) rename internal/platform/{chi => router}/context_test.go (99%) rename internal/platform/{chi => router}/mux.go (99%) rename internal/platform/{chi => router}/mux_test.go (99%) rename internal/platform/{chi => router}/path_value.go (97%) rename internal/platform/{chi => router}/path_value_fallback.go (97%) rename internal/platform/{chi => router}/path_value_test.go (99%) rename internal/platform/{chi => router}/tree.go (99%) rename internal/platform/{chi => router}/tree_test.go (99%) diff --git a/Dockerfile b/Dockerfile index 6d4c956..6391d14 100644 --- a/Dockerfile +++ b/Dockerfile @@ -14,7 +14,7 @@ COPY . . RUN go mod download -x && \ go build \ - -ldflags="-X main.build=${APIFIREWALL_VERSION} -s -w" \ + -ldflags="-X main.build=${APIFIREWALL_VERSION} -linkmode 'external' -extldflags '-static' -s -w" \ -buildvcs=false \ -o ./api-firewall \ ./cmd/api-firewall diff --git a/cmd/api-firewall/internal/handlers/api/app.go b/cmd/api-firewall/internal/handlers/api/app.go index c0fdcb9..f407b91 100644 --- a/cmd/api-firewall/internal/handlers/api/app.go +++ b/cmd/api-firewall/internal/handlers/api/app.go @@ -16,8 +16,8 @@ import ( "github.com/sirupsen/logrus" "github.com/valyala/fasthttp" "github.com/valyala/fasthttp/fasthttpadaptor" - "github.com/wallarm/api-firewall/internal/platform/chi" "github.com/wallarm/api-firewall/internal/platform/database" + "github.com/wallarm/api-firewall/internal/platform/router" "github.com/wallarm/api-firewall/internal/platform/web" ) @@ -30,7 +30,7 @@ var ( // object for each of our http handlers. Feel free to add any configuration // data/logic on this App struct type APIModeApp struct { - Routers map[int]*chi.Mux + Routers map[int]*router.Mux Log *logrus.Logger passOPTIONS bool shutdown chan os.Signal @@ -45,10 +45,10 @@ func NewAPIModeApp(lock *sync.RWMutex, passOPTIONS bool, storedSpecs database.DB schemaIDs := storedSpecs.SchemaIDs() // Init routers - routers := make(map[int]*chi.Mux) + routers := make(map[int]*router.Mux) for _, schemaID := range schemaIDs { //routers[schemaID] = make(map[string]*mux.Router) - routers[schemaID] = chi.NewRouter() + routers[schemaID] = router.NewRouter() //routers[schemaID].HandleOPTIONS = passOPTIONS } @@ -199,7 +199,7 @@ func (a *APIModeApp) APIModeRouteHandler(ctx *fasthttp.RequestCtx) { } // find the handler with the OAS information - rctx := chi.NewRouteContext() + rctx := router.NewRouteContext() handler := a.Routers[schemaID].Find(rctx, strconv.B2S(ctx.Method()), strconv.B2S(ctx.Request.URI().Path())) // handler not found in the OAS @@ -232,7 +232,7 @@ func (a *APIModeApp) APIModeRouteHandler(ctx *fasthttp.RequestCtx) { } // add router context to get URL params in the Handler - ctx.SetUserValue(chi.RouteCtxKey, rctx) + ctx.SetUserValue(router.RouteCtxKey, rctx) if err := handler(ctx); err != nil { a.Log.WithFields(logrus.Fields{ diff --git a/cmd/api-firewall/internal/handlers/api/openapi.go b/cmd/api-firewall/internal/handlers/api/openapi.go index c24e44d..6a8cc6f 100644 --- a/cmd/api-firewall/internal/handlers/api/openapi.go +++ b/cmd/api-firewall/internal/handlers/api/openapi.go @@ -3,7 +3,7 @@ package api import ( "context" "fmt" - "github.com/wallarm/api-firewall/internal/platform/chi" + "github.com/wallarm/api-firewall/internal/platform/router" "net/http" strconv2 "strconv" "strings" @@ -17,7 +17,7 @@ import ( "github.com/valyala/fasthttp/fasthttpadaptor" "github.com/valyala/fastjson" "github.com/wallarm/api-firewall/internal/config" - "github.com/wallarm/api-firewall/internal/platform/router" + "github.com/wallarm/api-firewall/internal/platform/loader" "github.com/wallarm/api-firewall/internal/platform/validator" "github.com/wallarm/api-firewall/internal/platform/web" ) @@ -76,8 +76,8 @@ var apiModeSecurityRequirementsOptions = &openapi3filter.Options{ } type APIMode struct { - CustomRoute *router.CustomRoute - OpenAPIRouter *router.Router + CustomRoute *loader.CustomRoute + OpenAPIRouter *loader.Router Log *logrus.Logger Cfg *config.APIMode ParserPool *fastjson.ParserPool @@ -107,7 +107,7 @@ func (s *APIMode) APIModeHandler(ctx *fasthttp.RequestCtx) error { var pathParams map[string]string if s.CustomRoute.ParametersNumberInPath > 0 { - pathParams = chi.AllURLParams(ctx) + pathParams = router.AllURLParams(ctx) } // Convert fasthttp request to net/http request diff --git a/cmd/api-firewall/internal/handlers/api/routes.go b/cmd/api-firewall/internal/handlers/api/routes.go index 598d5c7..8653077 100644 --- a/cmd/api-firewall/internal/handlers/api/routes.go +++ b/cmd/api-firewall/internal/handlers/api/routes.go @@ -14,7 +14,7 @@ import ( "github.com/wallarm/api-firewall/internal/mid" "github.com/wallarm/api-firewall/internal/platform/allowiplist" "github.com/wallarm/api-firewall/internal/platform/database" - "github.com/wallarm/api-firewall/internal/platform/router" + "github.com/wallarm/api-firewall/internal/platform/loader" "github.com/wallarm/api-firewall/internal/platform/web" ) @@ -70,7 +70,7 @@ func Handlers(lock *sync.RWMutex, cfg *config.APIMode, shutdown chan os.Signal, } // get new router - newSwagRouter, err := router.NewRouterDBLoader(schemaID, storedSpecs) + newSwagRouter, err := loader.NewRouterDBLoader(storedSpecs.SpecificationVersion(schemaID), storedSpecs.Specification(schemaID)) if err != nil { logger.WithFields(logrus.Fields{"error": err}).Error("New router creation failed") } diff --git a/cmd/api-firewall/internal/handlers/proxy/openapi.go b/cmd/api-firewall/internal/handlers/proxy/openapi.go index 0290969..49c00d5 100644 --- a/cmd/api-firewall/internal/handlers/proxy/openapi.go +++ b/cmd/api-firewall/internal/handlers/proxy/openapi.go @@ -16,15 +16,15 @@ import ( "github.com/valyala/fasthttp/fasthttpadaptor" "github.com/valyala/fastjson" "github.com/wallarm/api-firewall/internal/config" + "github.com/wallarm/api-firewall/internal/platform/loader" "github.com/wallarm/api-firewall/internal/platform/oauth2" "github.com/wallarm/api-firewall/internal/platform/proxy" - "github.com/wallarm/api-firewall/internal/platform/router" "github.com/wallarm/api-firewall/internal/platform/validator" "github.com/wallarm/api-firewall/internal/platform/web" ) type openapiWaf struct { - customRoute *router.CustomRoute + customRoute *loader.CustomRoute proxyPool proxy.Pool logger *logrus.Logger cfg *config.ProxyMode diff --git a/cmd/api-firewall/internal/handlers/proxy/routes.go b/cmd/api-firewall/internal/handlers/proxy/routes.go index 0cf19ea..a3b360b 100644 --- a/cmd/api-firewall/internal/handlers/proxy/routes.go +++ b/cmd/api-firewall/internal/handlers/proxy/routes.go @@ -18,13 +18,13 @@ import ( "github.com/wallarm/api-firewall/internal/mid" "github.com/wallarm/api-firewall/internal/platform/allowiplist" "github.com/wallarm/api-firewall/internal/platform/denylist" + "github.com/wallarm/api-firewall/internal/platform/loader" woauth2 "github.com/wallarm/api-firewall/internal/platform/oauth2" "github.com/wallarm/api-firewall/internal/platform/proxy" - "github.com/wallarm/api-firewall/internal/platform/router" "github.com/wallarm/api-firewall/internal/platform/web" ) -func Handlers(cfg *config.ProxyMode, serverURL *url.URL, shutdown chan os.Signal, logger *logrus.Logger, httpClientsPool proxy.Pool, swagRouter *router.Router, deniedTokens *denylist.DeniedTokens, AllowedIPCache *allowiplist.AllowedIPsType, waf coraza.WAF) fasthttp.RequestHandler { +func Handlers(cfg *config.ProxyMode, serverURL *url.URL, shutdown chan os.Signal, logger *logrus.Logger, httpClientsPool proxy.Pool, swagRouter *loader.Router, deniedTokens *denylist.DeniedTokens, AllowedIPCache *allowiplist.AllowedIPsType, waf coraza.WAF) fasthttp.RequestHandler { // define FastJSON parsers pool var parserPool fastjson.ParserPool diff --git a/cmd/api-firewall/internal/updater/wallarm_api2_update.db b/cmd/api-firewall/internal/updater/wallarm_api2_update.db index 603b995af0300b6f2561ae7b5fb24cd9a10ea460..b4918fe3d1cdf73348bce4fc954da336633d920b 100644 GIT binary patch delta 34 qcmZo@U~6b#n;^~TJyFJ)(R*V;lq{pm=H0R`LX1b643-%%1^@uM7Yd93 delta 34 qcmZo@U~6b#n;^|7I#I@%QFLQMlq{ps=H0R`LX4Z543-%%1^@uB9|`;b diff --git a/cmd/api-firewall/main.go b/cmd/api-firewall/main.go index 47f6bc4..01131eb 100644 --- a/cmd/api-firewall/main.go +++ b/cmd/api-firewall/main.go @@ -26,8 +26,8 @@ import ( "github.com/wallarm/api-firewall/internal/platform/allowiplist" "github.com/wallarm/api-firewall/internal/platform/database" "github.com/wallarm/api-firewall/internal/platform/denylist" + "github.com/wallarm/api-firewall/internal/platform/loader" "github.com/wallarm/api-firewall/internal/platform/proxy" - "github.com/wallarm/api-firewall/internal/platform/router" "github.com/wallarm/api-firewall/internal/platform/web" "github.com/wundergraph/graphql-go-tools/pkg/graphql" ) @@ -760,7 +760,7 @@ func runProxyMode(logger *logrus.Logger) error { } } - swagRouter, err := router.NewRouter(swagger) + swagRouter, err := loader.NewRouter(swagger, true) if err != nil { return errors.Wrap(err, "parsing swagwaf file") } diff --git a/cmd/api-firewall/tests/main_json_test.go b/cmd/api-firewall/tests/main_json_test.go index 1b61ef9..c1c5538 100644 --- a/cmd/api-firewall/tests/main_json_test.go +++ b/cmd/api-firewall/tests/main_json_test.go @@ -15,8 +15,8 @@ import ( "github.com/valyala/fasthttp" proxyHandler "github.com/wallarm/api-firewall/cmd/api-firewall/internal/handlers/proxy" "github.com/wallarm/api-firewall/internal/config" + "github.com/wallarm/api-firewall/internal/platform/loader" "github.com/wallarm/api-firewall/internal/platform/proxy" - "github.com/wallarm/api-firewall/internal/platform/router" ) const openAPIJSONSpecTest = ` @@ -105,7 +105,7 @@ func TestJSONBasic(t *testing.T) { t.Fatalf("loading swagwaf file: %s", err.Error()) } - swagRouter, err := router.NewRouter(swagger) + swagRouter, err := loader.NewRouter(swagger, true) if err != nil { t.Fatalf("parsing swagwaf file: %s", err.Error()) } diff --git a/cmd/api-firewall/tests/main_modsec_test.go b/cmd/api-firewall/tests/main_modsec_test.go index 7c6744a..ff72f53 100644 --- a/cmd/api-firewall/tests/main_modsec_test.go +++ b/cmd/api-firewall/tests/main_modsec_test.go @@ -17,8 +17,8 @@ import ( "github.com/valyala/fasthttp" proxy2 "github.com/wallarm/api-firewall/cmd/api-firewall/internal/handlers/proxy" "github.com/wallarm/api-firewall/internal/config" + "github.com/wallarm/api-firewall/internal/platform/loader" "github.com/wallarm/api-firewall/internal/platform/proxy" - "github.com/wallarm/api-firewall/internal/platform/router" ) const openAPISpecModSecTest = ` @@ -99,7 +99,7 @@ type ModSecIntegrationTests struct { logger *logrus.Logger proxy *proxy.MockPool client *proxy.MockHTTPClient - swagRouter *router.Router + swagRouter *loader.Router waf coraza.WAF loggerHook *test.Hook } @@ -125,7 +125,7 @@ func TestModSec(t *testing.T) { t.Fatalf("loading swagwaf file: %s", err.Error()) } - swagRouter, err := router.NewRouter(swagger) + swagRouter, err := loader.NewRouter(swagger, true) if err != nil { t.Fatalf("parsing swagwaf file: %s", err.Error()) } diff --git a/cmd/api-firewall/tests/main_test.go b/cmd/api-firewall/tests/main_test.go index ed80ba8..9a7f451 100644 --- a/cmd/api-firewall/tests/main_test.go +++ b/cmd/api-firewall/tests/main_test.go @@ -26,8 +26,8 @@ import ( "github.com/wallarm/api-firewall/internal/config" "github.com/wallarm/api-firewall/internal/platform/allowiplist" "github.com/wallarm/api-firewall/internal/platform/denylist" + "github.com/wallarm/api-firewall/internal/platform/loader" "github.com/wallarm/api-firewall/internal/platform/proxy" - "github.com/wallarm/api-firewall/internal/platform/router" ) const openAPISpecTest = ` @@ -317,7 +317,7 @@ type ServiceTests struct { logger *logrus.Logger proxy *proxy.MockPool client *proxy.MockHTTPClient - swagRouter *router.Router + swagRouter *loader.Router } func compressFlate(data []byte) ([]byte, error) { @@ -395,7 +395,7 @@ func TestBasic(t *testing.T) { t.Fatalf("loading swagwaf file: %s", err.Error()) } - swagRouter, err := router.NewRouter(swagger) + swagRouter, err := loader.NewRouter(swagger, true) if err != nil { t.Fatalf("parsing swagwaf file: %s", err.Error()) } diff --git a/go.mod b/go.mod index 35b62b5..57a0498 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( github.com/golang-jwt/jwt v3.2.2+incompatible github.com/golang/mock v1.6.0 github.com/google/uuid v1.6.0 + github.com/invopop/yaml v0.2.0 github.com/karlseguin/ccache/v2 v2.0.8 github.com/klauspost/compress v1.17.7 github.com/mattn/go-sqlite3 v1.14.22 @@ -51,7 +52,6 @@ require ( github.com/hashicorp/golang-lru v0.5.4 // indirect github.com/huandu/xstrings v1.2.1 // indirect github.com/imdario/mergo v0.3.8 // indirect - github.com/invopop/yaml v0.2.0 // indirect github.com/jensneuse/abstractlogger v0.0.4 // indirect github.com/jensneuse/byte-template v0.0.0-20200214152254-4f3cf06e5c68 // indirect github.com/jensneuse/pipeline v0.0.0-20200117120358-9fb4de085cd6 // indirect diff --git a/internal/platform/database/database.go b/internal/platform/database/database.go index 9026339..c6d3f0a 100644 --- a/internal/platform/database/database.go +++ b/internal/platform/database/database.go @@ -2,7 +2,6 @@ package database import ( "bytes" - "github.com/getkin/kin-openapi/openapi3" _ "github.com/mattn/go-sqlite3" "github.com/sirupsen/logrus" diff --git a/internal/platform/database/v1.go b/internal/platform/database/v1.go index 702f984..f252722 100644 --- a/internal/platform/database/v1.go +++ b/internal/platform/database/v1.go @@ -13,6 +13,7 @@ import ( _ "github.com/mattn/go-sqlite3" "github.com/pkg/errors" "github.com/sirupsen/logrus" + "github.com/wallarm/api-firewall/internal/platform/loader" ) const currentSQLSchemaVersionV1 = 1 @@ -107,20 +108,9 @@ func (s *SQLLiteV1) Load(dbStoragePath string) error { for schemaID, spec := range s.RawSpecs { - // parse specification - loader := openapi3.NewLoader() - parsedSpec, err := loader.LoadFromData(getSpecBytes(spec.SchemaContent)) + parsedSpec, err := loader.ParseOAS(getSpecBytes(spec.SchemaContent), spec.SchemaVersion, schemaID) if err != nil { - s.Log.Errorf("error: parsing of the OpenAPI specification %s (schema ID %d): %v", spec.SchemaVersion, schemaID, err) - delete(s.RawSpecs, schemaID) - continue - } - - if err := parsedSpec.Validate( - loader.Context, - openapi3.DisableExamplesValidation(), - ); err != nil { - s.Log.Errorf("error: validation of the OpenAPI specification %s (schema ID %d): %v", spec.SchemaVersion, schemaID, err) + s.Log.Errorf("error: %v", err) delete(s.RawSpecs, schemaID) continue } diff --git a/internal/platform/database/v2.go b/internal/platform/database/v2.go index f681401..dfadebb 100644 --- a/internal/platform/database/v2.go +++ b/internal/platform/database/v2.go @@ -14,6 +14,7 @@ import ( _ "github.com/mattn/go-sqlite3" "github.com/pkg/errors" "github.com/sirupsen/logrus" + "github.com/wallarm/api-firewall/internal/platform/loader" ) const currentSQLSchemaVersionV2 = 2 @@ -105,20 +106,9 @@ func (s *SQLLiteV2) Load(dbStoragePath string) error { for schemaID, spec := range s.RawSpecs { - // parse specification - loader := openapi3.NewLoader() - parsedSpec, err := loader.LoadFromData(getSpecBytes(spec.SchemaContent)) + parsedSpec, err := loader.ParseOAS(getSpecBytes(spec.SchemaContent), spec.SchemaVersion, schemaID) if err != nil { - s.Log.Errorf("error: parsing of the OpenAPI specification %s (schema ID %d): %v", spec.SchemaVersion, schemaID, err) - delete(s.RawSpecs, schemaID) - continue - } - - if err := parsedSpec.Validate( - loader.Context, - openapi3.DisableExamplesValidation(), - ); err != nil { - s.Log.Errorf("error: validation of the OpenAPI specification %s (schema ID %d): %v", spec.SchemaVersion, schemaID, err) + s.Log.Errorf("error: %v", err) delete(s.RawSpecs, schemaID) continue } diff --git a/internal/platform/loader/loader.go b/internal/platform/loader/loader.go new file mode 100644 index 0000000..3b17183 --- /dev/null +++ b/internal/platform/loader/loader.go @@ -0,0 +1,39 @@ +package loader + +import ( + "context" + "fmt" + + "github.com/getkin/kin-openapi/openapi3" +) + +func validateOAS(spec *openapi3.T) error { + + if err := spec.Validate( + context.Background(), + openapi3.DisableExamplesValidation(), + openapi3.DisableSchemaFormatValidation(), + openapi3.DisableSchemaDefaultsValidation(), + openapi3.DisableSchemaPatternValidation(), + ); err != nil { + return err + } + + return nil +} + +func ParseOAS(schema []byte, SchemaVersion string, schemaID int) (*openapi3.T, error) { + + // parse specification + loader := openapi3.NewLoader() + parsedSpec, err := loader.LoadFromData(schema) + if err != nil { + return nil, fmt.Errorf("OpenAPI specification (version %s; schema ID %d) parsing failed: %v", SchemaVersion, schemaID, err) + } + + if err := validateOAS(parsedSpec); err != nil { + return nil, fmt.Errorf("OpenAPI specification (version %s; schema ID %d) validation failed: %v: ", SchemaVersion, schemaID, err) + } + + return parsedSpec, nil +} diff --git a/internal/platform/router/router.go b/internal/platform/loader/router.go similarity index 75% rename from internal/platform/router/router.go rename to internal/platform/loader/router.go index ab981eb..072a5f8 100644 --- a/internal/platform/router/router.go +++ b/internal/platform/loader/router.go @@ -1,13 +1,11 @@ -package router +package loader import ( - "context" "fmt" "strings" "github.com/getkin/kin-openapi/openapi3" "github.com/getkin/kin-openapi/routers" - "github.com/wallarm/api-firewall/internal/platform/database" ) // Router helps link http.Request.s and an OpenAPIv3 spec @@ -27,12 +25,11 @@ type CustomRoute struct { // // If the given Swagger has servers, router will use them. // All operations of the Swagger will be added to the router. -func NewRouter(doc *openapi3.T) (*Router, error) { - if err := doc.Validate( - context.Background(), - openapi3.DisableExamplesValidation(), - ); err != nil { - return nil, fmt.Errorf("validating OpenAPI failed: %v", err) +func NewRouter(doc *openapi3.T, validate bool) (*Router, error) { + if validate { + if err := validateOAS(doc); err != nil { + return nil, fmt.Errorf("OpenAPI specification validation failed: %v", err) + } } var router Router @@ -80,15 +77,14 @@ func NewRouter(doc *openapi3.T) (*Router, error) { } // NewRouterDBLoader creates a new router based on DB OpenAPI loader. -func NewRouterDBLoader(schemaID int, openAPISpec database.DBOpenAPILoader) (*Router, error) { - doc := openAPISpec.Specification(schemaID) +func NewRouterDBLoader(schemaVersion string, spec *openapi3.T) (*Router, error) { - router, err := NewRouter(doc) + newRouter, err := NewRouter(spec, false) if err != nil { return nil, err } - router.SchemaVersion = openAPISpec.SpecificationVersion(schemaID) + newRouter.SchemaVersion = schemaVersion - return router, nil + return newRouter, nil } diff --git a/internal/platform/chi/LICENSE b/internal/platform/router/LICENSE similarity index 100% rename from internal/platform/chi/LICENSE rename to internal/platform/router/LICENSE diff --git a/internal/platform/chi/chi.go b/internal/platform/router/chi.go similarity index 98% rename from internal/platform/chi/chi.go rename to internal/platform/router/chi.go index 0d7a5c2..53193f0 100644 --- a/internal/platform/chi/chi.go +++ b/internal/platform/router/chi.go @@ -1,4 +1,4 @@ -package chi +package router import "github.com/wallarm/api-firewall/internal/platform/web" diff --git a/internal/platform/chi/context.go b/internal/platform/router/context.go similarity index 99% rename from internal/platform/chi/context.go rename to internal/platform/router/context.go index f5c506a..7a6dba7 100644 --- a/internal/platform/chi/context.go +++ b/internal/platform/router/context.go @@ -1,4 +1,4 @@ -package chi +package router import ( "strings" diff --git a/internal/platform/chi/context_test.go b/internal/platform/router/context_test.go similarity index 99% rename from internal/platform/chi/context_test.go rename to internal/platform/router/context_test.go index 4731c70..1a4fad6 100644 --- a/internal/platform/chi/context_test.go +++ b/internal/platform/router/context_test.go @@ -1,4 +1,4 @@ -package chi +package router import "testing" diff --git a/internal/platform/chi/mux.go b/internal/platform/router/mux.go similarity index 99% rename from internal/platform/chi/mux.go rename to internal/platform/router/mux.go index c4d3fd9..6d1e5f5 100644 --- a/internal/platform/chi/mux.go +++ b/internal/platform/router/mux.go @@ -1,4 +1,4 @@ -package chi +package router import ( "fmt" diff --git a/internal/platform/chi/mux_test.go b/internal/platform/router/mux_test.go similarity index 99% rename from internal/platform/chi/mux_test.go rename to internal/platform/router/mux_test.go index 9a4789e..5f0917b 100644 --- a/internal/platform/chi/mux_test.go +++ b/internal/platform/router/mux_test.go @@ -1,4 +1,4 @@ -package chi +package router import ( "bytes" diff --git a/internal/platform/chi/path_value.go b/internal/platform/router/path_value.go similarity index 97% rename from internal/platform/chi/path_value.go rename to internal/platform/router/path_value.go index 8ab89cd..cd985e9 100644 --- a/internal/platform/chi/path_value.go +++ b/internal/platform/router/path_value.go @@ -1,7 +1,7 @@ //go:build go1.22 // +build go1.22 -package chi +package router import ( "github.com/valyala/fasthttp" diff --git a/internal/platform/chi/path_value_fallback.go b/internal/platform/router/path_value_fallback.go similarity index 97% rename from internal/platform/chi/path_value_fallback.go rename to internal/platform/router/path_value_fallback.go index 9f0288b..bd5249f 100644 --- a/internal/platform/chi/path_value_fallback.go +++ b/internal/platform/router/path_value_fallback.go @@ -1,7 +1,7 @@ //go:build !go1.22 // +build !go1.22 -package chi +package router import ( "github.com/valyala/fasthttp" diff --git a/internal/platform/chi/path_value_test.go b/internal/platform/router/path_value_test.go similarity index 99% rename from internal/platform/chi/path_value_test.go rename to internal/platform/router/path_value_test.go index 5a48698..7b41311 100644 --- a/internal/platform/chi/path_value_test.go +++ b/internal/platform/router/path_value_test.go @@ -1,7 +1,7 @@ //go:build go1.22 // +build go1.22 -package chi +package router import ( "net/http" diff --git a/internal/platform/chi/tree.go b/internal/platform/router/tree.go similarity index 99% rename from internal/platform/chi/tree.go rename to internal/platform/router/tree.go index 762dded..93b662f 100644 --- a/internal/platform/chi/tree.go +++ b/internal/platform/router/tree.go @@ -1,4 +1,4 @@ -package chi +package router // Radix tree implementation below is a based on the original work by // Armon Dadgar in https://github.com/armon/go-radix/blob/master/radix.go diff --git a/internal/platform/chi/tree_test.go b/internal/platform/router/tree_test.go similarity index 99% rename from internal/platform/chi/tree_test.go rename to internal/platform/router/tree_test.go index 53e02e9..6ef0e0d 100644 --- a/internal/platform/chi/tree_test.go +++ b/internal/platform/router/tree_test.go @@ -1,4 +1,4 @@ -package chi +package router import ( "fmt" diff --git a/internal/platform/validator/req_resp_decoder.go b/internal/platform/validator/req_resp_decoder.go index 3604286..7452865 100644 --- a/internal/platform/validator/req_resp_decoder.go +++ b/internal/platform/validator/req_resp_decoder.go @@ -895,6 +895,7 @@ func parsePrimitive(raw string, schema *openapi3.SchemaRef) (interface{}, error) if raw == "" { return nil, nil } + switch schema.Value.Type { case "integer": if len(schema.Value.Enum) > 0 { @@ -932,7 +933,7 @@ func parsePrimitive(raw string, schema *openapi3.SchemaRef) (interface{}, error) case "string": return raw, nil default: - panic(fmt.Sprintf("schema has non primitive type %q", schema.Value.Type)) + return nil, &ParseError{Kind: KindOther, Value: raw, Reason: "schema has non primitive type " + schema.Value.Type} } } From a3eddb37d76e4129fca366f8fa0ccdae9c2fb6cc Mon Sep 17 00:00:00 2001 From: Nikolay Tkachenko Date: Sat, 13 Apr 2024 02:18:10 +0300 Subject: [PATCH 05/12] Update Go version to 1.21.9 --- .github/workflows/binaries.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/binaries.yml b/.github/workflows/binaries.yml index 94c7a2e..94f6809 100644 --- a/.github/workflows/binaries.yml +++ b/.github/workflows/binaries.yml @@ -51,7 +51,7 @@ jobs: needs: - draft-release env: - X_GO_DISTRIBUTION: "https://go.dev/dl/go1.21.8.linux-amd64.tar.gz" + X_GO_DISTRIBUTION: "https://go.dev/dl/go1.21.9.linux-amd64.tar.gz" strategy: matrix: include: @@ -160,7 +160,7 @@ jobs: needs: - draft-release env: - X_GO_VERSION: "1.21.8" + X_GO_VERSION: "1.21.9" strategy: matrix: include: @@ -267,11 +267,11 @@ jobs: include: - arch: armv6 distro: bullseye - go_distribution: https://go.dev/dl/go1.21.8.linux-armv6l.tar.gz + go_distribution: https://go.dev/dl/go1.21.9.linux-armv6l.tar.gz artifact: armv6-libc - arch: aarch64 distro: bullseye - go_distribution: https://go.dev/dl/go1.21.8.linux-arm64.tar.gz + go_distribution: https://go.dev/dl/go1.21.9.linux-arm64.tar.gz artifact: arm64-libc - arch: armv6 distro: alpine_latest From 9528f1953a3b55a072f0fd4ae108ca98595d70d0 Mon Sep 17 00:00:00 2001 From: Nikolay Tkachenko Date: Sat, 13 Apr 2024 02:43:21 +0300 Subject: [PATCH 06/12] Small refactoring --- cmd/api-firewall/internal/handlers/api/app.go | 19 ++++++++----------- .../internal/handlers/api/openapi.go | 8 ++++---- .../internal/handlers/api/routes.go | 6 +++--- internal/platform/database/database.go | 2 ++ 4 files changed, 17 insertions(+), 18 deletions(-) diff --git a/cmd/api-firewall/internal/handlers/api/app.go b/cmd/api-firewall/internal/handlers/api/app.go index f407b91..676f280 100644 --- a/cmd/api-firewall/internal/handlers/api/app.go +++ b/cmd/api-firewall/internal/handlers/api/app.go @@ -26,10 +26,10 @@ var ( statusInternalError = fasthttp.StatusInternalServerError ) -// APIModeApp is the entrypoint into our application and what configures our context +// App is the entrypoint into our application and what configures our context // object for each of our http handlers. Feel free to add any configuration // data/logic on this App struct -type APIModeApp struct { +type App struct { Routers map[int]*router.Mux Log *logrus.Logger passOPTIONS bool @@ -39,20 +39,18 @@ type APIModeApp struct { lock *sync.RWMutex } -// NewAPIModeApp creates an APIModeApp value that handle a set of routes for the set of application. -func NewAPIModeApp(lock *sync.RWMutex, passOPTIONS bool, storedSpecs database.DBOpenAPILoader, shutdown chan os.Signal, logger *logrus.Logger, mw ...web.Middleware) *APIModeApp { +// NewApp creates an App value that handle a set of routes for the set of application. +func NewApp(lock *sync.RWMutex, passOPTIONS bool, storedSpecs database.DBOpenAPILoader, shutdown chan os.Signal, logger *logrus.Logger, mw ...web.Middleware) *App { schemaIDs := storedSpecs.SchemaIDs() // Init routers routers := make(map[int]*router.Mux) for _, schemaID := range schemaIDs { - //routers[schemaID] = make(map[string]*mux.Router) routers[schemaID] = router.NewRouter() - //routers[schemaID].HandleOPTIONS = passOPTIONS } - app := APIModeApp{ + app := App{ Routers: routers, shutdown: shutdown, mw: mw, @@ -67,7 +65,7 @@ func NewAPIModeApp(lock *sync.RWMutex, passOPTIONS bool, storedSpecs database.DB // Handle is our mechanism for mounting Handlers for a given HTTP verb and path // pair, this makes for really easy, convenient routing. -func (a *APIModeApp) Handle(schemaID int, method string, path string, handler web.Handler, mw ...web.Middleware) error { +func (a *App) Handle(schemaID int, method string, path string, handler web.Handler, mw ...web.Middleware) error { // First wrap handler specific middleware around this handler. handler = web.WrapMiddleware(mw, handler) @@ -133,7 +131,7 @@ func getWallarmSchemaID(ctx *fasthttp.RequestCtx, storedSpecs database.DBOpenAPI } // APIModeRouteHandler routes request to the appropriate handler according to the OpenAPI specification schema ID -func (a *APIModeApp) APIModeRouteHandler(ctx *fasthttp.RequestCtx) { +func (a *App) APIModeRouteHandler(ctx *fasthttp.RequestCtx) { // handle panic defer func() { @@ -179,7 +177,6 @@ func (a *APIModeApp) APIModeRouteHandler(ctx *fasthttp.RequestCtx) { a.lock.RLock() defer a.lock.RUnlock() - //w := NewFastHTTPResponseAdapter(ctx) // Validate requests against list of schemas for _, sID := range schemaIDs { @@ -312,6 +309,6 @@ func (a *APIModeApp) APIModeRouteHandler(ctx *fasthttp.RequestCtx) { // SignalShutdown is used to gracefully shutdown the app when an integrity // issue is identified. -func (a *APIModeApp) SignalShutdown() { +func (a *App) SignalShutdown() { a.shutdown <- syscall.SIGTERM } diff --git a/cmd/api-firewall/internal/handlers/api/openapi.go b/cmd/api-firewall/internal/handlers/api/openapi.go index 6a8cc6f..2ec7b16 100644 --- a/cmd/api-firewall/internal/handlers/api/openapi.go +++ b/cmd/api-firewall/internal/handlers/api/openapi.go @@ -3,7 +3,6 @@ package api import ( "context" "fmt" - "github.com/wallarm/api-firewall/internal/platform/router" "net/http" strconv2 "strconv" "strings" @@ -18,6 +17,7 @@ import ( "github.com/valyala/fastjson" "github.com/wallarm/api-firewall/internal/config" "github.com/wallarm/api-firewall/internal/platform/loader" + "github.com/wallarm/api-firewall/internal/platform/router" "github.com/wallarm/api-firewall/internal/platform/validator" "github.com/wallarm/api-firewall/internal/platform/web" ) @@ -75,7 +75,7 @@ var apiModeSecurityRequirementsOptions = &openapi3filter.Options{ }, } -type APIMode struct { +type RequestValidator struct { CustomRoute *loader.CustomRoute OpenAPIRouter *loader.Router Log *logrus.Logger @@ -84,8 +84,8 @@ type APIMode struct { SchemaID int } -// APIModeHandler validates request and respond with 200, 403 (with error) or 500 status code -func (s *APIMode) APIModeHandler(ctx *fasthttp.RequestCtx) error { +// Handler validates request and respond with 200, 403 (with error) or 500 status code +func (s *RequestValidator) Handler(ctx *fasthttp.RequestCtx) error { keyValidationErrors := strconv2.Itoa(s.SchemaID) + web.APIModePostfixValidationErrors keyStatusCode := strconv2.Itoa(s.SchemaID) + web.APIModePostfixStatusCode diff --git a/cmd/api-firewall/internal/handlers/api/routes.go b/cmd/api-firewall/internal/handlers/api/routes.go index 8653077..9390868 100644 --- a/cmd/api-firewall/internal/handlers/api/routes.go +++ b/cmd/api-firewall/internal/handlers/api/routes.go @@ -50,7 +50,7 @@ func Handlers(lock *sync.RWMutex, cfg *config.APIMode, shutdown chan os.Signal, } // Construct the web.App which holds all routes as well as common Middleware. - apps := NewAPIModeApp(lock, cfg.PassOptionsRequests, storedSpecs, shutdown, logger, mid.IPAllowlist(&ipAllowlistOptions), mid.WAFModSecurity(&modSecOptions), mid.Logger(logger), mid.MIMETypeIdentifier(logger), mid.Errors(logger), mid.Panics(logger)) + apps := NewApp(lock, cfg.PassOptionsRequests, storedSpecs, shutdown, logger, mid.IPAllowlist(&ipAllowlistOptions), mid.WAFModSecurity(&modSecOptions), mid.Logger(logger), mid.MIMETypeIdentifier(logger), mid.Errors(logger), mid.Panics(logger)) for _, schemaID := range schemaIDs { @@ -77,7 +77,7 @@ func Handlers(lock *sync.RWMutex, cfg *config.APIMode, shutdown chan os.Signal, for i := 0; i < len(newSwagRouter.Routes); i++ { - s := APIMode{ + s := RequestValidator{ CustomRoute: &newSwagRouter.Routes[i], Log: logger, Cfg: cfg, @@ -99,7 +99,7 @@ func Handlers(lock *sync.RWMutex, cfg *config.APIMode, shutdown chan os.Signal, s.Log.Debugf("handler: Schema ID %d: OpenAPI version %s: Loaded path %s - %s", schemaID, storedSpecs.SpecificationVersion(schemaID), newSwagRouter.Routes[i].Method, updRoutePath) - if err := apps.Handle(schemaID, newSwagRouter.Routes[i].Method, updRoutePath, s.APIModeHandler); err != nil { + if err := apps.Handle(schemaID, newSwagRouter.Routes[i].Method, updRoutePath, s.Handler); err != nil { logger.WithFields(logrus.Fields{"error": err, "schema_id": schemaID}).Error("Registration of the OAS failed") } } diff --git a/internal/platform/database/database.go b/internal/platform/database/database.go index c6d3f0a..96f4807 100644 --- a/internal/platform/database/database.go +++ b/internal/platform/database/database.go @@ -2,6 +2,7 @@ package database import ( "bytes" + "github.com/getkin/kin-openapi/openapi3" _ "github.com/mattn/go-sqlite3" "github.com/sirupsen/logrus" @@ -25,6 +26,7 @@ func getSpecBytes(spec string) []byte { return bytes.NewBufferString(spec).Bytes() } +// NewOpenAPIDB loads OAS specs from the database and returns the struct with the parsed specs func NewOpenAPIDB(log *logrus.Logger, dbStoragePath string, version int) (DBOpenAPILoader, error) { switch version { From 92e7c493c053b3c74c664c25c7314312ff0a5ccd Mon Sep 17 00:00:00 2001 From: Nikolay Tkachenko Date: Sun, 14 Apr 2024 16:36:25 +0300 Subject: [PATCH 07/12] Update router in Proxy and GraphQL modes. Add tests --- cmd/api-firewall/internal/handlers/api/app.go | 2 +- .../internal/handlers/api/errors.go | 4 + .../internal/handlers/api/openapi.go | 2 +- .../internal/handlers/api/routes.go | 2 +- .../internal/handlers/graphql/httpHandler.go | 3 +- .../internal/handlers/graphql/routes.go | 19 +- .../internal/handlers/proxy/openapi.go | 17 +- .../internal/handlers/proxy/routes.go | 26 +- .../internal/updater/wallarm_api2_update.db | Bin 98304 -> 98304 bytes cmd/api-firewall/tests/main_api_mode_test.go | 76 +++++ cmd/api-firewall/tests/main_graphql_test.go | 85 +++++ cmd/api-firewall/tests/main_test.go | 299 +++++++++++++++++- go.mod | 3 +- go.sum | 2 - internal/mid/allowiplist.go | 3 +- internal/mid/denylist.go | 3 +- internal/mid/errors.go | 3 +- internal/mid/logger.go | 3 +- internal/mid/mimetype.go | 3 +- internal/mid/modsec.go | 3 +- internal/mid/panics.go | 3 +- internal/mid/proxy.go | 3 +- internal/mid/shadowAPI.go | 4 +- internal/platform/router/chi.go | 6 +- internal/platform/router/handler.go | 7 + internal/platform/router/mux.go | 8 +- internal/platform/router/mux_test.go | 6 +- internal/platform/router/tree.go | 19 +- internal/platform/router/tree_test.go | 143 +++++---- internal/platform/web/adaptor.go | 3 +- internal/platform/web/middleware.go | 6 +- internal/platform/web/web.go | 172 ++++++---- 32 files changed, 731 insertions(+), 207 deletions(-) create mode 100644 internal/platform/router/handler.go diff --git a/cmd/api-firewall/internal/handlers/api/app.go b/cmd/api-firewall/internal/handlers/api/app.go index 676f280..b340a4a 100644 --- a/cmd/api-firewall/internal/handlers/api/app.go +++ b/cmd/api-firewall/internal/handlers/api/app.go @@ -65,7 +65,7 @@ func NewApp(lock *sync.RWMutex, passOPTIONS bool, storedSpecs database.DBOpenAPI // Handle is our mechanism for mounting Handlers for a given HTTP verb and path // pair, this makes for really easy, convenient routing. -func (a *App) Handle(schemaID int, method string, path string, handler web.Handler, mw ...web.Middleware) error { +func (a *App) Handle(schemaID int, method string, path string, handler router.Handler, mw ...web.Middleware) error { // First wrap handler specific middleware around this handler. handler = web.WrapMiddleware(mw, handler) diff --git a/cmd/api-firewall/internal/handlers/api/errors.go b/cmd/api-firewall/internal/handlers/api/errors.go index d3a6d58..1cc4cd9 100644 --- a/cmd/api-firewall/internal/handlers/api/errors.go +++ b/cmd/api-firewall/internal/handlers/api/errors.go @@ -71,6 +71,8 @@ func checkRequiredFields(reqErr *openapi3filter.RequestError, schemaError *opena response.Code = ErrCodeRequiredQueryParameterMissed case "cookie": response.Code = ErrCodeRequiredCookieParameterMissed + case "path": + response.Code = ErrCodeRequiredPathParameterMissed case "header": response.Code = ErrCodeRequiredHeaderMissed } @@ -91,6 +93,8 @@ func checkRequiredFields(reqErr *openapi3filter.RequestError, schemaError *opena response.Code = ErrCodeRequiredQueryParameterInvalidValue case "cookie": response.Code = ErrCodeRequiredCookieParameterInvalidValue + case "path": + response.Code = ErrCodeRequiredPathParameterInvalidValue case "header": response.Code = ErrCodeRequiredHeaderInvalidValue } diff --git a/cmd/api-firewall/internal/handlers/api/openapi.go b/cmd/api-firewall/internal/handlers/api/openapi.go index 2ec7b16..416990e 100644 --- a/cmd/api-firewall/internal/handlers/api/openapi.go +++ b/cmd/api-firewall/internal/handlers/api/openapi.go @@ -3,6 +3,7 @@ package api import ( "context" "fmt" + "github.com/wallarm/api-firewall/internal/platform/router" "net/http" strconv2 "strconv" "strings" @@ -17,7 +18,6 @@ import ( "github.com/valyala/fastjson" "github.com/wallarm/api-firewall/internal/config" "github.com/wallarm/api-firewall/internal/platform/loader" - "github.com/wallarm/api-firewall/internal/platform/router" "github.com/wallarm/api-firewall/internal/platform/validator" "github.com/wallarm/api-firewall/internal/platform/web" ) diff --git a/cmd/api-firewall/internal/handlers/api/routes.go b/cmd/api-firewall/internal/handlers/api/routes.go index 9390868..c0592ee 100644 --- a/cmd/api-firewall/internal/handlers/api/routes.go +++ b/cmd/api-firewall/internal/handlers/api/routes.go @@ -100,7 +100,7 @@ func Handlers(lock *sync.RWMutex, cfg *config.APIMode, shutdown chan os.Signal, s.Log.Debugf("handler: Schema ID %d: OpenAPI version %s: Loaded path %s - %s", schemaID, storedSpecs.SpecificationVersion(schemaID), newSwagRouter.Routes[i].Method, updRoutePath) if err := apps.Handle(schemaID, newSwagRouter.Routes[i].Method, updRoutePath, s.Handler); err != nil { - logger.WithFields(logrus.Fields{"error": err, "schema_id": schemaID}).Error("Registration of the OAS failed") + logger.WithFields(logrus.Fields{"error": err, "schema_id": schemaID}).Errorf("The OAS endpoint registration failed: method %s, path %s", newSwagRouter.Routes[i].Method, updRoutePath) } } diff --git a/cmd/api-firewall/internal/handlers/graphql/httpHandler.go b/cmd/api-firewall/internal/handlers/graphql/httpHandler.go index e1dabbb..6c68e02 100644 --- a/cmd/api-firewall/internal/handlers/graphql/httpHandler.go +++ b/cmd/api-firewall/internal/handlers/graphql/httpHandler.go @@ -7,6 +7,8 @@ import ( "strings" "sync" + "golang.org/x/sync/errgroup" + "github.com/fasthttp/websocket" "github.com/savsgio/gotils/strconv" "github.com/sirupsen/logrus" @@ -17,7 +19,6 @@ import ( "github.com/wallarm/api-firewall/internal/platform/validator" "github.com/wallarm/api-firewall/internal/platform/web" "github.com/wundergraph/graphql-go-tools/pkg/graphql" - "golang.org/x/sync/errgroup" ) type Handler struct { diff --git a/cmd/api-firewall/internal/handlers/graphql/routes.go b/cmd/api-firewall/internal/handlers/graphql/routes.go index b4a75b1..ac63d4e 100644 --- a/cmd/api-firewall/internal/handlers/graphql/routes.go +++ b/cmd/api-firewall/internal/handlers/graphql/routes.go @@ -5,10 +5,9 @@ import ( "os" "sync" + "github.com/fasthttp/websocket" "github.com/savsgio/gotils/strconv" "github.com/savsgio/gotils/strings" - - "github.com/fasthttp/websocket" "github.com/sirupsen/logrus" "github.com/valyala/fasthttp" "github.com/valyala/fastjson" @@ -26,7 +25,7 @@ func Handlers(cfg *config.GraphQLMode, schema *graphql.Schema, serverURL *url.UR // Construct the web.App which holds all routes as well as common Middleware. appOptions := web.AppAdditionalOptions{ - Mode: cfg.Mode, + Mode: web.GraphQLMode, PassOptions: false, } @@ -85,8 +84,12 @@ func Handlers(cfg *config.GraphQLMode, schema *graphql.Schema, serverURL *url.UR graphqlPath = "/" } - app.Handle(fasthttp.MethodGet, graphqlPath, s.GraphQLHandle) - app.Handle(fasthttp.MethodPost, graphqlPath, s.GraphQLHandle) + if err := app.Handle(fasthttp.MethodGet, graphqlPath, s.GraphQLHandle); err != nil { + logger.WithFields(logrus.Fields{"error": err}).Error("GraphQL GET endpoint registration failed") + } + if err := app.Handle(fasthttp.MethodPost, graphqlPath, s.GraphQLHandle); err != nil { + logger.WithFields(logrus.Fields{"error": err}).Error("GraphQL POST endpoint registration failed") + } // enable playground if cfg.Graphql.Playground { @@ -104,9 +107,11 @@ func Handlers(cfg *config.GraphQLMode, schema *graphql.Schema, serverURL *url.UR } for i := range handlers { - app.Handle(fasthttp.MethodGet, handlers[i].Path, web.NewFastHTTPHandler(handlers[i].Handler, true)) + if err := app.Handle(fasthttp.MethodGet, handlers[i].Path, web.NewFastHTTPHandler(handlers[i].Handler, true)); err != nil { + logger.WithFields(logrus.Fields{"error": err}).Error("GraphQL Playground endpoint registration failed") + } } } - return app.Router.Handler + return app.MainHandler } diff --git a/cmd/api-firewall/internal/handlers/proxy/openapi.go b/cmd/api-firewall/internal/handlers/proxy/openapi.go index 49c00d5..93e1061 100644 --- a/cmd/api-firewall/internal/handlers/proxy/openapi.go +++ b/cmd/api-firewall/internal/handlers/proxy/openapi.go @@ -19,6 +19,7 @@ import ( "github.com/wallarm/api-firewall/internal/platform/loader" "github.com/wallarm/api-firewall/internal/platform/oauth2" "github.com/wallarm/api-firewall/internal/platform/proxy" + "github.com/wallarm/api-firewall/internal/platform/router" "github.com/wallarm/api-firewall/internal/platform/validator" "github.com/wallarm/api-firewall/internal/platform/web" ) @@ -97,8 +98,15 @@ func getValidationHeader(ctx *fasthttp.RequestCtx, err error) *string { func (s *openapiWaf) openapiWafHandler(ctx *fasthttp.RequestCtx) error { + // Pass OPTIONS if the feature is enabled + var isOptionsReq, ok bool + if isOptionsReq, ok = ctx.UserValue(web.PassRequestOPTIONS).(bool); !ok { + isOptionsReq = false + } + // Proxy request if APIFW is disabled - if strings.EqualFold(s.cfg.RequestValidation, web.ValidationDisable) && strings.EqualFold(s.cfg.ResponseValidation, web.ValidationDisable) { + if isOptionsReq == true || + strings.EqualFold(s.cfg.RequestValidation, web.ValidationDisable) && strings.EqualFold(s.cfg.ResponseValidation, web.ValidationDisable) { if err := proxy.Perform(ctx, s.proxyPool); err != nil { s.logger.WithFields(logrus.Fields{ "error": err, @@ -139,12 +147,7 @@ func (s *openapiWaf) openapiWafHandler(ctx *fasthttp.RequestCtx) error { var pathParams map[string]string if s.customRoute.ParametersNumberInPath > 0 { - pathParams = make(map[string]string) - - ctx.VisitUserValues(func(key []byte, value interface{}) { - keyStr := strconv.B2S(key) - pathParams[keyStr] = value.(string) - }) + pathParams = router.AllURLParams(ctx) } // Convert fasthttp request to net/http request diff --git a/cmd/api-firewall/internal/handlers/proxy/routes.go b/cmd/api-firewall/internal/handlers/proxy/routes.go index a3b360b..ab8c6e7 100644 --- a/cmd/api-firewall/internal/handlers/proxy/routes.go +++ b/cmd/api-firewall/internal/handlers/proxy/routes.go @@ -93,6 +93,15 @@ func Handlers(cfg *config.ProxyMode, serverURL *url.URL, shutdown chan os.Signal } } + // set handler for default behavior (404, 405) + defaultOpenAPIWaf := openapiWaf{ + customRoute: nil, + proxyPool: httpClientsPool, + logger: logger, + cfg: cfg, + parserPool: &parserPool, + } + // Construct the web.App which holds all routes as well as common Middleware. options := web.AppAdditionalOptions{ Mode: cfg.Mode, @@ -101,6 +110,7 @@ func Handlers(cfg *config.ProxyMode, serverURL *url.URL, shutdown chan os.Signal ResponseValidation: cfg.ResponseValidation, CustomBlockStatusCode: cfg.CustomBlockStatusCode, OptionsHandler: optionsHandler, + DefaultHandler: defaultOpenAPIWaf.openapiWafHandler, } proxyOptions := mid.ProxyOptions{ @@ -166,18 +176,10 @@ func Handlers(cfg *config.ProxyMode, serverURL *url.URL, shutdown chan os.Signal s.logger.Debugf("handler: Loaded path %s - %s", swagRouter.Routes[i].Method, updRoutePath) - app.Handle(swagRouter.Routes[i].Method, updRoutePath, s.openapiWafHandler) - } - - // set handler for default behavior (404, 405) - s := openapiWaf{ - customRoute: nil, - proxyPool: httpClientsPool, - logger: logger, - cfg: cfg, - parserPool: &parserPool, + if err := app.Handle(swagRouter.Routes[i].Method, updRoutePath, s.openapiWafHandler); err != nil { + logger.WithFields(logrus.Fields{"error": err}).Errorf("The OAS endpoint registration failed: method %s, path %s", swagRouter.Routes[i].Method, updRoutePath) + } } - app.SetDefaultBehavior(s.openapiWafHandler) - return app.Router.Handler + return app.MainHandler } diff --git a/cmd/api-firewall/internal/updater/wallarm_api2_update.db b/cmd/api-firewall/internal/updater/wallarm_api2_update.db index b4918fe3d1cdf73348bce4fc954da336633d920b..d524d1e3d82d849a6c01456d923f8d94983e44bb 100644 GIT binary patch delta 34 qcmZo@U~6b#n;^~DI8nx#v2kNUlq_S|=H0R`LX6j%43-%%1^@uX4+|Us delta 34 qcmZo@U~6b#n;^~TJyFJ)(R*V;lq{pm=H0R`LX1b643-%%1^@uM7Yd93 diff --git a/cmd/api-firewall/tests/main_api_mode_test.go b/cmd/api-firewall/tests/main_api_mode_test.go index 739ed92..b19a444 100644 --- a/cmd/api-firewall/tests/main_api_mode_test.go +++ b/cmd/api-firewall/tests/main_api_mode_test.go @@ -368,6 +368,38 @@ paths: 200: description: Ok content: { } + /path/{test}: + get: + parameters: + - name: test + in: path + required: true + schema: + type: string + enum: + - testValue1 + - testValue1 + summary: Get Test Info + responses: + 200: + description: Ok + content: { } + /path/{test}.php: + get: + parameters: + - name: test + in: path + required: true + schema: + type: string + enum: + - value1 + - value2 + summary: Get Test Info + responses: + 200: + description: Ok + content: { } /test/body/request: post: summary: Post Request to test Request Body presence @@ -552,6 +584,8 @@ func TestAPIModeBasic(t *testing.T) { // check all supported methods: GET POST PUT PATCH DELETE TRACE OPTIONS HEAD t.Run("testAPIModeAllMethods", apifwTests.testAPIModeAllMethods) + // check conflicts in the Path + t.Run("testConflictsInThePath", apifwTests.testConflictsInThePath) } func createForm(form map[string]string) (string, io.Reader, error) { @@ -2752,3 +2786,45 @@ func (s *APIModeServiceTests) testAPIModeAllMethods(t *testing.T) { // check response status code and response body checkResponseForbiddenStatusCode(t, &reqCtx, DefaultSchemaID, []string{handlersAPI.ErrCodeMethodAndPathNotFound}) } + +func (s *APIModeServiceTests) testConflictsInThePath(t *testing.T) { + + handler := handlersAPI.Handlers(s.lock, &cfg, s.shutdown, s.logger, s.dbSpec, nil, nil) + + // check all supported methods: GET POST PUT PATCH DELETE TRACE OPTIONS HEAD + for _, path := range []string{"/path/testValue1", "/path/value1.php"} { + req := fasthttp.AcquireRequest() + req.SetRequestURI(path) + req.Header.SetMethod("GET") + req.Header.Add(web.XWallarmSchemaIDHeader, fmt.Sprintf("%d", DefaultSchemaID)) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + t.Logf("Name of the test: %s; request method: %s; request uri: %s; request body: %s", t.Name(), string(reqCtx.Request.Header.Method()), string(reqCtx.Request.RequestURI()), string(reqCtx.Request.Body())) + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) + + // check response status code and response body + checkResponseOkStatusCode(t, &reqCtx, DefaultSchemaID) + } + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/path/valueNotExist.php") + req.Header.SetMethod("GET") + req.Header.Add(web.XWallarmSchemaIDHeader, fmt.Sprintf("%d", DefaultSchemaID)) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + t.Logf("Name of the test: %s; request method: %s; request uri: %s; request body: %s", t.Name(), string(reqCtx.Request.Header.Method()), string(reqCtx.Request.RequestURI()), string(reqCtx.Request.Body())) + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) + + // check response status code and response body + checkResponseForbiddenStatusCode(t, &reqCtx, DefaultSchemaID, []string{handlersAPI.ErrCodeRequiredPathParameterInvalidValue}) +} diff --git a/cmd/api-firewall/tests/main_graphql_test.go b/cmd/api-firewall/tests/main_graphql_test.go index 41e1357..f2a81e6 100644 --- a/cmd/api-firewall/tests/main_graphql_test.go +++ b/cmd/api-firewall/tests/main_graphql_test.go @@ -115,6 +115,8 @@ func TestGraphQLBasic(t *testing.T) { // basic test t.Run("basicGraphQLQuerySuccess", apifwTests.testGQLSuccess) + t.Run("basicGraphQLEndpointNotExists", apifwTests.testGQLEndpointNotExists) + t.Run("basicGraphQLGETQuerySuccess", apifwTests.testGQLGETSuccess) t.Run("basicGraphQLGETQueryMutationFailed", apifwTests.testGQLGETMutationFailed) t.Run("basicGraphQLQueryValidationFailed", apifwTests.testGQLValidationFailed) @@ -228,6 +230,89 @@ func (s *ServiceGraphQLTests) testGQLSuccess(t *testing.T) { } +func (s *ServiceGraphQLTests) testGQLEndpointNotExists(t *testing.T) { + + gqlCfg := config.GraphQL{ + MaxQueryComplexity: 0, + MaxQueryDepth: 0, + NodeCountLimit: 0, + Playground: false, + Introspection: false, + Schema: "", + RequestValidation: "BLOCK", + } + var cfg = config.GraphQLMode{ + Graphql: gqlCfg, + } + + // parse the GraphQL schema + schema, err := graphql.NewSchemaFromString(testSchema) + if err != nil { + t.Fatalf("Loading GraphQL Schema error: %v", err) + } + + handler := graphqlHandler.Handlers(&cfg, schema, s.serverUrl, s.shutdown, s.logger, s.proxy, s.backendWSClient, nil, nil) + + // Construct GraphQL request payload + query := ` + query { + room(name: "GeneralChat") { + name + messages { + id + text + createdBy + createdAt + } + } +} + ` + var requestBody = map[string]interface{}{ + "query": query, + } + + responseBody := `{ + "data": { + "room": { + "name": "GeneralChat", + "messages": [ + { + "id": "TrsXJcKa", + "text": "Hello, world!", + "createdBy": "TestUser", + "createdAt": "2023-01-01T00:00:00+00:00" + } + ] + } + } +}` + + jsonValue, _ := json.Marshal(requestBody) + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/endpointNotExists") + req.Header.SetMethod("POST") + req.SetBodyStream(bytes.NewReader(jsonValue), -1) + req.Header.SetContentType("application/json") + + resp := fasthttp.AcquireResponse() + resp.SetStatusCode(fasthttp.StatusOK) + resp.Header.SetContentType("application/json") + resp.SetBody([]byte(responseBody)) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 403 { + t.Errorf("Incorrect response status code. Expected: 403 and got %d", + reqCtx.Response.StatusCode()) + } + +} + func (s *ServiceGraphQLTests) testGQLGETSuccess(t *testing.T) { gqlCfg := config.GraphQL{ diff --git a/cmd/api-firewall/tests/main_test.go b/cmd/api-firewall/tests/main_test.go index 9a7f451..1f691ce 100644 --- a/cmd/api-firewall/tests/main_test.go +++ b/cmd/api-firewall/tests/main_test.go @@ -217,6 +217,38 @@ paths: 200: description: Ok content: { } + /path/{test}: + get: + parameters: + - name: test + in: path + required: true + schema: + type: string + enum: + - testValue1 + - testValue1 + summary: Get Test Info + responses: + 200: + description: Ok + content: { } + /path/{test}.php: + get: + parameters: + - name: test + in: path + required: true + schema: + type: string + enum: + - value1 + - value2 + summary: Get Test Info + responses: + 200: + description: Ok + content: { } /user: get: summary: Get User Info @@ -413,6 +445,8 @@ func TestBasic(t *testing.T) { } // basic test + t.Run("basicCustomBlockStatusCode", apifwTests.testCustomBlockStatusCode) + t.Run("basicPathNotExists", apifwTests.testPathNotExists) t.Run("basicBlockBlockMode", apifwTests.testBlockMode) t.Run("basicLogOnlyLogOnlyMode", apifwTests.testLogOnlyMode) t.Run("basicDisableDisableMode", apifwTests.testDisableMode) @@ -450,6 +484,205 @@ func TestBasic(t *testing.T) { t.Run("unknownParamPostBody", apifwTests.unknownParamPostBody) t.Run("unknownParamJSONParam", apifwTests.unknownParamJSONParam) t.Run("unknownParamInvalidMimeType", apifwTests.unknownParamUnsupportedMimeType) + + t.Run("testConflictPaths", apifwTests.testConflictPaths) +} + +func (s *ServiceTests) testCustomBlockStatusCode(t *testing.T) { + + var cfg = config.ProxyMode{ + RequestValidation: "BLOCK", + ResponseValidation: "BLOCK", + CustomBlockStatusCode: 403, + AddValidationStatusHeader: false, + ShadowAPI: config.ShadowAPI{ + ExcludeList: []int{404, 401}, + }, + } + + handler := proxy2.Handlers(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil, nil, nil) + + p, err := json.Marshal(map[string]interface{}{ + "firstname": "test", + "lastname": "test", + "job": "test", + "email": "test@wallarm.com", + "url": "http://wallarm.com", + }) + + if err != nil { + t.Fatal(err) + } + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/test/signupNotExist") + req.Header.SetMethod("POST") + req.SetBodyStream(bytes.NewReader(p), -1) + req.Header.SetContentType("application/json") + + resp := fasthttp.AcquireResponse() + resp.SetStatusCode(fasthttp.StatusOK) + resp.Header.SetContentType("application/json") + resp.SetBody([]byte("{\"status\":\"success\"}")) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != cfg.CustomBlockStatusCode { + t.Errorf("Incorrect response status code. Expected: %d and got %d", + cfg.CustomBlockStatusCode, reqCtx.Response.StatusCode()) + } + + // Repeat request with new Custom block status code + cfg = config.ProxyMode{ + RequestValidation: "BLOCK", + ResponseValidation: "BLOCK", + CustomBlockStatusCode: 401, + AddValidationStatusHeader: false, + ShadowAPI: config.ShadowAPI{ + ExcludeList: []int{404, 401}, + }, + } + + handler = proxy2.Handlers(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil, nil, nil) + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != cfg.CustomBlockStatusCode { + t.Errorf("Incorrect response status code. Expected: %d and got %d", + cfg.CustomBlockStatusCode, reqCtx.Response.StatusCode()) + } + +} + +func (s *ServiceTests) testPathNotExists(t *testing.T) { + + var cfg = config.ProxyMode{ + RequestValidation: "BLOCK", + ResponseValidation: "BLOCK", + CustomBlockStatusCode: 403, + AddValidationStatusHeader: false, + ShadowAPI: config.ShadowAPI{ + ExcludeList: []int{404, 401}, + }, + } + + handler := proxy2.Handlers(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil, nil, nil) + + p, err := json.Marshal(map[string]interface{}{ + "firstname": "test", + "lastname": "test", + "job": "test", + "email": "test@wallarm.com", + "url": "http://wallarm.com", + }) + + if err != nil { + t.Fatal(err) + } + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/test/signupNotExist") + req.Header.SetMethod("POST") + req.SetBodyStream(bytes.NewReader(p), -1) + req.Header.SetContentType("application/json") + + resp := fasthttp.AcquireResponse() + resp.SetStatusCode(fasthttp.StatusOK) + resp.Header.SetContentType("application/json") + resp.SetBody([]byte("{\"status\":\"success\"}")) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != cfg.CustomBlockStatusCode { + t.Errorf("Incorrect response status code. Expected: %d and got %d", + cfg.CustomBlockStatusCode, reqCtx.Response.StatusCode()) + } + + req = fasthttp.AcquireRequest() + req.SetRequestURI("/test/signup") + req.Header.SetMethod("TRACE") + req.SetBodyStream(bytes.NewReader(p), -1) + req.Header.SetContentType("application/json") + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != cfg.CustomBlockStatusCode { + t.Errorf("Incorrect response status code. Expected: %d and got %d", + cfg.CustomBlockStatusCode, reqCtx.Response.StatusCode()) + } + + // Repeat request with new Custom block status code + cfg = config.ProxyMode{ + RequestValidation: "LOG_ONLY", + ResponseValidation: "LOG_ONLY", + CustomBlockStatusCode: 403, + AddValidationStatusHeader: false, + ShadowAPI: config.ShadowAPI{ + ExcludeList: []int{404, 401}, + }, + } + + handler = proxy2.Handlers(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil, nil, nil) + + s.proxy.EXPECT().Get().Return(s.client, nil) + s.client.EXPECT().Do(gomock.Any(), gomock.Any()).SetArg(1, *resp) + s.proxy.EXPECT().Put(s.client).Return(nil) + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + + // Repeat request with new Custom block status code + cfg = config.ProxyMode{ + RequestValidation: "DISABLE", + ResponseValidation: "DISABLE", + CustomBlockStatusCode: 403, + AddValidationStatusHeader: false, + ShadowAPI: config.ShadowAPI{ + ExcludeList: []int{404, 401}, + }, + } + + handler = proxy2.Handlers(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil, nil, nil) + + s.proxy.EXPECT().Get().Return(s.client, nil) + s.client.EXPECT().Do(gomock.Any(), gomock.Any()).SetArg(1, *resp) + s.proxy.EXPECT().Put(s.client).Return(nil) + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + } func (s *ServiceTests) testBlockMode(t *testing.T) { @@ -847,12 +1080,7 @@ func (s *ServiceTests) testShadowAPI(t *testing.T) { }{Tokens: tokensCfg}, } - deniedTokens, err := denylist.New(&cfg.Denylist, s.logger) - if err != nil { - t.Fatal(err) - } - - handler := proxy2.Handlers(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, deniedTokens, nil, nil) + handler := proxy2.Handlers(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil, nil, nil) p, err := json.Marshal(map[string]interface{}{ "firstname": "test", @@ -2471,3 +2699,62 @@ func (s *ServiceTests) unknownParamUnsupportedMimeType(t *testing.T) { } } + +func (s *ServiceTests) testConflictPaths(t *testing.T) { + + var cfg = config.ProxyMode{ + RequestValidation: "BLOCK", + ResponseValidation: "BLOCK", + CustomBlockStatusCode: 403, + AddValidationStatusHeader: false, + ShadowAPI: config.ShadowAPI{ + ExcludeList: []int{404, 401}, + }, + } + + handler := proxy2.Handlers(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil, nil, nil) + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/path/testValue1") + req.Header.SetMethod("GET") + + resp := fasthttp.AcquireResponse() + resp.SetStatusCode(fasthttp.StatusOK) + resp.Header.SetContentType("application/json") + resp.SetBody([]byte("{\"status\":\"success\"}")) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + s.proxy.EXPECT().Get().Return(s.client, nil) + s.client.EXPECT().Do(gomock.Any(), gomock.Any()).SetArg(1, *resp) + s.proxy.EXPECT().Put(s.client).Return(nil) + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + + req = fasthttp.AcquireRequest() + req.SetRequestURI("/path/value1.php") + req.Header.SetMethod("GET") + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + s.proxy.EXPECT().Get().Return(s.client, nil) + s.client.EXPECT().Do(gomock.Any(), gomock.Any()).SetArg(1, *resp) + s.proxy.EXPECT().Put(s.client).Return(nil) + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + +} diff --git a/go.mod b/go.mod index 57a0498..ad2bd34 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,6 @@ require ( github.com/clbanning/mxj/v2 v2.7.0 github.com/corazawaf/coraza/v3 v3.1.0 github.com/dgraph-io/ristretto v0.1.1 - github.com/fasthttp/router v1.5.0 github.com/fasthttp/websocket v1.5.8 github.com/gabriel-vasile/mimetype v1.4.3 github.com/getkin/kin-openapi v0.118.0 @@ -16,7 +15,6 @@ require ( github.com/golang-jwt/jwt v3.2.2+incompatible github.com/golang/mock v1.6.0 github.com/google/uuid v1.6.0 - github.com/invopop/yaml v0.2.0 github.com/karlseguin/ccache/v2 v2.0.8 github.com/klauspost/compress v1.17.7 github.com/mattn/go-sqlite3 v1.14.22 @@ -52,6 +50,7 @@ require ( github.com/hashicorp/golang-lru v0.5.4 // indirect github.com/huandu/xstrings v1.2.1 // indirect github.com/imdario/mergo v0.3.8 // indirect + github.com/invopop/yaml v0.2.0 // indirect github.com/jensneuse/abstractlogger v0.0.4 // indirect github.com/jensneuse/byte-template v0.0.0-20200214152254-4f3cf06e5c68 // indirect github.com/jensneuse/pipeline v0.0.0-20200117120358-9fb4de085cd6 // indirect diff --git a/go.sum b/go.sum index 20e8922..81f4cfc 100644 --- a/go.sum +++ b/go.sum @@ -40,8 +40,6 @@ github.com/eclipse/paho.mqtt.golang v1.2.0 h1:1F8mhG9+aO5/xpdtFkW4SxOJB67ukuDC3t github.com/eclipse/paho.mqtt.golang v1.2.0/go.mod h1:H9keYFcgq3Qr5OUJm/JZI/i6U7joQ8SYLhZwfeOo6Ts= github.com/evanphx/json-patch/v5 v5.1.0 h1:B0aXl1o/1cP8NbviYiBMkcHBtUjIJ1/Ccg6b+SwCLQg= github.com/evanphx/json-patch/v5 v5.1.0/go.mod h1:G79N1coSVB93tBe7j6PhzjmR3/2VvlbKOFpnXhI9Bw4= -github.com/fasthttp/router v1.5.0 h1:3Qbbo27HAPzwbpRzgiV5V9+2faPkPt3eNuRaDV6LYDA= -github.com/fasthttp/router v1.5.0/go.mod h1:FddcKNXFZg1imHcy+uKB0oo/o6yE9zD3wNguqlhWDak= github.com/fasthttp/websocket v1.5.8 h1:k5DpirKkftIF/w1R8ZzjSgARJrs54Je9YJK37DL/Ah8= github.com/fasthttp/websocket v1.5.8/go.mod h1:d08g8WaT6nnyvg9uMm8K9zMYyDjfKyj3170AtPRuVU0= github.com/foxcpp/go-mockdns v1.1.0 h1:jI0rD8M0wuYAxL7r/ynTrCQQq0BVqfB99Vgk7DlmewI= diff --git a/internal/mid/allowiplist.go b/internal/mid/allowiplist.go index 01dd802..fcefae2 100644 --- a/internal/mid/allowiplist.go +++ b/internal/mid/allowiplist.go @@ -10,6 +10,7 @@ import ( "github.com/valyala/fasthttp" "github.com/wallarm/api-firewall/internal/config" "github.com/wallarm/api-firewall/internal/platform/allowiplist" + "github.com/wallarm/api-firewall/internal/platform/router" "github.com/wallarm/api-firewall/internal/platform/web" ) @@ -27,7 +28,7 @@ var errAccessDeniedIP = errors.New("access denied to this IP") func IPAllowlist(options *IPAllowListOptions) web.Middleware { // This is the actual middleware function to be executed. - m := func(before web.Handler) web.Handler { + m := func(before router.Handler) router.Handler { // Create the handler that will be attached in the middleware chain. h := func(ctx *fasthttp.RequestCtx) error { diff --git a/internal/mid/denylist.go b/internal/mid/denylist.go index b06fa25..5d99e4a 100644 --- a/internal/mid/denylist.go +++ b/internal/mid/denylist.go @@ -8,6 +8,7 @@ import ( "github.com/valyala/fasthttp" "github.com/wallarm/api-firewall/internal/config" "github.com/wallarm/api-firewall/internal/platform/denylist" + "github.com/wallarm/api-firewall/internal/platform/router" "github.com/wallarm/api-firewall/internal/platform/web" ) @@ -25,7 +26,7 @@ var errAccessDenied = errors.New("access denied") func Denylist(options *DenylistOptions) web.Middleware { // This is the actual middleware function to be executed. - m := func(before web.Handler) web.Handler { + m := func(before router.Handler) router.Handler { // Create the handler that will be attached in the middleware chain. h := func(ctx *fasthttp.RequestCtx) error { diff --git a/internal/mid/errors.go b/internal/mid/errors.go index 32a965e..1b39b62 100644 --- a/internal/mid/errors.go +++ b/internal/mid/errors.go @@ -3,6 +3,7 @@ package mid import ( "github.com/sirupsen/logrus" "github.com/valyala/fasthttp" + "github.com/wallarm/api-firewall/internal/platform/router" "github.com/wallarm/api-firewall/internal/platform/web" ) @@ -12,7 +13,7 @@ import ( func Errors(logger *logrus.Logger) web.Middleware { // This is the actual middleware function to be executed. - m := func(before web.Handler) web.Handler { + m := func(before router.Handler) router.Handler { // Create the handler that will be attached in the middleware chain. h := func(ctx *fasthttp.RequestCtx) error { diff --git a/internal/mid/logger.go b/internal/mid/logger.go index 1d22243..85197c7 100644 --- a/internal/mid/logger.go +++ b/internal/mid/logger.go @@ -6,6 +6,7 @@ import ( "github.com/sirupsen/logrus" "github.com/valyala/fasthttp" + "github.com/wallarm/api-firewall/internal/platform/router" "github.com/wallarm/api-firewall/internal/platform/web" ) @@ -14,7 +15,7 @@ import ( func Logger(logger *logrus.Logger) web.Middleware { // This is the actual middleware function to be executed. - m := func(before web.Handler) web.Handler { + m := func(before router.Handler) router.Handler { // Create the handler that will be attached in the middleware chain. h := func(ctx *fasthttp.RequestCtx) error { diff --git a/internal/mid/mimetype.go b/internal/mid/mimetype.go index bcb6fca..74cddf4 100644 --- a/internal/mid/mimetype.go +++ b/internal/mid/mimetype.go @@ -4,6 +4,7 @@ import ( "github.com/gabriel-vasile/mimetype" "github.com/sirupsen/logrus" "github.com/valyala/fasthttp" + "github.com/wallarm/api-firewall/internal/platform/router" "github.com/wallarm/api-firewall/internal/platform/web" ) @@ -11,7 +12,7 @@ import ( func MIMETypeIdentifier(logger *logrus.Logger) web.Middleware { // This is the actual middleware function to be executed. - m := func(before web.Handler) web.Handler { + m := func(before router.Handler) router.Handler { // Create the handler that will be attached in the middleware chain. h := func(ctx *fasthttp.RequestCtx) error { diff --git a/internal/mid/modsec.go b/internal/mid/modsec.go index 2ec029b..4e5636c 100644 --- a/internal/mid/modsec.go +++ b/internal/mid/modsec.go @@ -15,6 +15,7 @@ import ( utils "github.com/savsgio/gotils/strconv" "github.com/sirupsen/logrus" "github.com/valyala/fasthttp" + "github.com/wallarm/api-firewall/internal/platform/router" "github.com/wallarm/api-firewall/internal/platform/web" ) @@ -100,7 +101,7 @@ func processRequest(tx types.Transaction, ctx *fasthttp.RequestCtx) (*types.Inte func WAFModSecurity(options *ModSecurityOptions) web.Middleware { // This is the actual middleware function to be executed. - m := func(before web.Handler) web.Handler { + m := func(before router.Handler) router.Handler { // Create the handler that will be attached in the middleware chain. h := func(ctx *fasthttp.RequestCtx) error { diff --git a/internal/mid/panics.go b/internal/mid/panics.go index c67f778..57f21a4 100644 --- a/internal/mid/panics.go +++ b/internal/mid/panics.go @@ -6,6 +6,7 @@ import ( "github.com/pkg/errors" "github.com/sirupsen/logrus" "github.com/valyala/fasthttp" + "github.com/wallarm/api-firewall/internal/platform/router" "github.com/wallarm/api-firewall/internal/platform/web" ) @@ -14,7 +15,7 @@ import ( func Panics(logger *logrus.Logger) web.Middleware { // This is the actual middleware function to be executed. - m := func(after web.Handler) web.Handler { + m := func(after router.Handler) router.Handler { // Create the handler that will be attached in the middleware chain. h := func(ctx *fasthttp.RequestCtx) (err error) { diff --git a/internal/mid/proxy.go b/internal/mid/proxy.go index 353f633..c1b25e1 100644 --- a/internal/mid/proxy.go +++ b/internal/mid/proxy.go @@ -9,6 +9,7 @@ import ( "github.com/savsgio/gotils/strconv" "github.com/valyala/fasthttp" + "github.com/wallarm/api-firewall/internal/platform/router" "github.com/wallarm/api-firewall/internal/platform/web" ) @@ -45,7 +46,7 @@ type ProxyOptions struct { func Proxy(options *ProxyOptions) web.Middleware { // This is the actual middleware function to be executed. - m := func(before web.Handler) web.Handler { + m := func(before router.Handler) router.Handler { // Create the handler that will be attached in the middleware chain. h := func(ctx *fasthttp.RequestCtx) error { diff --git a/internal/mid/shadowAPI.go b/internal/mid/shadowAPI.go index 871a10a..a8c7dc2 100644 --- a/internal/mid/shadowAPI.go +++ b/internal/mid/shadowAPI.go @@ -6,6 +6,7 @@ import ( "github.com/sirupsen/logrus" "github.com/valyala/fasthttp" "github.com/wallarm/api-firewall/internal/config" + "github.com/wallarm/api-firewall/internal/platform/router" "github.com/wallarm/api-firewall/internal/platform/web" "golang.org/x/exp/slices" ) @@ -15,7 +16,7 @@ import ( func ShadowAPIMonitor(logger *logrus.Logger, cfg *config.ShadowAPI) web.Middleware { // This is the actual middleware function to be executed. - m := func(before web.Handler) web.Handler { + m := func(before router.Handler) router.Handler { // Create the handler that will be attached in the middleware chain. h := func(ctx *fasthttp.RequestCtx) error { @@ -48,6 +49,7 @@ func ShadowAPIMonitor(logger *logrus.Logger, cfg *config.ShadowAPI) web.Middlewa // check response status code statusCode := ctx.Response.StatusCode() idx := slices.IndexFunc(cfg.ExcludeList, func(c int) bool { return c == statusCode }) + // if response status code not found in the OpenAPI spec AND the code not in the exclude list if isProxyStatusCodeNotFound && idx < 0 { logger.WithFields(logrus.Fields{ diff --git a/internal/platform/router/chi.go b/internal/platform/router/chi.go index 53193f0..71e1a87 100644 --- a/internal/platform/router/chi.go +++ b/internal/platform/router/chi.go @@ -1,7 +1,5 @@ package router -import "github.com/wallarm/api-firewall/internal/platform/web" - // NewRouter returns a new Mux object that implements the Router interface. func NewRouter() *Mux { return NewMux() @@ -14,7 +12,7 @@ type Router interface { // AddEndpoint adds routes for `pattern` that matches // the `method` HTTP method. - AddEndpoint(method, pattern string, handler web.Handler) error + AddEndpoint(method, pattern string, handler Handler) error } // Routes interface adds two methods for router traversal, which is also @@ -26,5 +24,5 @@ type Routes interface { // Find searches the routing tree for a handler that matches // the method/path - similar to routing a http request, but without // executing the handler thereafter. - Find(rctx *Context, method, path string) web.Handler + Find(rctx *Context, method, path string) Handler } diff --git a/internal/platform/router/handler.go b/internal/platform/router/handler.go new file mode 100644 index 0000000..eb63535 --- /dev/null +++ b/internal/platform/router/handler.go @@ -0,0 +1,7 @@ +package router + +import "github.com/valyala/fasthttp" + +// A Handler is a type that handles an http request within our own little mini +// framework. +type Handler func(ctx *fasthttp.RequestCtx) error diff --git a/internal/platform/router/mux.go b/internal/platform/router/mux.go index 6d1e5f5..0e9d47a 100644 --- a/internal/platform/router/mux.go +++ b/internal/platform/router/mux.go @@ -3,8 +3,6 @@ package router import ( "fmt" "strings" - - "github.com/wallarm/api-firewall/internal/platform/web" ) var _ Router = &Mux{} @@ -26,7 +24,7 @@ func NewMux() *Mux { // AddEndpoint adds the route `pattern` that matches `method` http method to // execute the `handler` web.Handler. -func (mx *Mux) AddEndpoint(method, pattern string, handler web.Handler) error { +func (mx *Mux) AddEndpoint(method, pattern string, handler Handler) error { m, ok := methodMap[strings.ToUpper(method)] if !ok { return fmt.Errorf("'%s' http method is not supported", method) @@ -45,7 +43,7 @@ func (mx *Mux) Routes() []Route { return mx.tree.routes() } -func (mx *Mux) Find(rctx *Context, method, path string) web.Handler { +func (mx *Mux) Find(rctx *Context, method, path string) Handler { m, ok := methodMap[method] if !ok { return nil @@ -63,7 +61,7 @@ func (mx *Mux) Find(rctx *Context, method, path string) web.Handler { // handle registers a web.Handler in the routing tree for a particular http method // and routing pattern. -func (mx *Mux) handle(method methodTyp, pattern string, handler web.Handler) (*node, error) { +func (mx *Mux) handle(method methodTyp, pattern string, handler Handler) (*node, error) { if len(pattern) == 0 || pattern[0] != '/' { return nil, fmt.Errorf("routing pattern must begin with '/' in '%s'", pattern) } diff --git a/internal/platform/router/mux_test.go b/internal/platform/router/mux_test.go index 5f0917b..bbb886a 100644 --- a/internal/platform/router/mux_test.go +++ b/internal/platform/router/mux_test.go @@ -3,11 +3,11 @@ package router import ( "bytes" "fmt" - "github.com/valyala/fasthttp" - "github.com/wallarm/api-firewall/internal/platform/web" "io" "net/http" "testing" + + "github.com/valyala/fasthttp" ) func TestMuxBasic(t *testing.T) { @@ -314,7 +314,7 @@ func TestMuxRegexp3(t *testing.T) { } func TestMuxSubrouterWildcardParam(t *testing.T) { - h := web.Handler(func(ctx *fasthttp.RequestCtx) error { + h := Handler(func(ctx *fasthttp.RequestCtx) error { ctx.SetBody([]byte(fmt.Sprintf("param:%v *:%v", URLParam(ctx, "param"), URLParam(ctx, "*")))) return nil }) diff --git a/internal/platform/router/tree.go b/internal/platform/router/tree.go index 93b662f..3bfaf5a 100644 --- a/internal/platform/router/tree.go +++ b/internal/platform/router/tree.go @@ -12,7 +12,6 @@ import ( "strings" "github.com/valyala/fasthttp" - "github.com/wallarm/api-firewall/internal/platform/web" ) type methodTyp uint @@ -108,7 +107,7 @@ type endpoints map[methodTyp]*endpoint type endpoint struct { // endpoint handler - handler web.Handler + handler Handler // pattern is the routing pattern for handler nodes pattern string @@ -126,7 +125,7 @@ func (s endpoints) Value(method methodTyp) *endpoint { return mh } -func (n *node) InsertRoute(method methodTyp, pattern string, handler web.Handler) (*node, error) { +func (n *node) InsertRoute(method methodTyp, pattern string, handler Handler) (*node, error) { var parent *node search := pattern @@ -362,7 +361,7 @@ func (n *node) getEdge(ntyp nodeTyp, label, tail byte, prefix string) *node { return nil } -func (n *node) setEndpoint(method methodTyp, handler web.Handler, pattern string) error { +func (n *node) setEndpoint(method methodTyp, handler Handler, pattern string) error { // Set the handler for the method type on the node if n.endpoints == nil { n.endpoints = make(endpoints) @@ -396,7 +395,7 @@ func (n *node) setEndpoint(method methodTyp, handler web.Handler, pattern string return nil } -func (n *node) FindRoute(rctx *Context, method methodTyp, path string) (*node, endpoints, web.Handler) { +func (n *node) FindRoute(rctx *Context, method methodTyp, path string) (*node, endpoints, Handler) { // Reset the context routing pattern and params rctx.routePattern = "" rctx.routeParams.Keys = rctx.routeParams.Keys[:0] @@ -668,7 +667,7 @@ func (n *node) routes() []Route { } for p, mh := range pats { - hs := make(map[string]web.Handler) + hs := make(map[string]Handler) if mh[mALL] != nil && mh[mALL].handler != nil { hs["*"] = mh[mALL].handler } @@ -872,21 +871,21 @@ func (ns nodes) findEdge(label byte) *node { // Handlers map key is an HTTP method type Route struct { SubRoutes Routes - Handlers map[string]web.Handler + Handlers map[string]Handler Pattern string } // WalkFunc is the type of the function called for each method and route visited by Walk. -type WalkFunc func(method string, route string, handler web.Handler, middlewares ...func(web.Handler) web.Handler) error +type WalkFunc func(method string, route string, handler Handler, middlewares ...func(Handler) Handler) error // Walk walks any router tree that implements Routes interface. func Walk(r Routes, walkFn WalkFunc) error { return walk(r, walkFn, "") } -func walk(r Routes, walkFn WalkFunc, parentRoute string, parentMw ...func(web.Handler) web.Handler) error { +func walk(r Routes, walkFn WalkFunc, parentRoute string, parentMw ...func(Handler) Handler) error { for _, route := range r.Routes() { - mws := make([]func(web.Handler) web.Handler, len(parentMw)) + mws := make([]func(Handler) Handler, len(parentMw)) copy(mws, parentMw) if route.SubRoutes != nil { diff --git a/internal/platform/router/tree_test.go b/internal/platform/router/tree_test.go index 6ef0e0d..671e033 100644 --- a/internal/platform/router/tree_test.go +++ b/internal/platform/router/tree_test.go @@ -6,31 +6,30 @@ import ( "testing" "github.com/valyala/fasthttp" - "github.com/wallarm/api-firewall/internal/platform/web" ) func TestTree(t *testing.T) { - hStub := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - hIndex := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - hFavicon := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - hArticleList := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - hArticleNear := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - hArticleShow := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - hArticleShowRelated := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - hArticleShowOpts := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - hArticleSlug := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - hArticleByUser := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - hUserList := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - hUserShow := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - hAdminCatchall := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - hAdminAppShow := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - hAdminAppShowCatchall := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - hUserProfile := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - hUserSuper := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - hUserAll := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - hHubView1 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - hHubView2 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - hHubView3 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hIndex := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hFavicon := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hArticleList := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hArticleNear := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hArticleShow := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hArticleShowRelated := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hArticleShowOpts := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hArticleSlug := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hArticleByUser := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hUserList := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hUserShow := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hAdminCatchall := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hAdminAppShow := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hAdminAppShowCatchall := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hUserProfile := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hUserSuper := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hUserAll := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hHubView1 := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hHubView2 := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hHubView3 := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) tr := &node{} @@ -140,10 +139,10 @@ func TestTree(t *testing.T) { } tests := []struct { - r string // input request path - h web.Handler // output matched handler - k []string // output param keys - v []string // output param values + r string // input request path + h Handler // output matched handler + k []string // output param keys + v []string // output param values }{ {r: "/", h: hIndex, k: []string{}, v: []string{}}, {r: "/favicon.ico", h: hFavicon, k: []string{}, v: []string{}}, @@ -193,7 +192,7 @@ func TestTree(t *testing.T) { _, handlers, _ := tr.FindRoute(rctx, mGET, tt.r) - var handler web.Handler + var handler Handler if methodHandler, ok := handlers[mGET]; ok { handler = methodHandler.handler } @@ -214,23 +213,23 @@ func TestTree(t *testing.T) { } func TestTreeMoar(t *testing.T) { - hStub := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - hStub1 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - hStub2 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - hStub3 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - hStub4 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - hStub5 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - hStub6 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - hStub7 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - hStub8 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - hStub9 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - hStub10 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - hStub11 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - hStub12 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - hStub13 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - hStub14 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - hStub15 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - hStub16 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub1 := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub2 := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub3 := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub4 := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub5 := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub6 := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub7 := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub8 := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub9 := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub10 := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub11 := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub12 := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub13 := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub14 := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub15 := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub16 := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) // TODO: panic if we see {id}{x} because we're missing a delimiter, its not possible. // also {:id}* is not possible. @@ -294,7 +293,7 @@ func TestTreeMoar(t *testing.T) { tr.InsertRoute(mGET, "/users/{id}/settings/*", hStub16) tests := []struct { - h web.Handler + h Handler r string k []string v []string @@ -339,7 +338,7 @@ func TestTreeMoar(t *testing.T) { _, handlers, _ := tr.FindRoute(rctx, tt.m, tt.r) - var handler web.Handler + var handler Handler if methodHandler, ok := handlers[tt.m]; ok { handler = methodHandler.handler } @@ -360,13 +359,13 @@ func TestTreeMoar(t *testing.T) { } func TestTreeRegexp(t *testing.T) { - hStub1 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - hStub2 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - hStub3 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - hStub4 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - hStub5 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - hStub6 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - hStub7 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub1 := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub2 := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub3 := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub4 := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub5 := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub6 := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub7 := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) tr := &node{} if _, err := tr.InsertRoute(mGET, "/articles/{rid:^[0-9]{5,6}}", hStub7); err != nil { @@ -398,10 +397,10 @@ func TestTreeRegexp(t *testing.T) { // log.Println("~~~~~~~~~") tests := []struct { - r string // input request path - h web.Handler // output matched handler - k []string // output param keys - v []string // output param values + r string // input request path + h Handler // output matched handler + k []string // output param keys + v []string // output param values }{ {r: "/articles", h: nil, k: []string{}, v: []string{}}, {r: "/articles/12345", h: hStub7, k: []string{"rid"}, v: []string{"12345"}}, @@ -419,7 +418,7 @@ func TestTreeRegexp(t *testing.T) { _, handlers, _ := tr.FindRoute(rctx, mGET, tt.r) - var handler web.Handler + var handler Handler if methodHandler, ok := handlers[mGET]; ok { handler = methodHandler.handler } @@ -440,8 +439,8 @@ func TestTreeRegexp(t *testing.T) { } func TestTreeRegexpRecursive(t *testing.T) { - hStub1 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - hStub2 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub1 := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub2 := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) tr := &node{} if _, err := tr.InsertRoute(mGET, "/one/{firstId:[a-z0-9-]+}/{secondId:[a-z0-9-]+}/first", hStub1); err != nil { @@ -458,10 +457,10 @@ func TestTreeRegexpRecursive(t *testing.T) { // log.Println("~~~~~~~~~") tests := []struct { - r string // input request path - h web.Handler // output matched handler - k []string // output param keys - v []string // output param values + r string // input request path + h Handler // output matched handler + k []string // output param keys + v []string // output param values }{ {r: "/one/hello/world/first", h: hStub1, k: []string{"firstId", "secondId"}, v: []string{"hello", "world"}}, {r: "/one/hi_there/ok/second", h: hStub2, k: []string{"firstId", "secondId"}, v: []string{"hi_there", "ok"}}, @@ -474,7 +473,7 @@ func TestTreeRegexpRecursive(t *testing.T) { _, handlers, _ := tr.FindRoute(rctx, mGET, tt.r) - var handler web.Handler + var handler Handler if methodHandler, ok := handlers[mGET]; ok { handler = methodHandler.handler } @@ -495,7 +494,7 @@ func TestTreeRegexpRecursive(t *testing.T) { } func TestTreeRegexMatchWholeParam(t *testing.T) { - hStub1 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub1 := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) rctx := NewRouteContext() tr := &node{} @@ -510,7 +509,7 @@ func TestTreeRegexMatchWholeParam(t *testing.T) { } tests := []struct { - expectedHandler web.Handler + expectedHandler Handler url string }{ {url: "/13", expectedHandler: hStub1}, @@ -531,9 +530,9 @@ func TestTreeRegexMatchWholeParam(t *testing.T) { } func TestTreeFindPattern(t *testing.T) { - hStub1 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - hStub2 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - hStub3 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub1 := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub2 := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + hStub3 := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) tr := &node{} if _, err := tr.InsertRoute(mGET, "/pages/*", hStub1); err != nil { @@ -602,8 +601,8 @@ func stringSliceEqual(a, b []string) bool { } func BenchmarkTreeGet(b *testing.B) { - h1 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) - h2 := web.Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + h1 := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) + h2 := Handler(func(ctx *fasthttp.RequestCtx) error { return nil }) tr := &node{} if _, err := tr.InsertRoute(mGET, "/", h1); err != nil { diff --git a/internal/platform/web/adaptor.go b/internal/platform/web/adaptor.go index 3eb5188..ccc2271 100644 --- a/internal/platform/web/adaptor.go +++ b/internal/platform/web/adaptor.go @@ -6,6 +6,7 @@ import ( "github.com/valyala/fasthttp" "github.com/valyala/fasthttp/fasthttpadaptor" + "github.com/wallarm/api-firewall/internal/platform/router" ) // NewFastHTTPHandler wraps net/http handler to fasthttp request handler, @@ -24,7 +25,7 @@ import ( // So it is advisable using this function only for quick net/http -> fasthttp // switching. Then manually convert net/http handlers to fasthttp handlers // according to https://github.com/valyala/fasthttp#switching-from-nethttp-to-fasthttp . -func NewFastHTTPHandler(h http.Handler, isPlayground bool) Handler { +func NewFastHTTPHandler(h http.Handler, isPlayground bool) router.Handler { return func(ctx *fasthttp.RequestCtx) error { var r http.Request if err := fasthttpadaptor.ConvertRequest(ctx, &r, true); err != nil { diff --git a/internal/platform/web/middleware.go b/internal/platform/web/middleware.go index 2490869..f8df3e9 100644 --- a/internal/platform/web/middleware.go +++ b/internal/platform/web/middleware.go @@ -1,14 +1,16 @@ package web +import "github.com/wallarm/api-firewall/internal/platform/router" + // Middleware is a function designed to run some code before and/or after // another Handler. It is designed to remove boilerplate or other concerns not // direct to any given Handler. -type Middleware func(Handler) Handler +type Middleware func(router.Handler) router.Handler // WrapMiddleware creates a new handler by wrapping middleware around a final // handler. The middlewares' Handlers will be executed by requests in the order // they are provided. -func WrapMiddleware(mw []Middleware, handler Handler) Handler { +func WrapMiddleware(mw []Middleware, handler router.Handler) router.Handler { // Loop backwards through the middleware invoking each one. Replace the // handler with the new wrapped handler. Looping backwards ensures that the diff --git a/internal/platform/web/web.go b/internal/platform/web/web.go index 09d9a5e..95e764a 100644 --- a/internal/platform/web/web.go +++ b/internal/platform/web/web.go @@ -2,12 +2,13 @@ package web import ( "bytes" + "github.com/wallarm/api-firewall/internal/platform/router" "os" - "strings" + "runtime/debug" "syscall" - "github.com/fasthttp/router" "github.com/google/uuid" + "github.com/savsgio/gotils/strconv" "github.com/sirupsen/logrus" "github.com/valyala/fasthttp" ) @@ -23,7 +24,7 @@ const ( ValidationBlock = "block" ValidationLog = "log_only" - RequestProxyNoChecks = "proxy_request_no_checks" + PassRequestOPTIONS = "proxy_request_with_options_method" RequestProxyFailed = "proxy_failed" RequestProxyNoRoute = "proxy_no_route" RequestBlocked = "request_blocked" @@ -38,15 +39,11 @@ const ( RequestID = "__wallarm_apifw_request_id" ) -// A Handler is a type that handles an http request within our own little mini -// framework. -type Handler func(ctx *fasthttp.RequestCtx) error - // App is the entrypoint into our application and what configures our context // object for each of our http handlers. Feel free to add any configuration // data/logic on this App struct type App struct { - Router *router.Router + Router *router.Mux Log *logrus.Logger shutdown chan os.Signal mw []Middleware @@ -60,70 +57,26 @@ type AppAdditionalOptions struct { ResponseValidation string CustomBlockStatusCode int OptionsHandler fasthttp.RequestHandler -} - -func (a *App) SetDefaultBehavior(handler Handler, mw ...Middleware) { - // First wrap handler specific middleware around this handler. - handler = WrapMiddleware(mw, handler) - - // Add the application's general middleware to the handler chain. - handler = WrapMiddleware(a.mw, handler) - - customHandler := func(ctx *fasthttp.RequestCtx) { - - // Add request ID - ctx.SetUserValue(RequestID, uuid.NewString()) - - // Block request if it's not found in the route. Not for API mode. - if strings.EqualFold(a.Options.Mode, ProxyMode) { - if strings.EqualFold(a.Options.RequestValidation, ValidationBlock) || strings.EqualFold(a.Options.ResponseValidation, ValidationBlock) { - - ctx.Error("", a.Options.CustomBlockStatusCode) - - a.Log.WithFields(logrus.Fields{ - "request_id": ctx.UserValue(RequestID), - "method": bytes.NewBuffer(ctx.Request.Header.Method()).String(), - "host": string(ctx.Request.Header.Host()), - "path": string(ctx.Path()), - "client_address": ctx.RemoteAddr(), - }).Info("Path or method not found: request blocked") - } - } - - if err := handler(ctx); err != nil { - a.SignalShutdown() - return - } - - } - - //Set NOT FOUND behavior - a.Router.NotFound = customHandler - - // Set Method Not Allowed behavior - a.Router.MethodNotAllowed = customHandler + DefaultHandler router.Handler } // NewApp creates an App value that handle a set of routes for the application. func NewApp(options *AppAdditionalOptions, shutdown chan os.Signal, logger *logrus.Logger, mw ...Middleware) *App { app := App{ - Router: router.New(), + Router: router.NewRouter(), shutdown: shutdown, mw: mw, Log: logger, Options: options, } - app.Router.HandleOPTIONS = options.PassOptions - app.Router.GlobalOPTIONS = options.OptionsHandler - return &app } // Handle is our mechanism for mounting Handlers for a given HTTP verb and path // pair, this makes for really easy, convenient routing. -func (a *App) Handle(method string, path string, handler Handler, mw ...Middleware) { +func (a *App) Handle(method string, path string, handler router.Handler, mw ...Middleware) error { // First wrap handler specific middleware around this handler. handler = WrapMiddleware(mw, handler) @@ -132,24 +85,121 @@ func (a *App) Handle(method string, path string, handler Handler, mw ...Middlewa handler = WrapMiddleware(a.mw, handler) // The function to execute for each request. - h := func(ctx *fasthttp.RequestCtx) { + h := func(ctx *fasthttp.RequestCtx) error { // Add request ID ctx.SetUserValue(RequestID, uuid.NewString()) if err := handler(ctx); err != nil { a.SignalShutdown() - return + return err } + + return nil } - if method == AnyMethod { - a.Router.ANY(path, h) + // Add this handler for the specified verb and route. + //a.Router.Handle(method, path, h) + if err := a.Router.AddEndpoint(method, path, h); err != nil { + return err + } + + return nil +} + +// MainHandler routes request to the OpenAPI validator (handler) +func (a *App) MainHandler(ctx *fasthttp.RequestCtx) { + + // handle panic + defer func() { + if r := recover(); r != nil { + a.Log.Errorf("panic: %v", r) + + // Log the Go stack trace for this panic'd goroutine. + a.Log.Debugf("%s", debug.Stack()) + return + } + }() + + // Add request ID + ctx.SetUserValue(RequestID, uuid.NewString()) + + // find the handler with the OAS information + rctx := router.NewRouteContext() + handler := a.Router.Find(rctx, strconv.B2S(ctx.Method()), strconv.B2S(ctx.Request.URI().Path())) + + // handler not found in the OAS + if handler == nil { + + // OPTIONS methods are passed if the passOPTIONS is set to true + if a.Options.PassOptions == true && strconv.B2S(ctx.Method()) == fasthttp.MethodOptions { + + ctx.SetUserValue(PassRequestOPTIONS, true) + + a.Log.WithFields(logrus.Fields{ + "host": strconv.B2S(ctx.Request.Header.Host()), + "path": strconv.B2S(ctx.Path()), + "method": strconv.B2S(ctx.Request.Header.Method()), + "request_id": ctx.UserValue(RequestID), + }).Debug("Pass request with OPTIONS method") + + // proxy request if passOptions flag is set to true and request method is OPTIONS + if err := a.Options.DefaultHandler(ctx); err != nil { + a.Log.WithFields(logrus.Fields{ + "error": err, + "host": strconv.B2S(ctx.Request.Header.Host()), + "path": strconv.B2S(ctx.Path()), + "method": strconv.B2S(ctx.Request.Header.Method()), + "request_id": ctx.UserValue(RequestID), + }).Error("Error in the request handler") + } + return + } + + a.Log.WithFields(logrus.Fields{ + "request_id": ctx.UserValue(RequestID), + "method": bytes.NewBuffer(ctx.Request.Header.Method()).String(), + "host": string(ctx.Request.Header.Host()), + "path": string(ctx.Path()), + "client_address": ctx.RemoteAddr(), + }).Info("Path or method not found") + + // block request if the GraphQL endpoint not found + if a.Options.Mode == GraphQLMode { + RespondError(ctx, fasthttp.StatusForbidden, "") + return + } + + // handle request by default handler in the endpoint not found in Proxy mode + // Default handler is used to handle request and response validation logic + if err := a.Options.DefaultHandler(ctx); err != nil { + a.Log.WithFields(logrus.Fields{ + "error": err, + "host": strconv.B2S(ctx.Request.Header.Host()), + "path": strconv.B2S(ctx.Path()), + "method": strconv.B2S(ctx.Request.Header.Method()), + "request_id": ctx.UserValue(RequestID), + }).Error("Error in the request handler") + } + return } - // Add this handler for the specified verb and route. - a.Router.Handle(method, path, h) + // add router context to get URL params in the Handler + ctx.SetUserValue(router.RouteCtxKey, rctx) + + if err := handler(ctx); err != nil { + a.Log.WithFields(logrus.Fields{ + "error": err, + "host": strconv.B2S(ctx.Request.Header.Host()), + "path": strconv.B2S(ctx.Path()), + "method": strconv.B2S(ctx.Request.Header.Method()), + "request_id": ctx.UserValue(RequestID), + }).Error("Error in the request handler") + } + + // delete Allow header which is set by the router + ctx.Response.Header.Del(fasthttp.HeaderAllow) } // SignalShutdown is used to gracefully shutdown the app when an integrity From 15cf0e9c8cef784f840a0b8b0090f4daec33d4bc Mon Sep 17 00:00:00 2001 From: Nikolay Tkachenko Date: Sun, 14 Apr 2024 20:44:49 +0300 Subject: [PATCH 08/12] Fix query issues --- .../internal/handlers/api/openapi.go | 9 +- .../internal/handlers/proxy/openapi.go | 7 +- .../internal/updater/wallarm_api2_update.db | Bin 98304 -> 98304 bytes go.mod | 7 +- go.sum | 27 +- internal/platform/loader/router.go | 2 +- internal/platform/validator/internal.go | 8 +- internal/platform/validator/issue201_test.go | 145 +++++++++ internal/platform/validator/issue436_test.go | 135 +++++++++ internal/platform/validator/issue624_test.go | 73 +++++ internal/platform/validator/issue625_test.go | 127 ++++++++ internal/platform/validator/issue639_test.go | 105 +++++++ internal/platform/validator/issue641_test.go | 111 +++++++ internal/platform/validator/issue707_test.go | 93 ++++++ internal/platform/validator/issue722_test.go | 136 +++++++++ internal/platform/validator/issue733_test.go | 111 +++++++ internal/platform/validator/issue789_test.go | 128 ++++++++ internal/platform/validator/issue884_test.go | 116 ++++++++ .../platform/validator/req_resp_decoder.go | 277 +++++++++++++++++- .../validator/req_resp_decoder_test.go | 2 +- .../platform/validator/validate_request.go | 43 +-- .../validator/validate_request_test.go | 38 ++- .../platform/validator/validate_response.go | 4 +- internal/platform/web/web.go | 3 +- 24 files changed, 1631 insertions(+), 76 deletions(-) create mode 100644 internal/platform/validator/issue201_test.go create mode 100644 internal/platform/validator/issue436_test.go create mode 100644 internal/platform/validator/issue624_test.go create mode 100644 internal/platform/validator/issue625_test.go create mode 100644 internal/platform/validator/issue639_test.go create mode 100644 internal/platform/validator/issue641_test.go create mode 100644 internal/platform/validator/issue707_test.go create mode 100644 internal/platform/validator/issue722_test.go create mode 100644 internal/platform/validator/issue733_test.go create mode 100644 internal/platform/validator/issue789_test.go create mode 100644 internal/platform/validator/issue884_test.go diff --git a/cmd/api-firewall/internal/handlers/api/openapi.go b/cmd/api-firewall/internal/handlers/api/openapi.go index 416990e..334526f 100644 --- a/cmd/api-firewall/internal/handlers/api/openapi.go +++ b/cmd/api-firewall/internal/handlers/api/openapi.go @@ -143,10 +143,11 @@ func (s *RequestValidator) Handler(ctx *fasthttp.RequestCtx) error { // Validate request requestValidationInput := &openapi3filter.RequestValidationInput{ - Request: &req, - PathParams: pathParams, - Route: s.CustomRoute.Route, - Options: apiModeSecurityRequirementsOptions, + Request: &req, + PathParams: pathParams, + Route: s.CustomRoute.Route, + QueryParams: req.URL.Query(), + Options: apiModeSecurityRequirementsOptions, } var wg sync.WaitGroup diff --git a/cmd/api-firewall/internal/handlers/proxy/openapi.go b/cmd/api-firewall/internal/handlers/proxy/openapi.go index 93e1061..a20ad3c 100644 --- a/cmd/api-firewall/internal/handlers/proxy/openapi.go +++ b/cmd/api-firewall/internal/handlers/proxy/openapi.go @@ -182,9 +182,10 @@ func (s *openapiWaf) openapiWafHandler(ctx *fasthttp.RequestCtx) error { // Validate request requestValidationInput := &openapi3filter.RequestValidationInput{ - Request: &req, - PathParams: pathParams, - Route: s.customRoute.Route, + Request: &req, + PathParams: pathParams, + Route: s.customRoute.Route, + QueryParams: req.URL.Query(), Options: &openapi3filter.Options{ AuthenticationFunc: func(ctx context.Context, input *openapi3filter.AuthenticationInput) error { switch input.SecurityScheme.Type { diff --git a/cmd/api-firewall/internal/updater/wallarm_api2_update.db b/cmd/api-firewall/internal/updater/wallarm_api2_update.db index d524d1e3d82d849a6c01456d923f8d94983e44bb..c4bfe885d8cb04730960a1e156f5450576f592de 100644 GIT binary patch delta 34 qcmZo@U~6b#n;^|Nd7_LnCNn|I5)2r=GiGFWE77ytmsw+l=F delta 34 qcmZo@U~6b#n;^~DI8nx#v2kNUlq_S|=H0R`LX6j%43-%%1^@uX4+|Us diff --git a/go.mod b/go.mod index ad2bd34..07e1cca 100644 --- a/go.mod +++ b/go.mod @@ -10,13 +10,13 @@ require ( github.com/dgraph-io/ristretto v0.1.1 github.com/fasthttp/websocket v1.5.8 github.com/gabriel-vasile/mimetype v1.4.3 - github.com/getkin/kin-openapi v0.118.0 + github.com/getkin/kin-openapi v0.123.0 github.com/go-playground/validator v9.31.0+incompatible github.com/golang-jwt/jwt v3.2.2+incompatible github.com/golang/mock v1.6.0 github.com/google/uuid v1.6.0 github.com/karlseguin/ccache/v2 v2.0.8 - github.com/klauspost/compress v1.17.7 + github.com/klauspost/compress v1.17.8 github.com/mattn/go-sqlite3 v1.14.22 github.com/pkg/errors v0.9.1 github.com/savsgio/gotils v0.0.0-20240303185622-093b76447511 @@ -26,7 +26,7 @@ require ( github.com/valyala/fastjson v1.6.4 github.com/wundergraph/graphql-go-tools v1.67.2 golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 - golang.org/x/sync v0.6.0 + golang.org/x/sync v0.7.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -70,7 +70,6 @@ require ( github.com/petar-dambovaliev/aho-corasick v0.0.0-20230725210150-fb29fc3c913e // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/r3labs/sse/v2 v2.8.1 // indirect - github.com/rogpeppe/go-internal v1.12.0 // indirect github.com/santhosh-tekuri/jsonschema/v5 v5.3.0 // indirect github.com/tidwall/gjson v1.17.1 // indirect github.com/tidwall/match v1.1.1 // indirect diff --git a/go.sum b/go.sum index 81f4cfc..58feb97 100644 --- a/go.sum +++ b/go.sum @@ -46,16 +46,14 @@ github.com/foxcpp/go-mockdns v1.1.0 h1:jI0rD8M0wuYAxL7r/ynTrCQQq0BVqfB99Vgk7Dlme github.com/foxcpp/go-mockdns v1.1.0/go.mod h1:IhLeSFGed3mJIAXPH2aiRQB+kqz7oqu8ld2qVbOu7Wk= github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= -github.com/getkin/kin-openapi v0.118.0 h1:z43njxPmJ7TaPpMSCQb7PN0dEYno4tyBPQcrFdHoLuM= -github.com/getkin/kin-openapi v0.118.0/go.mod h1:l5e9PaFUo9fyLJCPGQeXI2ML8c3P8BHOEV2VaAVf/pc= +github.com/getkin/kin-openapi v0.123.0 h1:zIik0mRwFNLyvtXK274Q6ut+dPh6nlxBp0x7mNrPhs8= +github.com/getkin/kin-openapi v0.123.0/go.mod h1:wb1aSZA/iWmorQP9KTAS/phLj/t17B5jT7+fS8ed9NM= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= github.com/gin-gonic/gin v1.6.3 h1:ahKqKTFpO5KTPHxWZjEdPScmYaGtLo8Y4DMHoEsnp14= github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M= -github.com/go-openapi/jsonpointer v0.19.5/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= github.com/go-openapi/jsonpointer v0.20.2 h1:mQc3nmndL8ZBzStEo3JYF8wzmeWffDH4VbXz58sAx6Q= github.com/go-openapi/jsonpointer v0.20.2/go.mod h1:bHen+N0u1KEO3YlmqOjTT9Adn1RfD91Ar825/PuiRVs= -github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= github.com/go-openapi/swag v0.22.8 h1:/9RjDSQ0vbFR+NyjGMkFTsA1IA0fmhKSThmfGZjicbw= github.com/go-openapi/swag v0.22.8/go.mod h1:6QT22icPLEqAM/z/TChgb4WAveCHF92+2gF0CNjHpPI= github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= @@ -96,7 +94,6 @@ github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm4 github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= @@ -108,7 +105,6 @@ github.com/huandu/xstrings v1.2.1 h1:v6IdmkCnDhJG/S0ivr58PeIfg+tyhqQYy4YsCsQ0Pdc github.com/huandu/xstrings v1.2.1/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= github.com/imdario/mergo v0.3.8 h1:CGgOkSJeqMRmt0D9XLWExdT4m4F1vd3FV3VPt+0VxkQ= github.com/imdario/mergo v0.3.8/go.mod h1:2EnlNZ0deacrJVfApfmtdGgDfMuh/nq6Ok1EcJh5FfA= -github.com/invopop/yaml v0.1.0/go.mod h1:2XuRLgs/ouIrW3XNzuNj7J3Nvu/Dig5MXvbCEdiBN3Q= github.com/invopop/yaml v0.2.0 h1:7zky/qH+O0DwAyoobXUqvVBwgBFRxKoQ/3FjcVpjTMY= github.com/invopop/yaml v0.2.0/go.mod h1:2XuRLgs/ouIrW3XNzuNj7J3Nvu/Dig5MXvbCEdiBN3Q= github.com/jensneuse/abstractlogger v0.0.4 h1:sa4EH8fhWk3zlTDbSncaWKfwxYM8tYSlQ054ETLyyQY= @@ -129,8 +125,8 @@ github.com/karlseguin/expect v1.0.2-0.20190806010014-778a5f0c6003 h1:vJ0Snvo+SLM github.com/karlseguin/expect v1.0.2-0.20190806010014-778a5f0c6003/go.mod h1:zNBxMY8P21owkeogJELCLeHIt+voOSduHYTFUbwRAV8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.10.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= -github.com/klauspost/compress v1.17.7 h1:ehO88t2UGzQK66LMdE8tibEd1ErmzZjNEqWkjLAKQQg= -github.com/klauspost/compress v1.17.7/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/klauspost/compress v1.17.8 h1:YcnTYrq7MikUT7k0Yb5eceMmALQPYBW/Xltxn0NAMnU= +github.com/klauspost/compress v1.17.8/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= @@ -145,8 +141,6 @@ github.com/logrusorgru/aurora/v3 v3.0.0 h1:R6zcoZZbvVcGMvDCKo45A9U/lzYyzl5NfYIvz github.com/logrusorgru/aurora/v3 v3.0.0/go.mod h1:vsR12bk5grlLvLXAYrBsb5Oc/N+LxAlxggSjiwMnCUc= github.com/magefile/mage v1.15.0 h1:BvGheCMAsG3bWUDbZ8AyXXpCNwU9u5CB6sM+HNb9HYg= github.com/magefile/mage v1.15.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= -github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= -github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= @@ -183,7 +177,6 @@ github.com/nats-io/nkeys v0.4.7 h1:RwNJbbIdYCoClSDNY7QVKZlyb/wfT6ugvFCiKy6vDvI= github.com/nats-io/nkeys v0.4.7/go.mod h1:kqXRgRDPlGy7nGaEDMuYzmiJCIAAWDK0IMBtDmGD0nc= github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= -github.com/perimeterx/marshmallow v1.1.4/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw= github.com/perimeterx/marshmallow v1.1.5 h1:a2LALqQ1BlHM8PZblsDdidgv1mWi1DgC2UmX50IvK2s= github.com/perimeterx/marshmallow v1.1.5/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw= github.com/petar-dambovaliev/aho-corasick v0.0.0-20230725210150-fb29fc3c913e h1:POJco99aNgosh92lGqmx7L1ei+kCymivB/419SD15PQ= @@ -211,15 +204,10 @@ github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/tidwall/gjson v1.17.1 h1:wlYEnwqAHgzmhNUFfw7Xalt2JzQvsMx2Se4PcoFCT/U= @@ -231,9 +219,8 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.0.4 h1:UcdIRXff12Lpnu3OLtZvnc03g4vH2suXDXhBwBqmzYg= github.com/tidwall/sjson v1.0.4/go.mod h1:bURseu1nuBkFpIES5cz6zBtjmYeOQmEESshn7VpF15Y= +github.com/ugorji/go v1.1.7 h1:/68gy2h+1mWMrwZFeD1kQialdSzAb432dtpeJ42ovdo= github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= -github.com/ugorji/go v1.2.7 h1:qYhyWUUd6WbiM+C6JZAUkIJt/1WrjzNHY9+KCIjVqTo= -github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M= github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0= github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY= @@ -289,8 +276,8 @@ golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc= golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= -golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190130150945-aca44879d564/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/internal/platform/loader/router.go b/internal/platform/loader/router.go index 072a5f8..e204a6b 100644 --- a/internal/platform/loader/router.go +++ b/internal/platform/loader/router.go @@ -34,7 +34,7 @@ func NewRouter(doc *openapi3.T, validate bool) (*Router, error) { var router Router - for path, pathItem := range doc.Paths { + for path, pathItem := range doc.Paths.Map() { for method, operation := range pathItem.Operations() { method = strings.ToUpper(method) route := routers.Route{ diff --git a/internal/platform/validator/internal.go b/internal/platform/validator/internal.go index 588b5c1..6344bae 100644 --- a/internal/platform/validator/internal.go +++ b/internal/platform/validator/internal.go @@ -1,6 +1,7 @@ package validator import ( + "encoding/json" "reflect" "strings" @@ -54,12 +55,7 @@ func convertToMap(v *fastjson.Value) interface{} { } return a case fastjson.TypeNumber: - valueInt := v.GetInt64() - valueFloat := v.GetFloat64() - if valueFloat == float64(int(valueFloat)) { - return valueInt - } - return valueFloat + return json.Number(v.String()) case fastjson.TypeString: return string(v.GetStringBytes()) case fastjson.TypeTrue, fastjson.TypeFalse: diff --git a/internal/platform/validator/issue201_test.go b/internal/platform/validator/issue201_test.go new file mode 100644 index 0000000..c5d2983 --- /dev/null +++ b/internal/platform/validator/issue201_test.go @@ -0,0 +1,145 @@ +package validator + +import ( + "context" + "io" + "net/http" + "strings" + "testing" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/getkin/kin-openapi/openapi3filter" + "github.com/getkin/kin-openapi/routers/gorillamux" + "github.com/stretchr/testify/require" + "github.com/valyala/fastjson" +) + +func TestIssue201(t *testing.T) { + loader := openapi3.NewLoader() + ctx := loader.Context + spec := ` +openapi: '3.0.3' +info: + version: 1.0.0 + title: Sample API +paths: + /_: + get: + description: '' + responses: + default: + description: '' + content: + application/json: + schema: + type: object + headers: + X-Blip: + description: '' + required: true + schema: + type: string + pattern: '^blip$' + x-blop: + description: '' + schema: + type: string + pattern: '^blop$' + X-Blap: + description: '' + required: true + schema: + type: string + pattern: '^blap$' + X-Blup: + description: '' + required: true + schema: + type: string + pattern: '^blup$' +`[1:] + + doc, err := loader.LoadFromData([]byte(spec)) + require.NoError(t, err) + + err = doc.Validate(ctx) + require.NoError(t, err) + + for name, testcase := range map[string]struct { + headers map[string]string + err string + }{ + + "no error": { + headers: map[string]string{ + "X-Blip": "blip", + "x-blop": "blop", + "X-Blap": "blap", + "X-Blup": "blup", + }, + }, + + "missing non-required header": { + headers: map[string]string{ + "X-Blip": "blip", + // "x-blop": "blop", + "X-Blap": "blap", + "X-Blup": "blup", + }, + }, + + "missing required header": { + err: `response header "X-Blip" missing`, + headers: map[string]string{ + // "X-Blip": "blip", + "x-blop": "blop", + "X-Blap": "blap", + "X-Blup": "blup", + }, + }, + + "invalid required header": { + err: `response header "X-Blup" doesn't match schema: string doesn't match the regular expression "^blup$"`, + headers: map[string]string{ + "X-Blip": "blip", + "x-blop": "blop", + "X-Blap": "blap", + "X-Blup": "bluuuuuup", + }, + }, + } { + t.Run(name, func(t *testing.T) { + router, err := gorillamux.NewRouter(doc) + require.NoError(t, err) + + r, err := http.NewRequest(http.MethodGet, `/_`, nil) + require.NoError(t, err) + + r.Header.Add(headerCT, "application/json") + for k, v := range testcase.headers { + r.Header.Add(k, v) + } + + route, pathParams, err := router.FindRoute(r) + require.NoError(t, err) + + jsonParser := &fastjson.Parser{} + + err = ValidateResponse(context.Background(), &openapi3filter.ResponseValidationInput{ + RequestValidationInput: &openapi3filter.RequestValidationInput{ + Request: r, + PathParams: pathParams, + Route: route, + }, + Status: 200, + Header: r.Header, + Body: io.NopCloser(strings.NewReader(`{}`)), + }, jsonParser) + if e := testcase.err; e != "" { + require.ErrorContains(t, err, e) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/internal/platform/validator/issue436_test.go b/internal/platform/validator/issue436_test.go new file mode 100644 index 0000000..952e6a0 --- /dev/null +++ b/internal/platform/validator/issue436_test.go @@ -0,0 +1,135 @@ +package validator + +import ( + "bytes" + "context" + "io" + "mime/multipart" + "net/http" + "net/textproto" + "strings" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/getkin/kin-openapi/openapi3filter" + "github.com/getkin/kin-openapi/routers/gorillamux" +) + +func Example_validateMultipartFormData() { + const spec = ` +openapi: 3.0.0 +info: + title: 'Validator' + version: 0.0.1 +paths: + /test: + post: + requestBody: + required: true + content: + multipart/form-data: + schema: + type: object + required: + - file + properties: + file: + type: string + format: binary + categories: + type: array + items: + $ref: "#/components/schemas/Category" + responses: + '200': + description: Created + +components: + schemas: + Category: + type: object + properties: + name: + type: string + required: + - name +` + + loader := openapi3.NewLoader() + doc, err := loader.LoadFromData([]byte(spec)) + if err != nil { + panic(err) + } + if err = doc.Validate(loader.Context); err != nil { + panic(err) + } + + router, err := gorillamux.NewRouter(doc) + if err != nil { + panic(err) + } + + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + + { // Add a single "categories" item as part data + h := make(textproto.MIMEHeader) + h.Set("Content-Disposition", `form-data; name="categories"`) + h.Set("Content-Type", "application/json") + fw, err := writer.CreatePart(h) + if err != nil { + panic(err) + } + if _, err = io.Copy(fw, strings.NewReader(`{"name": "foo"}`)); err != nil { + panic(err) + } + } + + { // Add a single "categories" item as part data, again + h := make(textproto.MIMEHeader) + h.Set("Content-Disposition", `form-data; name="categories"`) + h.Set("Content-Type", "application/json") + fw, err := writer.CreatePart(h) + if err != nil { + panic(err) + } + if _, err = io.Copy(fw, strings.NewReader(`{"name": "bar"}`)); err != nil { + panic(err) + } + } + + { // Add file data + fw, err := writer.CreateFormFile("file", "hello.txt") + if err != nil { + panic(err) + } + if _, err = io.Copy(fw, strings.NewReader("hello")); err != nil { + panic(err) + } + } + + writer.Close() + + req, err := http.NewRequest(http.MethodPost, "/test", bytes.NewReader(body.Bytes())) + if err != nil { + panic(err) + } + req.Header.Set("Content-Type", writer.FormDataContentType()) + + route, pathParams, err := router.FindRoute(req) + if err != nil { + panic(err) + } + + if err = openapi3filter.ValidateRequestBody( + context.Background(), + &openapi3filter.RequestValidationInput{ + Request: req, + PathParams: pathParams, + Route: route, + }, + route.Operation.RequestBody.Value, + ); err != nil { + panic(err) + } + // Output: +} diff --git a/internal/platform/validator/issue624_test.go b/internal/platform/validator/issue624_test.go new file mode 100644 index 0000000..8f0ed32 --- /dev/null +++ b/internal/platform/validator/issue624_test.go @@ -0,0 +1,73 @@ +package validator + +import ( + "net/http" + "testing" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/getkin/kin-openapi/openapi3filter" + "github.com/getkin/kin-openapi/routers/gorillamux" + "github.com/stretchr/testify/require" + "github.com/valyala/fastjson" +) + +func TestIssue624(t *testing.T) { + loader := openapi3.NewLoader() + ctx := loader.Context + spec := ` +openapi: 3.0.0 +info: + version: 1.0.0 + title: Sample API +paths: + /items: + get: + description: Returns a list of stuff + parameters: + - description: "test non object" + explode: true + style: form + in: query + name: test + required: false + content: + application/json: + schema: + anyOf: + - type: string + - type: integer + responses: + '200': + description: Successful response +`[1:] + + doc, err := loader.LoadFromData([]byte(spec)) + require.NoError(t, err) + + err = doc.Validate(ctx) + require.NoError(t, err) + + router, err := gorillamux.NewRouter(doc) + require.NoError(t, err) + + for _, testcase := range []string{`test1`, `test[1`} { + t.Run(testcase, func(t *testing.T) { + httpReq, err := http.NewRequest(http.MethodGet, `/items?test=`+testcase, nil) + require.NoError(t, err) + + route, pathParams, err := router.FindRoute(httpReq) + require.NoError(t, err) + + requestValidationInput := &openapi3filter.RequestValidationInput{ + Request: httpReq, + PathParams: pathParams, + Route: route, + } + + jsonParser := &fastjson.Parser{} + + err = ValidateRequest(ctx, requestValidationInput, jsonParser) + require.NoError(t, err) + }) + } +} diff --git a/internal/platform/validator/issue625_test.go b/internal/platform/validator/issue625_test.go new file mode 100644 index 0000000..fcf4e26 --- /dev/null +++ b/internal/platform/validator/issue625_test.go @@ -0,0 +1,127 @@ +package validator + +import ( + "github.com/valyala/fastjson" + "net/http" + "strings" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/getkin/kin-openapi/openapi3filter" + "github.com/getkin/kin-openapi/routers/gorillamux" +) + +func TestIssue625(t *testing.T) { + + anyOfArraySpec := ` +openapi: 3.0.0 +info: + version: 1.0.0 + title: Sample API +paths: + /items: + get: + description: Returns a list of stuff + parameters: + - description: test object + explode: false + in: query + name: test + required: false + schema: + type: array + items: + anyOf: + - type: integer + - type: boolean + responses: + '200': + description: Successful response +`[1:] + + oneOfArraySpec := strings.ReplaceAll(anyOfArraySpec, "anyOf", "oneOf") + + allOfArraySpec := strings.ReplaceAll(strings.ReplaceAll(anyOfArraySpec, "anyOf", "allOf"), + "type: boolean", "type: number") + + tests := []struct { + name string + spec string + req string + errStr string + }{ + { + name: "success anyof object array", + spec: anyOfArraySpec, + req: "/items?test=3,7", + }, + { + name: "failed anyof object array", + spec: anyOfArraySpec, + req: "/items?test=s1,s2", + errStr: `parameter "test" in query has an error: path 0: value s1: an invalid boolean: invalid syntax`, + }, + + { + name: "success allof object array", + spec: allOfArraySpec, + req: `/items?test=1,3`, + }, + { + name: "failed allof object array", + spec: allOfArraySpec, + req: `/items?test=1.2,3.1`, + errStr: `parameter "test" in query has an error: path 0: value 1.2: an invalid integer: invalid syntax`, + }, + { + name: "success oneof object array", + spec: oneOfArraySpec, + req: `/items?test=true,3`, + }, + { + name: "failed oneof object array", + spec: oneOfArraySpec, + req: `/items?test="val1","val2"`, + errStr: `parameter "test" in query has an error: item 0: decoding oneOf failed: 0 schemas matched`, + }, + } + + for _, testcase := range tests { + t.Run(testcase.name, func(t *testing.T) { + loader := openapi3.NewLoader() + ctx := loader.Context + + doc, err := loader.LoadFromData([]byte(testcase.spec)) + require.NoError(t, err) + + err = doc.Validate(ctx) + require.NoError(t, err) + + router, err := gorillamux.NewRouter(doc) + require.NoError(t, err) + httpReq, err := http.NewRequest(http.MethodGet, testcase.req, nil) + require.NoError(t, err) + + route, pathParams, err := router.FindRoute(httpReq) + require.NoError(t, err) + + requestValidationInput := &openapi3filter.RequestValidationInput{ + Request: httpReq, + PathParams: pathParams, + Route: route, + } + + jsonParser := &fastjson.Parser{} + + err = ValidateRequest(ctx, requestValidationInput, jsonParser) + if testcase.errStr == "" { + require.NoError(t, err) + } else { + require.ErrorContains(t, err, testcase.errStr) + } + }, + ) + } +} diff --git a/internal/platform/validator/issue639_test.go b/internal/platform/validator/issue639_test.go new file mode 100644 index 0000000..5ec7cf0 --- /dev/null +++ b/internal/platform/validator/issue639_test.go @@ -0,0 +1,105 @@ +package validator + +import ( + "encoding/json" + "github.com/getkin/kin-openapi/openapi3filter" + "github.com/valyala/fastjson" + "io" + "net/http" + "strings" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/getkin/kin-openapi/routers/gorillamux" +) + +func TestIssue639(t *testing.T) { + loader := openapi3.NewLoader() + ctx := loader.Context + spec := ` + openapi: 3.0.0 + info: + version: 1.0.0 + title: Sample API + paths: + /items: + put: + requestBody: + content: + application/json: + schema: + properties: + testWithdefault: + default: false + type: boolean + testNoDefault: + type: boolean + type: object + responses: + '200': + description: Successful response +`[1:] + + doc, err := loader.LoadFromData([]byte(spec)) + require.NoError(t, err) + + err = doc.Validate(ctx) + require.NoError(t, err) + + router, err := gorillamux.NewRouter(doc) + require.NoError(t, err) + + tests := []struct { + name string + options *openapi3filter.Options + expectedDefaultVal interface{} + }{ + { + name: "no defaults are added to requests", + options: &openapi3filter.Options{ + SkipSettingDefaults: true, + }, + expectedDefaultVal: nil, + }, + + { + name: "defaults are added to requests", + expectedDefaultVal: false, + }, + } + + for _, testcase := range tests { + t.Run(testcase.name, func(t *testing.T) { + body := "{\"testNoDefault\": true}" + httpReq, err := http.NewRequest(http.MethodPut, "/items", strings.NewReader(body)) + require.NoError(t, err) + httpReq.Header.Set("Content-Type", "application/json") + require.NoError(t, err) + + route, pathParams, err := router.FindRoute(httpReq) + require.NoError(t, err) + + requestValidationInput := &openapi3filter.RequestValidationInput{ + Request: httpReq, + PathParams: pathParams, + Route: route, + Options: testcase.options, + } + + jsonParser := &fastjson.Parser{} + + err = ValidateRequest(ctx, requestValidationInput, jsonParser) + require.NoError(t, err) + bodyAfterValidation, err := io.ReadAll(httpReq.Body) + require.NoError(t, err) + + raw := map[string]interface{}{} + err = json.Unmarshal(bodyAfterValidation, &raw) + require.NoError(t, err) + require.Equal(t, testcase.expectedDefaultVal, + raw["testWithdefault"], "default value must not be included") + }) + } +} diff --git a/internal/platform/validator/issue641_test.go b/internal/platform/validator/issue641_test.go new file mode 100644 index 0000000..90ac3ea --- /dev/null +++ b/internal/platform/validator/issue641_test.go @@ -0,0 +1,111 @@ +package validator + +import ( + "net/http" + "strings" + "testing" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/getkin/kin-openapi/openapi3filter" + "github.com/getkin/kin-openapi/routers/gorillamux" + "github.com/stretchr/testify/require" + "github.com/valyala/fastjson" +) + +func TestIssue641(t *testing.T) { + + anyOfSpec := ` +openapi: 3.0.0 +info: + version: 1.0.0 + title: Sample API +paths: + /items: + get: + description: Returns a list of stuff + parameters: + - description: test object + explode: false + in: query + name: test + required: false + schema: + anyOf: + - pattern: "^[0-9]{1,4}$" + - pattern: "^[0-9]{1,4}$" + type: string + responses: + '200': + description: Successful response +`[1:] + + allOfSpec := strings.ReplaceAll(anyOfSpec, "anyOf", "allOf") + + tests := []struct { + name string + spec string + req string + errStr string + }{ + + { + name: "success anyof pattern", + spec: anyOfSpec, + req: "/items?test=51", + }, + { + name: "failed anyof pattern", + spec: anyOfSpec, + req: "/items?test=999999", + errStr: `parameter "test" in query has an error: doesn't match any schema from "anyOf"`, + }, + + { + name: "success allof pattern", + spec: allOfSpec, + req: `/items?test=51`, + }, + { + name: "failed allof pattern", + spec: allOfSpec, + req: `/items?test=999999`, + errStr: `parameter "test" in query has an error: string doesn't match the regular expression "^[0-9]{1,4}$"`, + }, + } + + for _, testcase := range tests { + t.Run(testcase.name, func(t *testing.T) { + loader := openapi3.NewLoader() + ctx := loader.Context + + doc, err := loader.LoadFromData([]byte(testcase.spec)) + require.NoError(t, err) + + err = doc.Validate(ctx) + require.NoError(t, err) + + router, err := gorillamux.NewRouter(doc) + require.NoError(t, err) + httpReq, err := http.NewRequest(http.MethodGet, testcase.req, nil) + require.NoError(t, err) + + route, pathParams, err := router.FindRoute(httpReq) + require.NoError(t, err) + + requestValidationInput := &openapi3filter.RequestValidationInput{ + Request: httpReq, + PathParams: pathParams, + Route: route, + } + + jsonParser := &fastjson.Parser{} + err = ValidateRequest(ctx, requestValidationInput, jsonParser) + if testcase.errStr == "" { + require.NoError(t, err) + } else { + require.ErrorContains(t, err, testcase.errStr) + } + }, + ) + } +} diff --git a/internal/platform/validator/issue707_test.go b/internal/platform/validator/issue707_test.go new file mode 100644 index 0000000..61d1bc1 --- /dev/null +++ b/internal/platform/validator/issue707_test.go @@ -0,0 +1,93 @@ +package validator + +import ( + "net/http" + "strings" + "testing" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/getkin/kin-openapi/openapi3filter" + "github.com/getkin/kin-openapi/routers/gorillamux" + "github.com/stretchr/testify/require" + "github.com/valyala/fastjson" +) + +func TestIssue707(t *testing.T) { + loader := openapi3.NewLoader() + ctx := loader.Context + spec := ` +openapi: 3.0.0 +info: + version: 1.0.0 + title: Sample API +paths: + /items: + get: + description: Returns a list of stuff + parameters: + - description: parameter with a default value + explode: true + in: query + name: param-with-default + schema: + default: 124 + type: integer + required: false + responses: + '200': + description: Successful response +`[1:] + + doc, err := loader.LoadFromData([]byte(spec)) + require.NoError(t, err) + + err = doc.Validate(ctx) + require.NoError(t, err) + + router, err := gorillamux.NewRouter(doc) + require.NoError(t, err) + + tests := []struct { + name string + options *openapi3filter.Options + expectedQuery string + }{ + { + name: "no defaults are added to requests parameters", + options: &openapi3filter.Options{ + SkipSettingDefaults: true, + }, + expectedQuery: "", + }, + + { + name: "defaults are added to requests", + expectedQuery: "param-with-default=124", + }, + } + + for _, testcase := range tests { + t.Run(testcase.name, func(t *testing.T) { + httpReq, err := http.NewRequest(http.MethodGet, "/items", strings.NewReader("")) + require.NoError(t, err) + + route, pathParams, err := router.FindRoute(httpReq) + require.NoError(t, err) + + requestValidationInput := &openapi3filter.RequestValidationInput{ + Request: httpReq, + PathParams: pathParams, + Route: route, + Options: testcase.options, + } + + jsonParser := &fastjson.Parser{} + err = ValidateRequest(ctx, requestValidationInput, jsonParser) + require.NoError(t, err) + + require.NoError(t, err) + require.Equal(t, testcase.expectedQuery, + httpReq.URL.RawQuery, "default value must not be included") + }) + } +} diff --git a/internal/platform/validator/issue722_test.go b/internal/platform/validator/issue722_test.go new file mode 100644 index 0000000..db009a6 --- /dev/null +++ b/internal/platform/validator/issue722_test.go @@ -0,0 +1,136 @@ +package validator + +import ( + "bytes" + "context" + "io" + "mime/multipart" + "net/http" + "net/textproto" + "strings" + "testing" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/getkin/kin-openapi/openapi3filter" + "github.com/getkin/kin-openapi/routers/gorillamux" + "github.com/valyala/fastjson" +) + +func TestValidateMultipartFormDataContainingAllOf(t *testing.T) { + const spec = ` +openapi: 3.0.0 +info: + title: 'Validator' + version: 0.0.1 +paths: + /test: + post: + requestBody: + required: true + content: + multipart/form-data: + schema: + type: object + required: + - file + allOf: + - $ref: '#/components/schemas/Category' + - properties: + file: + type: string + format: binary + description: + type: string + responses: + '200': + description: Created + +components: + schemas: + Category: + type: object + properties: + name: + type: string + required: + - name +` + + loader := openapi3.NewLoader() + doc, err := loader.LoadFromData([]byte(spec)) + if err != nil { + t.Fatal(err) + } + if err = doc.Validate(loader.Context); err != nil { + t.Fatal(err) + } + + router, err := gorillamux.NewRouter(doc) + if err != nil { + t.Fatal(err) + } + + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + + { // Add file data + fw, err := writer.CreateFormFile("file", "hello.txt") + if err != nil { + t.Fatal(err) + } + if _, err = io.Copy(fw, strings.NewReader("hello")); err != nil { + t.Fatal(err) + } + } + + { // Add a single "name" item as part data + h := make(textproto.MIMEHeader) + h.Set("Content-Disposition", `form-data; name="name"`) + fw, err := writer.CreatePart(h) + if err != nil { + t.Fatal(err) + } + if _, err = io.Copy(fw, strings.NewReader(`foo`)); err != nil { + t.Fatal(err) + } + } + + { // Add a single "description" item as part data + h := make(textproto.MIMEHeader) + h.Set("Content-Disposition", `form-data; name="description"`) + fw, err := writer.CreatePart(h) + if err != nil { + t.Fatal(err) + } + if _, err = io.Copy(fw, strings.NewReader(`description note`)); err != nil { + t.Fatal(err) + } + } + + writer.Close() + + req, err := http.NewRequest(http.MethodPost, "/test", bytes.NewReader(body.Bytes())) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", writer.FormDataContentType()) + + route, pathParams, err := router.FindRoute(req) + if err != nil { + t.Fatal(err) + } + + jsonParser := &fastjson.Parser{} + if err = ValidateRequestBody( + context.Background(), + &openapi3filter.RequestValidationInput{ + Request: req, + PathParams: pathParams, + Route: route, + }, + route.Operation.RequestBody.Value, + jsonParser, + ); err != nil { + t.Error(err) + } +} diff --git a/internal/platform/validator/issue733_test.go b/internal/platform/validator/issue733_test.go new file mode 100644 index 0000000..e2a28d9 --- /dev/null +++ b/internal/platform/validator/issue733_test.go @@ -0,0 +1,111 @@ +package validator + +import ( + "bytes" + "context" + "encoding/json" + "math" + "math/big" + "net/http" + "testing" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/getkin/kin-openapi/openapi3filter" + "github.com/getkin/kin-openapi/routers/gorillamux" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/valyala/fastjson" +) + +func TestIntMax(t *testing.T) { + spec := ` +openapi: 3.0.0 +info: + version: 1.0.0 + title: test large integer value +paths: + /test: + post: + requestBody: + content: + application/json: + schema: + type: object + properties: + testInteger: + type: integer + format: int64 + testDefault: + type: boolean + default: false + responses: + '200': + description: Successful response +`[1:] + + loader := openapi3.NewLoader() + + doc, err := loader.LoadFromData([]byte(spec)) + require.NoError(t, err) + + err = doc.Validate(loader.Context) + require.NoError(t, err) + + router, err := gorillamux.NewRouter(doc) + require.NoError(t, err) + + testOne := func(value *big.Int, pass bool) { + valueString := value.String() + + req, err := http.NewRequest(http.MethodPost, "/test", bytes.NewReader([]byte(`{"testInteger":`+valueString+`}`))) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + + route, pathParams, err := router.FindRoute(req) + require.NoError(t, err) + + jsonParser := &fastjson.Parser{} + err = ValidateRequest( + context.Background(), + &openapi3filter.RequestValidationInput{ + Request: req, + PathParams: pathParams, + Route: route, + }, + jsonParser) + if pass { + require.NoError(t, err) + + dec := json.NewDecoder(req.Body) + dec.UseNumber() + var jsonAfter map[string]interface{} + err = dec.Decode(&jsonAfter) + require.NoError(t, err) + + valueAfter := jsonAfter["testInteger"] + require.IsType(t, json.Number(""), valueAfter) + assert.Equal(t, valueString, string(valueAfter.(json.Number))) + } else { + if assert.Error(t, err) { + var serr *openapi3.SchemaError + if assert.ErrorAs(t, err, &serr) { + assert.Equal(t, "number must be an int64", serr.Reason) + } + } + } + } + + bigMaxInt64 := big.NewInt(math.MaxInt64) + bigMaxInt64Plus1 := new(big.Int).Add(bigMaxInt64, big.NewInt(1)) + bigMinInt64 := big.NewInt(math.MinInt64) + bigMinInt64Minus1 := new(big.Int).Sub(bigMinInt64, big.NewInt(1)) + + testOne(bigMaxInt64, true) + // XXX not yet fixed + // testOne(bigMaxInt64Plus1, false) + testOne(bigMaxInt64Plus1, true) + testOne(bigMinInt64, true) + // XXX not yet fixed + // testOne(bigMinInt64Minus1, false) + testOne(bigMinInt64Minus1, true) +} diff --git a/internal/platform/validator/issue789_test.go b/internal/platform/validator/issue789_test.go new file mode 100644 index 0000000..fac7524 --- /dev/null +++ b/internal/platform/validator/issue789_test.go @@ -0,0 +1,128 @@ +package validator + +import ( + "net/http" + "strings" + "testing" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/getkin/kin-openapi/openapi3filter" + "github.com/getkin/kin-openapi/routers/gorillamux" + "github.com/stretchr/testify/require" + "github.com/valyala/fastjson" +) + +func TestIssue789(t *testing.T) { + anyOfArraySpec := ` +openapi: 3.0.0 +info: + version: 1.0.0 + title: Sample API +paths: + /items: + get: + description: Returns a list of stuff + parameters: + - description: test object + explode: false + in: query + name: test + required: true + schema: + type: string + anyOf: + - pattern: '\babc\b' + - pattern: '\bfoo\b' + - pattern: '\bbar\b' + responses: + '200': + description: Successful response +`[1:] + + oneOfArraySpec := strings.ReplaceAll(anyOfArraySpec, "anyOf", "oneOf") + + allOfArraySpec := strings.ReplaceAll(anyOfArraySpec, "anyOf", "allOf") + + tests := []struct { + name string + spec string + req string + errStr string + }{ + { + name: "success anyof string pattern match", + spec: anyOfArraySpec, + req: "/items?test=abc", + }, + { + name: "failed anyof string pattern match", + spec: anyOfArraySpec, + req: "/items?test=def", + errStr: `parameter "test" in query has an error: doesn't match any schema from "anyOf"`, + }, + { + name: "success allof object array", + spec: allOfArraySpec, + req: `/items?test=abc foo bar`, + }, + { + name: "failed allof object array", + spec: allOfArraySpec, + req: `/items?test=foo`, + errStr: `parameter "test" in query has an error: string doesn't match the regular expression`, + }, + { + name: "success oneof string pattern match", + spec: oneOfArraySpec, + req: `/items?test=foo`, + }, + { + name: "failed oneof string pattern match", + spec: oneOfArraySpec, + req: `/items?test=def`, + errStr: `parameter "test" in query has an error: doesn't match schema due to: string doesn't match the regular expression`, + }, + { + name: "failed oneof string pattern match", + spec: oneOfArraySpec, + req: `/items?test=foo bar`, + errStr: `parameter "test" in query has an error: input matches more than one oneOf schemas`, + }, + } + + for _, testcase := range tests { + t.Run(testcase.name, func(t *testing.T) { + loader := openapi3.NewLoader() + ctx := loader.Context + + doc, err := loader.LoadFromData([]byte(testcase.spec)) + require.NoError(t, err) + + err = doc.Validate(ctx) + require.NoError(t, err) + + router, err := gorillamux.NewRouter(doc) + require.NoError(t, err) + httpReq, err := http.NewRequest(http.MethodGet, testcase.req, nil) + require.NoError(t, err) + + route, pathParams, err := router.FindRoute(httpReq) + require.NoError(t, err) + + requestValidationInput := &openapi3filter.RequestValidationInput{ + Request: httpReq, + PathParams: pathParams, + Route: route, + } + + jsonParser := &fastjson.Parser{} + err = ValidateRequest(ctx, requestValidationInput, jsonParser) + if testcase.errStr == "" { + require.NoError(t, err) + } else { + require.ErrorContains(t, err, testcase.errStr) + } + }, + ) + } +} diff --git a/internal/platform/validator/issue884_test.go b/internal/platform/validator/issue884_test.go new file mode 100644 index 0000000..7201d5d --- /dev/null +++ b/internal/platform/validator/issue884_test.go @@ -0,0 +1,116 @@ +package validator + +import ( + "github.com/getkin/kin-openapi/openapi3filter" + "github.com/valyala/fastjson" + "net/http" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/getkin/kin-openapi/routers/gorillamux" +) + +func TestIssue884(t *testing.T) { + loader := openapi3.NewLoader() + ctx := loader.Context + spec := ` + openapi: 3.0.0 + info: + version: 1.0.0 + title: Sample API + components: + schemas: + TaskSortEnum: + enum: + - createdAt + - -createdAt + - updatedAt + - -updatedAt + paths: + /tasks: + get: + operationId: ListTask + parameters: + - in: query + name: withDefault + schema: + allOf: + - $ref: '#/components/schemas/TaskSortEnum' + - default: -createdAt + - in: query + name: withoutDefault + schema: + allOf: + - $ref: '#/components/schemas/TaskSortEnum' + - in: query + name: withManyDefaults + schema: + allOf: + - default: -updatedAt + - $ref: '#/components/schemas/TaskSortEnum' + - default: -createdAt + responses: + '200': + description: Successful response + `[1:] + + doc, err := loader.LoadFromData([]byte(spec)) + require.NoError(t, err) + + err = doc.Validate(ctx) + require.NoError(t, err) + + router, err := gorillamux.NewRouter(doc) + require.NoError(t, err) + + tests := []struct { + name string + options *openapi3filter.Options + expectedQuery url.Values + }{ + { + name: "no defaults are added to requests", + options: &openapi3filter.Options{ + SkipSettingDefaults: true, + }, + expectedQuery: url.Values{}, + }, + + { + name: "defaults are added to requests", + expectedQuery: url.Values{ + "withDefault": []string{"-createdAt"}, + "withManyDefaults": []string{"-updatedAt"}, // first default is win + }, + }, + } + + for _, testcase := range tests { + t.Run(testcase.name, func(t *testing.T) { + httpReq, err := http.NewRequest(http.MethodGet, "/tasks", nil) + require.NoError(t, err) + httpReq.Header.Set("Content-Type", "application/json") + require.NoError(t, err) + + route, pathParams, err := router.FindRoute(httpReq) + require.NoError(t, err) + + jsonParser := &fastjson.Parser{} + requestValidationInput := &openapi3filter.RequestValidationInput{ + Request: httpReq, + PathParams: pathParams, + Route: route, + Options: testcase.options, + } + err = ValidateRequest(ctx, requestValidationInput, jsonParser) + require.NoError(t, err) + + q := httpReq.URL.Query() + assert.Equal(t, testcase.expectedQuery, q) + }) + } +} diff --git a/internal/platform/validator/req_resp_decoder.go b/internal/platform/validator/req_resp_decoder.go index 7452865..5083e9a 100644 --- a/internal/platform/validator/req_resp_decoder.go +++ b/internal/platform/validator/req_resp_decoder.go @@ -300,10 +300,8 @@ func decodeValue(dec valueDecoder, param string, sm *openapi3.SerializationMetho isMatched++ } } - if isMatched == 1 { + if isMatched >= 1 { return value, found, nil - } else if isMatched > 1 { - return nil, found, fmt.Errorf("decoding oneOf failed: %d schemas matched", isMatched) } if required { return nil, found, fmt.Errorf("decoding oneOf failed: %q is required", param) @@ -620,6 +618,10 @@ func (d *urlValuesDecoder) parseValue(v string, schema *openapi3.SchemaRef) (int } +const ( + urlDecoderDelimiter = "\x1F" // should not conflict with URL characters +) + func (d *urlValuesDecoder) DecodeObject(param string, sm *openapi3.SerializationMethod, schema *openapi3.SchemaRef) (map[string]interface{}, bool, error) { var propsFn func(url.Values) (map[string]string, error) switch sm.Style { @@ -672,6 +674,11 @@ func (d *urlValuesDecoder) DecodeObject(param string, sm *openapi3.Serialization return nil, false, nil } + val, err := makeObject(props, schema) + if err != nil { + return nil, false, err + } + // check the props found := false for propName := range schema.Value.Properties { @@ -679,8 +686,18 @@ func (d *urlValuesDecoder) DecodeObject(param string, sm *openapi3.Serialization found = true break } + + if schema.Value.Type == "array" || schema.Value.Type == "object" { + for k := range props { + path := strings.Split(k, urlDecoderDelimiter) + if _, ok := deepGet(val, path...); ok { + found = true + break + } + } + } } - val, err := makeObject(props, schema) + return val, found, err } @@ -846,22 +863,258 @@ func propsFromString(src, propDelim, valueDelim string) (map[string]string, erro return props, nil } +func deepGet(m map[string]interface{}, keys ...string) (interface{}, bool) { + for _, key := range keys { + val, ok := m[key] + if !ok { + return nil, false + } + if m, ok = val.(map[string]interface{}); !ok { + return val, true + } + } + return m, true +} + +func deepSet(m map[string]interface{}, keys []string, value interface{}) { + for i := 0; i < len(keys)-1; i++ { + key := keys[i] + if _, ok := m[key]; !ok { + m[key] = make(map[string]interface{}) + } + m = m[key].(map[string]interface{}) + } + m[keys[len(keys)-1]] = value +} + +//func findNestedSchema(parentSchema *openapi3.SchemaRef, keys []string) (*openapi3.SchemaRef, error) { +// currentSchema := parentSchema +// for _, key := range keys { +// if currentSchema.Value.Type.Includes(openapi3.TypeArray) { +// currentSchema = currentSchema.Value.Items +// } else { +// propertySchema, ok := currentSchema.Value.Properties[key] +// if !ok { +// if currentSchema.Value.AdditionalProperties.Schema == nil { +// return nil, fmt.Errorf("nested schema for key %q not found", key) +// } +// currentSchema = currentSchema.Value.AdditionalProperties.Schema +// continue +// } +// currentSchema = propertySchema +// } +// } +// return currentSchema, nil +//} + // makeObject returns an object that contains properties from props. // A value of every property is parsed as a primitive value. // The function returns an error when an error happened while parse object's properties. +//func makeObject(props map[string]string, schema *openapi3.SchemaRef) (map[string]interface{}, error) { +// obj := make(map[string]interface{}) +// for propName, propSchema := range schema.Value.Properties { +// value, err := parsePrimitive(props[propName], propSchema) +// if err != nil { +// if v, ok := err.(*ParseError); ok { +// return nil, &ParseError{path: []interface{}{propName}, Cause: v} +// } +// return nil, fmt.Errorf("property %q: %w", propName, err) +// } +// obj[propName] = value +// } +// return obj, nil +//} + +// makeObject returns an object that contains properties from props. func makeObject(props map[string]string, schema *openapi3.SchemaRef) (map[string]interface{}, error) { - obj := make(map[string]interface{}) - for propName, propSchema := range schema.Value.Properties { - value, err := parsePrimitive(props[propName], propSchema) + mobj := make(map[string]interface{}) + + for kk, value := range props { + keys := strings.Split(kk, urlDecoderDelimiter) + if strings.Contains(value, urlDecoderDelimiter) { + // don't support implicit array indexes anymore + p := pathFromKeys(keys) + return nil, &ParseError{path: p, Kind: KindInvalidFormat, Reason: "array items must be set with indexes"} + } + deepSet(mobj, keys, value) + } + r, err := buildResObj(mobj, nil, "", schema) + if err != nil { + return nil, err + } + result, ok := r.(map[string]interface{}) + if !ok { + return nil, &ParseError{Kind: KindOther, Reason: "invalid param object", Value: result} + } + + return result, nil +} + +// example: map[0:map[key:true] 1:map[key:false]] -> [map[key:true] map[key:false]] +func sliceMapToSlice(m map[string]interface{}) ([]interface{}, error) { + var result []interface{} + + keys := make([]int, 0, len(m)) + for k := range m { + key, err := strconv.Atoi(k) if err != nil { - if v, ok := err.(*ParseError); ok { - return nil, &ParseError{path: []interface{}{propName}, Cause: v} + return nil, fmt.Errorf("array indexes must be integers: %w", err) + } + keys = append(keys, key) + } + max := -1 + for _, k := range keys { + if k > max { + max = k + } + } + for i := 0; i <= max; i++ { + val, ok := m[strconv.Itoa(i)] + if !ok { + result = append(result, nil) + continue + } + result = append(result, val) + } + return result, nil +} + +// buildResObj constructs an object based on a given schema and param values +func buildResObj(params map[string]interface{}, parentKeys []string, key string, schema *openapi3.SchemaRef) (interface{}, error) { + mapKeys := parentKeys + if key != "" { + mapKeys = append(mapKeys, key) + } + + switch { + case schema.Value.Type == "array": + paramArr, ok := deepGet(params, mapKeys...) + if !ok { + return nil, nil + } + t, isMap := paramArr.(map[string]interface{}) + if !isMap { + return nil, &ParseError{path: pathFromKeys(mapKeys), Kind: KindInvalidFormat, Reason: "array items must be set with indexes"} + } + // intermediate arrays have to be instantiated + arr, err := sliceMapToSlice(t) + if err != nil { + return nil, &ParseError{path: pathFromKeys(mapKeys), Kind: KindInvalidFormat, Reason: fmt.Sprintf("could not convert value map to array: %v", err)} + } + resultArr := make([]interface{} /*not 0,*/, len(arr)) + for i := range arr { + r, err := buildResObj(params, mapKeys, strconv.Itoa(i), schema.Value.Items) + if err != nil { + return nil, err + } + if r != nil { + resultArr[i] = r + } + } + return resultArr, nil + case schema.Value.Type == "object": + resultMap := make(map[string]interface{}) + additPropsSchema := schema.Value.AdditionalProperties.Schema + pp, _ := deepGet(params, mapKeys...) + objectParams, ok := pp.(map[string]interface{}) + if !ok { + // not the expected type, but return it either way and leave validation up to ValidateParameter + return pp, nil + } + for k, propSchema := range schema.Value.Properties { + r, err := buildResObj(params, mapKeys, k, propSchema) + if err != nil { + return nil, err } - return nil, fmt.Errorf("property %q: %w", propName, err) + if r != nil { + resultMap[k] = r + } + } + if additPropsSchema != nil { + // dynamic creation of possibly nested objects + for k := range objectParams { + r, err := buildResObj(params, mapKeys, k, additPropsSchema) + if err != nil { + return nil, err + } + if r != nil { + resultMap[k] = r + } + } + } + + return resultMap, nil + case len(schema.Value.AnyOf) > 0: + return buildFromSchemas(schema.Value.AnyOf, params, parentKeys, key) + case len(schema.Value.OneOf) > 0: + return buildFromSchemas(schema.Value.OneOf, params, parentKeys, key) + case len(schema.Value.AllOf) > 0: + return buildFromSchemas(schema.Value.AllOf, params, parentKeys, key) + default: + val, ok := deepGet(params, mapKeys...) + if !ok { + // leave validation up to ValidateParameter. here there really is not parameter set + return nil, nil + } + v, ok := val.(string) + if !ok { + return nil, &ParseError{path: pathFromKeys(mapKeys), Kind: KindInvalidFormat, Value: val, Reason: "path is not convertible to primitive"} } - obj[propName] = value + prim, err := parsePrimitive(v, schema) + if err != nil { + return nil, handlePropParseError(mapKeys, err) + } + + return prim, nil } - return obj, nil +} + +// buildFromSchemas decodes params with anyOf, oneOf, allOf schemas. +func buildFromSchemas(schemas openapi3.SchemaRefs, params map[string]interface{}, mapKeys []string, key string) (interface{}, error) { + resultMap := make(map[string]interface{}) + for _, s := range schemas { + val, err := buildResObj(params, mapKeys, key, s) + if err == nil && val != nil { + + if m, ok := val.(map[string]interface{}); ok { + for k, v := range m { + resultMap[k] = v + } + continue + } + + if a, ok := val.([]interface{}); ok { + if len(a) > 0 { + return a, nil + } + continue + } + + // if its a primitive and not nil just return that and let it be validated + return val, nil + } + } + + if len(resultMap) > 0 { + return resultMap, nil + } + + return nil, nil +} + +func handlePropParseError(path []string, err error) error { + if v, ok := err.(*ParseError); ok { + return &ParseError{path: pathFromKeys(path), Cause: v} + } + return fmt.Errorf("property %q: %w", strings.Join(path, "."), err) +} + +func pathFromKeys(kk []string) []interface{} { + path := make([]interface{}, 0, len(kk)) + for _, v := range kk { + path = append(path, v) + } + return path } // parseArray returns an array that contains items from a raw array. diff --git a/internal/platform/validator/req_resp_decoder_test.go b/internal/platform/validator/req_resp_decoder_test.go index 118beaa..a9fb0a3 100644 --- a/internal/platform/validator/req_resp_decoder_test.go +++ b/internal/platform/validator/req_resp_decoder_test.go @@ -1238,7 +1238,7 @@ func TestDecodeBody(t *testing.T) { WithProperty("d", openapi3.NewObjectSchema().WithProperty("d1", openapi3.NewStringSchema())). WithProperty("f", openapi3.NewStringSchema().WithFormat("binary")). WithProperty("g", openapi3.NewStringSchema()), - want: map[string]interface{}{"a": "a1", "b": int64(10), "c": []interface{}{"c1", "c2"}, "d": map[string]interface{}{"d1": "d1"}, "f": "foo", "g": "g1"}, + want: map[string]interface{}{"a": "a1", "b": json.Number("10"), "c": []interface{}{"c1", "c2"}, "d": map[string]interface{}{"d1": "d1"}, "f": "foo", "g": "g1"}, }, { name: "multipartExtraPart", diff --git a/internal/platform/validator/validate_request.go b/internal/platform/validator/validate_request.go index c811635..8e424bc 100644 --- a/internal/platform/validator/validate_request.go +++ b/internal/platform/validator/validate_request.go @@ -142,24 +142,33 @@ func ValidateParameter(ctx context.Context, input *openapi3filter.RequestValidat } // Set default value if needed - if !options.SkipSettingDefaults && value == nil && schema != nil && schema.Default != nil { + if !options.SkipSettingDefaults && value == nil && schema != nil { value = schema.Default - req := input.Request - switch parameter.In { - case openapi3.ParameterInPath: - // Path parameters are required. - // Next check `parameter.Required && !found` will catch this. - case openapi3.ParameterInQuery: - q := req.URL.Query() - q.Add(parameter.Name, fmt.Sprintf("%v", value)) - req.URL.RawQuery = q.Encode() - case openapi3.ParameterInHeader: - req.Header.Add(parameter.Name, fmt.Sprintf("%v", value)) - case openapi3.ParameterInCookie: - req.AddCookie(&http.Cookie{ - Name: parameter.Name, - Value: fmt.Sprintf("%v", value), - }) + + for _, subSchema := range schema.AllOf { + if subSchema.Value.Default != nil { + value = subSchema.Value.Default + break // This is not a validation of the schema itself, so use the first default value. + } + } + if value != nil { + req := input.Request + switch parameter.In { + case openapi3.ParameterInPath: + // Path parameters are required. + // Next check `parameter.Required && !found` will catch this. + case openapi3.ParameterInQuery: + q := req.URL.Query() + q.Add(parameter.Name, fmt.Sprintf("%v", value)) + req.URL.RawQuery = q.Encode() + case openapi3.ParameterInHeader: + req.Header.Add(parameter.Name, fmt.Sprintf("%v", value)) + case openapi3.ParameterInCookie: + req.AddCookie(&http.Cookie{ + Name: parameter.Name, + Value: fmt.Sprintf("%v", value), + }) + } } } diff --git a/internal/platform/validator/validate_request_test.go b/internal/platform/validator/validate_request_test.go index 069e3e5..ce2f093 100644 --- a/internal/platform/validator/validate_request_test.go +++ b/internal/platform/validator/validate_request_test.go @@ -39,7 +39,7 @@ func TestValidateRequest(t *testing.T) { openapi: 3.0.0 info: title: 'Validator' - version: 0.0.1 + version: 0.0.2 paths: /category: post: @@ -63,6 +63,14 @@ paths: category: type: string default: Sweets + subCategoryInt: + type: integer + minimum: 100 + maximum: 1000 + categoryFloat: + type: number + minimum: 123.10 + maximum: 123.20 responses: '201': description: Created @@ -98,8 +106,10 @@ components: } type testRequestBody struct { - SubCategory string `json:"subCategory"` - Category string `json:"category,omitempty"` + SubCategory string `json:"subCategory"` + Category string `json:"category,omitempty"` + SubCategoryInt int `json:"subCategoryInt,omitempty"` + CategoryFloat float32 `json:"categoryFloat,omitempty"` } type args struct { requestBody *testRequestBody @@ -115,7 +125,7 @@ components: { name: "Valid request with all fields set", args: args{ - requestBody: &testRequestBody{SubCategory: "Chocolate", Category: "Food"}, + requestBody: &testRequestBody{SubCategory: "Chocolate", Category: "Food", SubCategoryInt: 123, CategoryFloat: 123.12}, url: "/category?category=cookies", apiKey: "SomeKey", }, @@ -172,6 +182,26 @@ components: expectedModification: false, expectedErr: &openapi3filter.SecurityRequirementsError{}, }, + { + name: "Invalid SubCategoryInt value", + args: args{ + requestBody: &testRequestBody{SubCategory: "Chocolate", Category: "Food", SubCategoryInt: 1, CategoryFloat: 123.12}, + url: "/category?category=cookies", + apiKey: "SomeKey", + }, + expectedModification: false, + expectedErr: &openapi3filter.RequestError{}, + }, + { + name: "Invalid CategoryFloat value", + args: args{ + requestBody: &testRequestBody{SubCategory: "Chocolate", Category: "Food", SubCategoryInt: 123, CategoryFloat: 123.21}, + url: "/category?category=cookies", + apiKey: "SomeKey", + }, + expectedModification: false, + expectedErr: &openapi3filter.RequestError{}, + }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { diff --git a/internal/platform/validator/validate_response.go b/internal/platform/validator/validate_response.go index 0c87a38..ec0c195 100644 --- a/internal/platform/validator/validate_response.go +++ b/internal/platform/validator/validate_response.go @@ -45,10 +45,10 @@ func ValidateResponse(ctx context.Context, input *openapi3filter.ResponseValidat // Find input for the current status responses := route.Operation.Responses - if len(responses) == 0 { + if responses.Len() == 0 { return nil } - responseRef := responses.Get(status) // Response + responseRef := responses.Status(status) // Response if responseRef == nil { responseRef = responses.Default() // Default input } diff --git a/internal/platform/web/web.go b/internal/platform/web/web.go index 95e764a..d358b2c 100644 --- a/internal/platform/web/web.go +++ b/internal/platform/web/web.go @@ -2,7 +2,6 @@ package web import ( "bytes" - "github.com/wallarm/api-firewall/internal/platform/router" "os" "runtime/debug" "syscall" @@ -11,6 +10,7 @@ import ( "github.com/savsgio/gotils/strconv" "github.com/sirupsen/logrus" "github.com/valyala/fasthttp" + "github.com/wallarm/api-firewall/internal/platform/router" ) const ( @@ -35,7 +35,6 @@ const ( ProxyMode = "proxy" GraphQLMode = "graphql" - AnyMethod = "any" RequestID = "__wallarm_apifw_request_id" ) From 0c400ff5b7486efe2bd72410d984022d12e45472 Mon Sep 17 00:00:00 2001 From: Nikolay Tkachenko Date: Mon, 15 Apr 2024 01:05:41 +0300 Subject: [PATCH 09/12] Update dependencies --- .../internal/handlers/api/errors.go | 148 ++-- .../internal/updater/wallarm_api2_update.db | Bin 98304 -> 98304 bytes cmd/api-firewall/tests/main_api_mode_test.go | 66 +- cmd/api-firewall/tests/main_test.go | 6 +- go.mod | 2 +- go.sum | 4 +- .../platform/validator/req_resp_decoder.go | 217 +++--- .../validator/req_resp_decoder_test.go | 671 +++++++++++++++--- .../validator/unknown_parameters_request.go | 12 +- .../unknown_parameters_request_test.go | 122 +++- .../validator/validate_request_test.go | 302 ++++++++ 11 files changed, 1268 insertions(+), 282 deletions(-) diff --git a/cmd/api-firewall/internal/handlers/api/errors.go b/cmd/api-firewall/internal/handlers/api/errors.go index 1cc4cd9..22603b4 100644 --- a/cmd/api-firewall/internal/handlers/api/errors.go +++ b/cmd/api-firewall/internal/handlers/api/errors.go @@ -79,12 +79,14 @@ func checkRequiredFields(reqErr *openapi3filter.RequestError, schemaError *opena response.Fields = schemaError.JSONPointer() response.Message = ErrMissedRequiredParameters.Error() - details := web.FieldTypeError{ - Name: reqErr.Parameter.Name, - ExpectedType: schemaError.Schema.Type, - } + for _, t := range schemaError.Schema.Type.Slice() { + details := web.FieldTypeError{ + Name: reqErr.Parameter.Name, + ExpectedType: t, + } - response.FieldsDetails = append(response.FieldsDetails, details) + response.FieldsDetails = append(response.FieldsDetails, details) + } return &response default: @@ -107,23 +109,25 @@ func checkRequiredFields(reqErr *openapi3filter.RequestError, schemaError *opena return &response } - details := web.FieldTypeError{ - Name: reqErr.Parameter.Name, - ExpectedType: schemaError.Schema.Type, - CurrentValue: fmt.Sprintf("%v", schemaError.Value), - } + for _, t := range schemaError.Schema.Type.Slice() { + details := web.FieldTypeError{ + Name: reqErr.Parameter.Name, + ExpectedType: t, + CurrentValue: fmt.Sprintf("%v", schemaError.Value), + } - // handle max, min and pattern cases - switch schemaError.SchemaField { - case "maximum": - details.Pattern = fmt.Sprintf("<=%0.4f", *schemaError.Schema.Max) - case "minimum": - details.Pattern = fmt.Sprintf(">=%0.4f", *schemaError.Schema.Min) - case "pattern": - details.Pattern = schemaError.Schema.Pattern - } + // handle max, min and pattern cases + switch schemaError.SchemaField { + case "maximum": + details.Pattern = fmt.Sprintf("<=%0.4f", *schemaError.Schema.Max) + case "minimum": + details.Pattern = fmt.Sprintf(">=%0.4f", *schemaError.Schema.Min) + case "pattern": + details.Pattern = schemaError.Schema.Pattern + } - response.FieldsDetails = append(response.FieldsDetails, details) + response.FieldsDetails = append(response.FieldsDetails, details) + } } return &response } @@ -152,11 +156,13 @@ func getErrorResponse(validationError error) ([]*web.ValidationError, error) { response.Message = err.Error() response.Fields = []string{err.Parameter.Name} - details := web.FieldTypeError{ - Name: err.Parameter.Name, - ExpectedType: err.Parameter.Schema.Value.Type, + for _, t := range err.Parameter.Schema.Value.Type.Slice() { + details := web.FieldTypeError{ + Name: err.Parameter.Name, + ExpectedType: t, + } + response.FieldsDetails = append(response.FieldsDetails, details) } - response.FieldsDetails = append(response.FieldsDetails, details) responseErrors = append(responseErrors, &response) } @@ -183,12 +189,14 @@ func getErrorResponse(validationError error) ([]*web.ValidationError, error) { schemaError, ok := err.Err.(*openapi3.SchemaError) if ok { if schemaError.SchemaField == "pattern" { - response.FieldsDetails = append(response.FieldsDetails, web.FieldTypeError{ - Name: err.Parameter.Name, - ExpectedType: schemaError.Schema.Type, - Pattern: schemaError.Schema.Pattern, - CurrentValue: fmt.Sprintf("%v", schemaError.Value), - }) + for _, t := range schemaError.Schema.Type.Slice() { + response.FieldsDetails = append(response.FieldsDetails, web.FieldTypeError{ + Name: err.Parameter.Name, + ExpectedType: t, + Pattern: schemaError.Schema.Pattern, + CurrentValue: fmt.Sprintf("%v", schemaError.Value), + }) + } } } @@ -234,11 +242,13 @@ func getErrorResponse(validationError error) ([]*web.ValidationError, error) { for _, f := range response.Fields { if p, lookupErr := schemaError.Schema.Properties.JSONLookup(f); lookupErr == nil { - details := web.FieldTypeError{ - Name: f, - ExpectedType: p.(*openapi3.Schema).Type, + for _, t := range p.(*openapi3.Schema).Type.Slice() { + details := web.FieldTypeError{ + Name: f, + ExpectedType: t, + } + response.FieldsDetails = append(response.FieldsDetails, details) } - response.FieldsDetails = append(response.FieldsDetails, details) } } @@ -257,21 +267,23 @@ func getErrorResponse(validationError error) ([]*web.ValidationError, error) { CurrentValue: parseErr.ValueStr, }) } else { - details := web.FieldTypeError{ - Name: response.Fields[0], - ExpectedType: schemaError.Schema.Type, - CurrentValue: fmt.Sprintf("%v", schemaError.Value), + for _, t := range schemaError.Schema.Type.Slice() { + details := web.FieldTypeError{ + Name: response.Fields[0], + ExpectedType: t, + CurrentValue: fmt.Sprintf("%v", schemaError.Value), + } + switch schemaError.SchemaField { + case "pattern": + details.Pattern = schemaError.Schema.Pattern + case "maximum": + details.Pattern = fmt.Sprintf("<=%0.4f", *schemaError.Schema.Max) + case "minimum": + details.Pattern = fmt.Sprintf(">=%0.4f", *schemaError.Schema.Min) + } + + response.FieldsDetails = append(response.FieldsDetails, details) } - switch schemaError.SchemaField { - case "pattern": - details.Pattern = schemaError.Schema.Pattern - case "maximum": - details.Pattern = fmt.Sprintf("<=%0.4f", *schemaError.Schema.Max) - case "minimum": - details.Pattern = fmt.Sprintf(">=%0.4f", *schemaError.Schema.Min) - } - - response.FieldsDetails = append(response.FieldsDetails, details) } responseErrors = append(responseErrors, &response) } @@ -290,11 +302,13 @@ func getErrorResponse(validationError error) ([]*web.ValidationError, error) { for _, f := range response.Fields { if p, lookupErr := schemaError.Schema.Properties.JSONLookup(f); lookupErr == nil { - details := web.FieldTypeError{ - Name: f, - ExpectedType: p.(*openapi3.Schema).Type, + for _, t := range p.(*openapi3.Schema).Type.Slice() { + details := web.FieldTypeError{ + Name: f, + ExpectedType: t, + } + response.FieldsDetails = append(response.FieldsDetails, details) } - response.FieldsDetails = append(response.FieldsDetails, details) } } @@ -313,21 +327,25 @@ func getErrorResponse(validationError error) ([]*web.ValidationError, error) { CurrentValue: parseErr.ValueStr, }) } else { - details := web.FieldTypeError{ - Name: response.Fields[0], - ExpectedType: schemaError.Schema.Type, - CurrentValue: fmt.Sprintf("%v", schemaError.Value), - } - switch schemaError.SchemaField { - case "pattern": - details.Pattern = schemaError.Schema.Pattern - case "maximum": - details.Pattern = fmt.Sprintf("<=%0.4f", *schemaError.Schema.Max) - case "minimum": - details.Pattern = fmt.Sprintf(">=%0.4f", *schemaError.Schema.Min) + + for _, t := range schemaError.Schema.Type.Slice() { + details := web.FieldTypeError{ + Name: response.Fields[0], + ExpectedType: t, + CurrentValue: fmt.Sprintf("%v", schemaError.Value), + } + switch schemaError.SchemaField { + case "pattern": + details.Pattern = schemaError.Schema.Pattern + case "maximum": + details.Pattern = fmt.Sprintf("<=%0.4f", *schemaError.Schema.Max) + case "minimum": + details.Pattern = fmt.Sprintf(">=%0.4f", *schemaError.Schema.Min) + } + + response.FieldsDetails = append(response.FieldsDetails, details) } - response.FieldsDetails = append(response.FieldsDetails, details) } responseErrors = append(responseErrors, &response) } diff --git a/cmd/api-firewall/internal/updater/wallarm_api2_update.db b/cmd/api-firewall/internal/updater/wallarm_api2_update.db index c4bfe885d8cb04730960a1e156f5450576f592de..0c250b0af094360612e4ab070d28c9a942e3ac1b 100644 GIT binary patch delta 34 qcmZo@U~6b#n;^}2d!mdpCBn|I5)2r+(XGFWE77ytm*R}2gQ delta 34 qcmZo@U~6b#n;^|Nd7_LnCNn|I5)2r=GiGFWE77ytmsw+l=F diff --git a/cmd/api-firewall/tests/main_api_mode_test.go b/cmd/api-firewall/tests/main_api_mode_test.go index b19a444..ea95512 100644 --- a/cmd/api-firewall/tests/main_api_mode_test.go +++ b/cmd/api-firewall/tests/main_api_mode_test.go @@ -143,6 +143,7 @@ paths: content: multipart/form-data: schema: + type: object required: - url properties: @@ -400,6 +401,26 @@ paths: 200: description: Ok content: { } + /query/paramsObject: + get: + parameters: + - name: f.0 + required: true + in: query + style: deepObject + schema: + type: object + properties: + f: + type: object + properties: + '0': + type: string + summary: Get Test Info + responses: + 200: + description: Ok + content: { } /test/body/request: post: summary: Post Request to test Request Body presence @@ -586,6 +607,7 @@ func TestAPIModeBasic(t *testing.T) { // check conflicts in the Path t.Run("testConflictsInThePath", apifwTests.testConflictsInThePath) + t.Run("testObjectInQuery", apifwTests.testObjectInQuery) } func createForm(form map[string]string) (string, io.Reader, error) { @@ -2791,7 +2813,7 @@ func (s *APIModeServiceTests) testConflictsInThePath(t *testing.T) { handler := handlersAPI.Handlers(s.lock, &cfg, s.shutdown, s.logger, s.dbSpec, nil, nil) - // check all supported methods: GET POST PUT PATCH DELETE TRACE OPTIONS HEAD + // check all related paths for _, path := range []string{"/path/testValue1", "/path/value1.php"} { req := fasthttp.AcquireRequest() req.SetRequestURI(path) @@ -2828,3 +2850,45 @@ func (s *APIModeServiceTests) testConflictsInThePath(t *testing.T) { // check response status code and response body checkResponseForbiddenStatusCode(t, &reqCtx, DefaultSchemaID, []string{handlersAPI.ErrCodeRequiredPathParameterInvalidValue}) } + +func (s *APIModeServiceTests) testObjectInQuery(t *testing.T) { + + handler := handlersAPI.Handlers(s.lock, &cfg, s.shutdown, s.logger, s.dbSpec, nil, nil) + + for _, path := range []string{"/query/paramsObject?f.0%5Bf%5D%5B0%5D=test"} { + + req := fasthttp.AcquireRequest() + req.SetRequestURI(path) + req.Header.SetMethod("GET") + req.Header.Add(web.XWallarmSchemaIDHeader, fmt.Sprintf("%d", DefaultSchemaID)) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + t.Logf("Name of the test: %s; request method: %s; request uri: %s; request body: %s", t.Name(), string(reqCtx.Request.Header.Method()), string(reqCtx.Request.RequestURI()), string(reqCtx.Request.Body())) + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) + + // check response status code and response body + checkResponseOkStatusCode(t, &reqCtx, DefaultSchemaID) + } + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/query/paramsObject") + req.Header.SetMethod("GET") + req.Header.Add(web.XWallarmSchemaIDHeader, fmt.Sprintf("%d", DefaultSchemaID)) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + t.Logf("Name of the test: %s; request method: %s; request uri: %s; request body: %s", t.Name(), string(reqCtx.Request.Header.Method()), string(reqCtx.Request.RequestURI()), string(reqCtx.Request.Body())) + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) + + // check response status code and response body + checkResponseForbiddenStatusCode(t, &reqCtx, DefaultSchemaID, []string{handlersAPI.ErrCodeRequiredQueryParameterMissed}) +} diff --git a/cmd/api-firewall/tests/main_test.go b/cmd/api-firewall/tests/main_test.go index 1f691ce..3da4ba0 100644 --- a/cmd/api-firewall/tests/main_test.go +++ b/cmd/api-firewall/tests/main_test.go @@ -2521,7 +2521,7 @@ func (s *ServiceTests) unknownParamPostBody(t *testing.T) { req := fasthttp.AcquireRequest() req.SetRequestURI("/test/signup") req.Header.SetMethod("POST") - req.SetBodyString("firstname=test&lastname=testjob=test&email=test@wallarm.com&url=http://wallarm.com") + req.SetBodyString("firstname=test&lastname=test&job=test&email=test@wallarm.com&url=http://wallarm.com") req.Header.SetContentType("application/x-www-form-urlencoded") resp := fasthttp.AcquireResponse() @@ -2544,7 +2544,7 @@ func (s *ServiceTests) unknownParamPostBody(t *testing.T) { reqCtx.Response.StatusCode()) } - req.SetBodyString("firstname=test&lastname=testjob=test&email=test@wallarm.com&url=http://wallarm.com&test=hello") + req.SetBodyString("firstname=test&lastname=test&email=test@wallarm.com&url=http://wallarm.com&test=hello") reqCtx = fasthttp.RequestCtx{ Request: *req, @@ -2659,7 +2659,7 @@ func (s *ServiceTests) unknownParamUnsupportedMimeType(t *testing.T) { req.SetRequestURI("/test/signup") req.Header.SetMethod("POST") req.Header.SetContentType("application/x-www-form-urlencoded") - req.SetBodyString("firstname=test&lastname=testjob=test&email=test@wallarm.com&url=http://wallarm.com") + req.SetBodyString("firstname=test&lastname=test&job=test&email=test@wallarm.com&url=http://wallarm.com") resp := fasthttp.AcquireResponse() resp.SetStatusCode(fasthttp.StatusOK) diff --git a/go.mod b/go.mod index 07e1cca..6841534 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/dgraph-io/ristretto v0.1.1 github.com/fasthttp/websocket v1.5.8 github.com/gabriel-vasile/mimetype v1.4.3 - github.com/getkin/kin-openapi v0.123.0 + github.com/getkin/kin-openapi v0.124.0 github.com/go-playground/validator v9.31.0+incompatible github.com/golang-jwt/jwt v3.2.2+incompatible github.com/golang/mock v1.6.0 diff --git a/go.sum b/go.sum index 58feb97..289c3ee 100644 --- a/go.sum +++ b/go.sum @@ -46,8 +46,8 @@ github.com/foxcpp/go-mockdns v1.1.0 h1:jI0rD8M0wuYAxL7r/ynTrCQQq0BVqfB99Vgk7Dlme github.com/foxcpp/go-mockdns v1.1.0/go.mod h1:IhLeSFGed3mJIAXPH2aiRQB+kqz7oqu8ld2qVbOu7Wk= github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= -github.com/getkin/kin-openapi v0.123.0 h1:zIik0mRwFNLyvtXK274Q6ut+dPh6nlxBp0x7mNrPhs8= -github.com/getkin/kin-openapi v0.123.0/go.mod h1:wb1aSZA/iWmorQP9KTAS/phLj/t17B5jT7+fS8ed9NM= +github.com/getkin/kin-openapi v0.124.0 h1:VSFNMB9C9rTKBnQ/fpyDU8ytMTr4dWI9QovSKj9kz/M= +github.com/getkin/kin-openapi v0.124.0/go.mod h1:wb1aSZA/iWmorQP9KTAS/phLj/t17B5jT7+fS8ed9NM= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= github.com/gin-gonic/gin v1.6.3 h1:ahKqKTFpO5KTPHxWZjEdPScmYaGtLo8Y4DMHoEsnp14= diff --git a/internal/platform/validator/req_resp_decoder.go b/internal/platform/validator/req_resp_decoder.go index 5083e9a..18e3a9f 100644 --- a/internal/platform/validator/req_resp_decoder.go +++ b/internal/platform/validator/req_resp_decoder.go @@ -12,6 +12,7 @@ import ( "mime/multipart" "net/http" "net/url" + "reflect" "regexp" "strconv" "strings" @@ -21,6 +22,7 @@ import ( "github.com/clbanning/mxj/v2" "github.com/getkin/kin-openapi/openapi3" "github.com/getkin/kin-openapi/openapi3filter" + utils "github.com/savsgio/gotils/strconv" "github.com/valyala/fastjson" ) @@ -193,7 +195,7 @@ func defaultContentParameterDecoder(param *openapi3.Parameter, values []string) unmarshal := func(encoded string, paramSchema *openapi3.SchemaRef) (decoded interface{}, err error) { if err = json.Unmarshal([]byte(encoded), &decoded); err != nil { - if paramSchema != nil && paramSchema.Value.Type != "object" { + if paramSchema != nil && !paramSchema.Value.Type.Is("object") { decoded, err = encoded, nil } } @@ -243,10 +245,11 @@ func decodeStyledParameter(param *openapi3.Parameter, input *openapi3filter.Requ } dec = &pathParamDecoder{pathParams: input.PathParams} case openapi3.ParameterInQuery: - if len(input.GetQueryParams()) == 0 { + queryParams := input.GetQueryParams() + if len(queryParams) == 0 { return nil, false, nil } - dec = &urlValuesDecoder{values: input.GetQueryParams()} + dec = &urlValuesDecoder{values: queryParams} case openapi3.ParameterInHeader: dec = &headerParamDecoder{header: input.Request.Header} case openapi3.ParameterInCookie: @@ -314,14 +317,14 @@ func decodeValue(dec valueDecoder, param string, sm *openapi3.SerializationMetho return nil, found, errors.New("not implemented: decoding 'not'") } - if schema.Value.Type != "" { + if schema.Value.Type != nil { var decodeFn func(param string, sm *openapi3.SerializationMethod, schema *openapi3.SchemaRef) (interface{}, bool, error) - switch schema.Value.Type { - case "array": + switch { + case schema.Value.Type.Is("array"): decodeFn = func(param string, sm *openapi3.SerializationMethod, schema *openapi3.SchemaRef) (interface{}, bool, error) { return dec.DecodeArray(param, sm, schema) } - case "object": + case schema.Value.Type.Is("object"): decodeFn = func(param string, sm *openapi3.SerializationMethod, schema *openapi3.SchemaRef) (interface{}, bool, error) { return dec.DecodeObject(param, sm, schema) } @@ -504,7 +507,7 @@ func (d *urlValuesDecoder) DecodePrimitive(param string, sm *openapi3.Serializat return nil, ok, nil } - if schema.Value.Type == "" && schema.Value.Pattern != "" { + if schema.Value.Type == nil && schema.Value.Pattern != "" { return values[0], ok, nil } val, err := parsePrimitive(values[0], schema) @@ -649,12 +652,18 @@ func (d *urlValuesDecoder) DecodeObject(param string, sm *openapi3.Serialization propsFn = func(params url.Values) (map[string]string, error) { props := make(map[string]string) for key, values := range params { - groups := regexp.MustCompile(fmt.Sprintf("%s\\[(.+?)\\]", param)).FindAllStringSubmatch(key, -1) - if len(groups) == 0 { + matches := regexp.MustCompile(`\[(.*?)\]`).FindAllStringSubmatch(key, -1) + switch l := len(matches); { + case l == 0: // A query parameter's name does not match the required format, so skip it. continue + case l >= 1: + kk := []string{} + for _, m := range matches { + kk = append(kk, m[1]) + } + props[strings.Join(kk, urlDecoderDelimiter)] = strings.Join(values, urlDecoderDelimiter) } - props[groups[0][1]] = values[0] } if len(props) == 0 { // HTTP request does not contain query parameters encoded by rules of style "deepObject". @@ -687,7 +696,7 @@ func (d *urlValuesDecoder) DecodeObject(param string, sm *openapi3.Serialization break } - if schema.Value.Type == "array" || schema.Value.Type == "object" { + if schema.Value.Type.Permits("array") || schema.Value.Type.Permits("object") { for k := range props { path := strings.Split(k, urlDecoderDelimiter) if _, ok := deepGet(val, path...); ok { @@ -698,7 +707,7 @@ func (d *urlValuesDecoder) DecodeObject(param string, sm *openapi3.Serialization } } - return val, found, err + return val, found, nil } // headerParamDecoder decodes values of header parameters. @@ -830,7 +839,7 @@ func propsFromString(src, propDelim, valueDelim string) (map[string]string, erro pairs := strings.Split(src, propDelim) // When propDelim and valueDelim is equal the source string follow the next rule: - // every even item of pairs is a properies's name, and the subsequent odd item is a property's value. + // every even item of pairs is a properties's name, and the subsequent odd item is a property's value. if propDelim == valueDelim { // Taking into account the rule above, a valid source string must be splitted by propDelim // to an array with an even number of items. @@ -887,43 +896,25 @@ func deepSet(m map[string]interface{}, keys []string, value interface{}) { m[keys[len(keys)-1]] = value } -//func findNestedSchema(parentSchema *openapi3.SchemaRef, keys []string) (*openapi3.SchemaRef, error) { -// currentSchema := parentSchema -// for _, key := range keys { -// if currentSchema.Value.Type.Includes(openapi3.TypeArray) { -// currentSchema = currentSchema.Value.Items -// } else { -// propertySchema, ok := currentSchema.Value.Properties[key] -// if !ok { -// if currentSchema.Value.AdditionalProperties.Schema == nil { -// return nil, fmt.Errorf("nested schema for key %q not found", key) -// } -// currentSchema = currentSchema.Value.AdditionalProperties.Schema -// continue -// } -// currentSchema = propertySchema -// } -// } -// return currentSchema, nil -//} - -// makeObject returns an object that contains properties from props. -// A value of every property is parsed as a primitive value. -// The function returns an error when an error happened while parse object's properties. -//func makeObject(props map[string]string, schema *openapi3.SchemaRef) (map[string]interface{}, error) { -// obj := make(map[string]interface{}) -// for propName, propSchema := range schema.Value.Properties { -// value, err := parsePrimitive(props[propName], propSchema) -// if err != nil { -// if v, ok := err.(*ParseError); ok { -// return nil, &ParseError{path: []interface{}{propName}, Cause: v} -// } -// return nil, fmt.Errorf("property %q: %w", propName, err) -// } -// obj[propName] = value -// } -// return obj, nil -//} +func findNestedSchema(parentSchema *openapi3.SchemaRef, keys []string) (*openapi3.SchemaRef, error) { + currentSchema := parentSchema + for _, key := range keys { + if currentSchema.Value.Type.Includes(openapi3.TypeArray) { + currentSchema = currentSchema.Value.Items + } else { + propertySchema, ok := currentSchema.Value.Properties[key] + if !ok { + if currentSchema.Value.AdditionalProperties.Schema == nil { + return nil, fmt.Errorf("nested schema for key %q not found", key) + } + currentSchema = currentSchema.Value.AdditionalProperties.Schema + continue + } + currentSchema = propertySchema + } + } + return currentSchema, nil +} // makeObject returns an object that contains properties from props. func makeObject(props map[string]string, schema *openapi3.SchemaRef) (map[string]interface{}, error) { @@ -987,7 +978,7 @@ func buildResObj(params map[string]interface{}, parentKeys []string, key string, } switch { - case schema.Value.Type == "array": + case schema.Value.Type.Is("array"): paramArr, ok := deepGet(params, mapKeys...) if !ok { return nil, nil @@ -1012,7 +1003,7 @@ func buildResObj(params map[string]interface{}, parentKeys []string, key string, } } return resultArr, nil - case schema.Value.Type == "object": + case schema.Value.Type.Is("object"): resultMap := make(map[string]interface{}) additPropsSchema := schema.Value.AdditionalProperties.Schema pp, _ := deepGet(params, mapKeys...) @@ -1144,49 +1135,49 @@ func parseArray(raw []string, schemaRef *openapi3.SchemaRef) ([]interface{}, err // parsePrimitive returns a value that is created by parsing a source string to a primitive type // that is specified by a schema. The function returns nil when the source string is empty. // The function panics when a schema has a non-primitive type. -func parsePrimitive(raw string, schema *openapi3.SchemaRef) (interface{}, error) { +func parsePrimitive(raw string, schema *openapi3.SchemaRef) (v interface{}, err error) { if raw == "" { return nil, nil } + for _, typ := range schema.Value.Type.Slice() { + if v, err = parsePrimitiveCase(raw, schema, typ); err == nil { + return + } + } + return +} - switch schema.Value.Type { +func parsePrimitiveCase(raw string, schema *openapi3.SchemaRef, typ string) (interface{}, error) { + switch typ { case "integer": - if len(schema.Value.Enum) > 0 { - // parse int as float because of the comparison with float enum values - v, err := strconv.ParseFloat(raw, 64) - if err != nil { - return nil, &ParseError{Kind: KindInvalidFormat, ValueStr: raw, ExpectedType: schema.Value.Type, Value: raw, Reason: "an invalid " + schema.Value.Type, Cause: err.(*strconv.NumError).Err} - } - return v, nil - } if schema.Value.Format == "int32" { v, err := strconv.ParseInt(raw, 0, 32) if err != nil { - return nil, &ParseError{Kind: KindInvalidFormat, ValueStr: raw, ExpectedType: schema.Value.Type, Value: raw, Reason: "an invalid " + schema.Value.Type, Cause: err.(*strconv.NumError).Err} + return nil, &ParseError{Kind: KindInvalidFormat, Value: raw, Reason: "an invalid " + typ, Cause: err.(*strconv.NumError).Err} } return int32(v), nil } v, err := strconv.ParseInt(raw, 0, 64) if err != nil { - return nil, &ParseError{Kind: KindInvalidFormat, ValueStr: raw, ExpectedType: schema.Value.Type, Value: raw, Reason: "an invalid " + schema.Value.Type, Cause: err.(*strconv.NumError).Err} + return nil, &ParseError{Kind: KindInvalidFormat, Value: raw, Reason: "an invalid " + typ, Cause: err.(*strconv.NumError).Err} } return v, nil case "number": v, err := strconv.ParseFloat(raw, 64) if err != nil { - return nil, &ParseError{Kind: KindInvalidFormat, ValueStr: raw, ExpectedType: schema.Value.Type, Value: raw, Reason: "an invalid " + schema.Value.Type, Cause: err.(*strconv.NumError).Err} + return nil, &ParseError{Kind: KindInvalidFormat, Value: raw, Reason: "an invalid " + typ, Cause: err.(*strconv.NumError).Err} } return v, nil case "boolean": v, err := strconv.ParseBool(raw) if err != nil { - return nil, &ParseError{Kind: KindInvalidFormat, ValueStr: raw, ExpectedType: schema.Value.Type, Value: raw, Reason: "an invalid " + schema.Value.Type, Cause: err.(*strconv.NumError).Err} + return nil, &ParseError{Kind: KindInvalidFormat, Value: raw, Reason: "an invalid " + typ, Cause: err.(*strconv.NumError).Err} } return v, nil case "string": return raw, nil default: - return nil, &ParseError{Kind: KindOther, Value: raw, Reason: "schema has non primitive type " + schema.Value.Type} + return nil, &ParseError{Kind: KindOther, Value: raw, Reason: "schema has non primitive type " + typ} } } @@ -1293,7 +1284,6 @@ func decodeBody(body io.Reader, header http.Header, schema *openapi3.SchemaRef, return mediaType, value, nil } } - mediaType, suffix := parseMediaType(contentType) decoder, err := getBodyDecoder(mediaType, suffix) if err != nil { @@ -1343,16 +1333,16 @@ func multipartPartBodyDecoder(body io.Reader, header http.Header, schema *openap return nil, &ParseError{Kind: KindInvalidFormat, Cause: err} } - dataStr := string(data) + dataStr := utils.B2S(data) - switch schema.Value.Type { - case "integer", "number": + switch { + case schema.Value.Type.Is("integer") || schema.Value.Type.Is("number"): floatValue, err := strconv.ParseFloat(dataStr, 64) if err != nil { return nil, &ParseError{Kind: KindInvalidFormat, Cause: err} } return floatValue, nil - case "boolean": + case schema.Value.Type.Is("boolean"): boolValue, err := strconv.ParseBool(dataStr) if err != nil { return nil, &ParseError{Kind: KindInvalidFormat, Cause: err} @@ -1407,13 +1397,20 @@ func yamlBodyDecoder(body io.Reader, header http.Header, schema *openapi3.Schema func urlencodedBodyDecoder(body io.Reader, header http.Header, schema *openapi3.SchemaRef, encFn EncodingFn, jsonParser *fastjson.Parser) (interface{}, error) { + // Validate schema of request body. + // By the OpenAPI 3 specification request body's schema must have type "object". + // Properties of the schema describes individual parts of request body. + if !schema.Value.Type.Is("object") { + return nil, errors.New("unsupported schema of request body") + } for propName, propSchema := range schema.Value.Properties { - switch propSchema.Value.Type { - case "object": + propType := propSchema.Value.Type + switch { + case propType.Is("object"): return nil, fmt.Errorf("unsupported schema of request body's property %q", propName) - case "array": + case propType.Is("array"): items := propSchema.Value.Items.Value - if items.Type != "string" && items.Type != "integer" && items.Type != "number" && items.Type != "boolean" { + if !(items.Type.Is("string") || items.Type.Is("integer") || items.Type.Is("number") || items.Type.Is("boolean")) { return nil, fmt.Errorf("unsupported schema of request body's property %q", propName) } } @@ -1432,29 +1429,61 @@ func urlencodedBodyDecoder(body io.Reader, header http.Header, schema *openapi3. // Make an object value from form values. obj := make(map[string]interface{}) dec := &urlValuesDecoder{values: values} - for name, prop := range schema.Value.Properties { - var ( - value interface{} - enc *openapi3.Encoding - ) - if encFn != nil { - enc = encFn(name) - } - sm := enc.SerializationMethod() + // Decode schema constructs (allOf, anyOf, oneOf) + if err := decodeSchemaConstructs(dec, schema.Value.AllOf, obj, encFn); err != nil { + return nil, err + } + if err := decodeSchemaConstructs(dec, schema.Value.AnyOf, obj, encFn); err != nil { + return nil, err + } + if err := decodeSchemaConstructs(dec, schema.Value.OneOf, obj, encFn); err != nil { + return nil, err + } - found := false - if value, found, err = decodeValue(dec, name, sm, prop, false); err != nil { - return nil, err - } - if found { + // Decode properties from the main schema + if err := decodeSchemaConstructs(dec, []*openapi3.SchemaRef{schema}, obj, encFn); err != nil { + return nil, err + } + + return obj, nil +} + +// decodeSchemaConstructs tries to decode properties based on provided schemas. +// This function is for decoding purposes only and not for validation. +func decodeSchemaConstructs(dec *urlValuesDecoder, schemas []*openapi3.SchemaRef, obj map[string]interface{}, encFn EncodingFn) error { + for _, schemaRef := range schemas { + for name, prop := range schemaRef.Value.Properties { + value, _, err := decodeProperty(dec, name, prop, encFn) + if err != nil { + continue + } + if existingValue, exists := obj[name]; exists && !isEqual(existingValue, value) { + return fmt.Errorf("conflicting values for property %q", name) + } obj[name] = value } } - return obj, nil + return nil +} + +func isEqual(value1, value2 interface{}) bool { + return reflect.DeepEqual(value1, value2) +} + +func decodeProperty(dec valueDecoder, name string, prop *openapi3.SchemaRef, encFn EncodingFn) (interface{}, bool, error) { + var enc *openapi3.Encoding + if encFn != nil { + enc = encFn(name) + } + sm := enc.SerializationMethod() + return decodeValue(dec, name, sm, prop, false) } func multipartBodyDecoder(body io.Reader, header http.Header, schema *openapi3.SchemaRef, encFn EncodingFn, jsonParser *fastjson.Parser) (interface{}, error) { + if !schema.Value.Type.Is("object") { + return nil, errors.New("unsupported schema of request body") + } // Parse form. values := make(map[string][]interface{}) @@ -1515,7 +1544,7 @@ func multipartBodyDecoder(body io.Reader, header http.Header, schema *openapi3.S return nil, &ParseError{Kind: KindOther, Cause: fmt.Errorf("part %s: undefined", name)} } } - if valueSchema.Value.Type == "array" { + if valueSchema.Value.Type.Is("array") { valueSchema = valueSchema.Value.Items } } @@ -1560,7 +1589,7 @@ func multipartBodyDecoder(body io.Reader, header http.Header, schema *openapi3.S if len(vv) == 0 { continue } - if prop.Value.Type == "array" { + if prop.Value.Type.Is("array") { obj[name] = vv } else { obj[name] = vv[0] diff --git a/internal/platform/validator/req_resp_decoder_test.go b/internal/platform/validator/req_resp_decoder_test.go index a9fb0a3..829ebbd 100644 --- a/internal/platform/validator/req_resp_decoder_test.go +++ b/internal/platform/validator/req_resp_decoder_test.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "fmt" + "github.com/stretchr/testify/assert" "io" "mime/multipart" "net/http" @@ -21,53 +22,272 @@ import ( "github.com/valyala/fastjson" ) -func TestDecodeParameter(t *testing.T) { - var ( - boolPtr = func(b bool) *bool { return &b } - explode = boolPtr(true) - noExplode = boolPtr(false) - arrayOf = func(items *openapi3.SchemaRef) *openapi3.SchemaRef { - return &openapi3.SchemaRef{Value: &openapi3.Schema{Type: "array", Items: items}} +var ( + explode = openapi3.BoolPtr(true) + noExplode = openapi3.BoolPtr(false) + arrayOf = func(items *openapi3.SchemaRef) *openapi3.SchemaRef { + return &openapi3.SchemaRef{Value: &openapi3.Schema{Type: &openapi3.Types{"array"}, Items: items}} + } + objectOf = func(args ...interface{}) *openapi3.SchemaRef { + s := &openapi3.SchemaRef{Value: &openapi3.Schema{Type: &openapi3.Types{"object"}, Properties: make(map[string]*openapi3.SchemaRef)}} + if len(args)%2 != 0 { + panic("invalid arguments. must be an even number of arguments") } - objectOf = func(args ...interface{}) *openapi3.SchemaRef { - s := &openapi3.SchemaRef{Value: &openapi3.Schema{Type: "object", Properties: make(map[string]*openapi3.SchemaRef)}} - if len(args)%2 != 0 { - panic("invalid arguments. must be an even number of arguments") - } - for i := 0; i < len(args)/2; i++ { - propName := args[i*2].(string) - propSchema := args[i*2+1].(*openapi3.SchemaRef) - s.Value.Properties[propName] = propSchema - } - return s + for i := 0; i < len(args)/2; i++ { + propName := args[i*2].(string) + propSchema := args[i*2+1].(*openapi3.SchemaRef) + s.Value.Properties[propName] = propSchema } + return s + } - integerSchema = &openapi3.SchemaRef{Value: &openapi3.Schema{Type: "integer"}} - numberSchema = &openapi3.SchemaRef{Value: &openapi3.Schema{Type: "number"}} - booleanSchema = &openapi3.SchemaRef{Value: &openapi3.Schema{Type: "boolean"}} - stringSchema = &openapi3.SchemaRef{Value: &openapi3.Schema{Type: "string"}} - allofSchema = &openapi3.SchemaRef{ - Value: &openapi3.Schema{ - AllOf: []*openapi3.SchemaRef{ - integerSchema, - numberSchema, - }}} - anyofSchema = &openapi3.SchemaRef{ - Value: &openapi3.Schema{ - AnyOf: []*openapi3.SchemaRef{ - integerSchema, - stringSchema, - }}} - oneofSchema = &openapi3.SchemaRef{ - Value: &openapi3.Schema{ - OneOf: []*openapi3.SchemaRef{ - booleanSchema, - integerSchema, - }}} - arraySchema = arrayOf(stringSchema) - objectSchema = objectOf("id", stringSchema, "name", stringSchema) - ) + additionalPropertiesObjectOf = func(schema *openapi3.SchemaRef) *openapi3.SchemaRef { + return &openapi3.SchemaRef{Value: &openapi3.Schema{Type: &openapi3.Types{"object"}, AdditionalProperties: openapi3.AdditionalProperties{Schema: schema}}} + } + integerSchema = &openapi3.SchemaRef{Value: &openapi3.Schema{Type: &openapi3.Types{"integer"}}} + numberSchema = &openapi3.SchemaRef{Value: &openapi3.Schema{Type: &openapi3.Types{"number"}}} + booleanSchema = &openapi3.SchemaRef{Value: &openapi3.Schema{Type: &openapi3.Types{"boolean"}}} + stringSchema = &openapi3.SchemaRef{Value: &openapi3.Schema{Type: &openapi3.Types{"string"}}} + additionalPropertiesObjectStringSchema = &openapi3.SchemaRef{Value: &openapi3.Schema{Type: &openapi3.Types{"object"}, AdditionalProperties: openapi3.AdditionalProperties{Schema: stringSchema}}} + additionalPropertiesObjectBoolSchema = &openapi3.SchemaRef{Value: &openapi3.Schema{Type: &openapi3.Types{"object"}, AdditionalProperties: openapi3.AdditionalProperties{Schema: booleanSchema}}} + allofSchema = &openapi3.SchemaRef{ + Value: &openapi3.Schema{ + AllOf: []*openapi3.SchemaRef{ + integerSchema, + numberSchema, + }, + }, + } + anyofSchema = &openapi3.SchemaRef{ + Value: &openapi3.Schema{ + AnyOf: []*openapi3.SchemaRef{ + integerSchema, + stringSchema, + }, + }, + } + oneofSchema = &openapi3.SchemaRef{ + Value: &openapi3.Schema{ + OneOf: []*openapi3.SchemaRef{ + booleanSchema, + integerSchema, + }, + }, + } + oneofSchemaObject = &openapi3.SchemaRef{ + Value: &openapi3.Schema{ + OneOf: []*openapi3.SchemaRef{ + objectOneRSchema, + objectTwoRSchema, + }, + }, + } + anyofSchemaObject = &openapi3.SchemaRef{ + Value: &openapi3.Schema{ + AnyOf: []*openapi3.SchemaRef{ + objectOneRSchema, + objectTwoRSchema, + }, + }, + } + stringArraySchema = arrayOf(stringSchema) + integerArraySchema = arrayOf(integerSchema) + objectSchema = objectOf("id", stringSchema, "name", stringSchema) + objectTwoRSchema = func() *openapi3.SchemaRef { + s := objectOf("id2", stringSchema, "name2", stringSchema) + s.Value.Required = []string{"id2"} + + return s + }() + + objectOneRSchema = func() *openapi3.SchemaRef { + s := objectOf("id", stringSchema, "name", stringSchema) + s.Value.Required = []string{"id"} + + return s + }() + + oneofSchemaArrayObject = &openapi3.SchemaRef{ + Value: &openapi3.Schema{ + AnyOf: []*openapi3.SchemaRef{ + stringArraySchema, + objectTwoRSchema, + }, + }, + } +) + +func TestDeepGet(t *testing.T) { + iarray := map[string]interface{}{ + "0": map[string]interface{}{ + "foo": 111, + }, + "1": map[string]interface{}{ + "bar": 222, + }, + } + + tests := []struct { + name string + m map[string]interface{} + keys []string + expected interface{} + shouldFind bool + }{ + { + name: "Simple map - key exists", + m: map[string]interface{}{ + "foo": "bar", + }, + keys: []string{"foo"}, + expected: "bar", + shouldFind: true, + }, + { + name: "Nested map - key exists", + m: map[string]interface{}{ + "foo": map[string]interface{}{ + "bar": "baz", + }, + }, + keys: []string{"foo", "bar"}, + expected: "baz", + shouldFind: true, + }, + { + name: "Nested map - key does not exist", + m: map[string]interface{}{ + "foo": map[string]interface{}{ + "bar": "baz", + }, + }, + keys: []string{"foo", "baz"}, + expected: nil, + shouldFind: false, + }, + { + name: "Array - element exists", + m: map[string]interface{}{ + "array": map[string]interface{}{"0": "a", "1": "b", "2": "c"}, + }, + keys: []string{"array", "1"}, + expected: "b", + shouldFind: true, + }, + { + name: "Array - element does not exist - invalid index", + m: map[string]interface{}{ + "array": map[string]interface{}{"0": "a", "1": "b", "2": "c"}, + }, + keys: []string{"array", "3"}, + expected: nil, + shouldFind: false, + }, + { + name: "Array - element does not exist - invalid keys", + m: map[string]interface{}{ + "array": map[string]interface{}{"0": "a", "1": "b", "2": "c"}, + }, + keys: []string{"array", "a", "999"}, + expected: nil, + shouldFind: false, + }, + { + name: "Array of objects - element exists 1", + m: map[string]interface{}{ + "array": iarray, + }, + keys: []string{"array", "1", "bar"}, + expected: 222, + shouldFind: true, + }, + { + name: "Array of objects - element exists 2", + m: map[string]interface{}{ + "array": iarray, + }, + keys: []string{"array", "0"}, + expected: map[string]interface{}{ + "foo": 111, + }, + shouldFind: true, + }, + { + name: "Array of objects - element exists 3", + m: map[string]interface{}{ + "array": iarray, + }, + keys: []string{"array"}, + expected: iarray, + shouldFind: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + tc := tc + + result, found := deepGet(tc.m, tc.keys...) + require.Equal(t, tc.shouldFind, found, "shouldFind mismatch") + require.Equal(t, tc.expected, result, "result mismatch") + }) + } +} + +func TestDeepSet(t *testing.T) { + tests := []struct { + name string + inputMap map[string]interface{} + keys []string + value interface{} + expected map[string]interface{} + }{ + { + name: "simple set", + inputMap: map[string]interface{}{}, + keys: []string{"key"}, + value: "value", + expected: map[string]interface{}{"key": "value"}, + }, + { + name: "intermediate array of objects", + inputMap: map[string]interface{}{}, + keys: []string{"nested", "0", "key"}, + value: true, + expected: map[string]interface{}{ + "nested": map[string]interface{}{ + "0": map[string]interface{}{ + "key": true, + }, + }, + }, + }, + { + name: "existing nested array of objects", + inputMap: map[string]interface{}{"nested": map[string]interface{}{"0": map[string]interface{}{"existingKey": "existingValue"}}}, + keys: []string{"nested", "0", "newKey"}, + value: "newValue", + expected: map[string]interface{}{ + "nested": map[string]interface{}{ + "0": map[string]interface{}{ + "existingKey": "existingValue", + "newKey": "newValue", + }, + }, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + deepSet(tc.inputMap, tc.keys, tc.value) + require.EqualValues(t, tc.expected, tc.inputMap) + }) + } +} + +func TestDecodeParameter(t *testing.T) { type testCase struct { name string param *openapi3.Parameter @@ -220,77 +440,77 @@ func TestDecodeParameter(t *testing.T) { testCases: []testCase{ { name: "simple", - param: &openapi3.Parameter{Name: "param", In: "path", Style: "simple", Explode: noExplode, Schema: arraySchema}, + param: &openapi3.Parameter{Name: "param", In: "path", Style: "simple", Explode: noExplode, Schema: stringArraySchema}, path: "/foo,bar", want: []interface{}{"foo", "bar"}, found: true, }, { name: "simple explode", - param: &openapi3.Parameter{Name: "param", In: "path", Style: "simple", Explode: explode, Schema: arraySchema}, + param: &openapi3.Parameter{Name: "param", In: "path", Style: "simple", Explode: explode, Schema: stringArraySchema}, path: "/foo,bar", want: []interface{}{"foo", "bar"}, found: true, }, { name: "label", - param: &openapi3.Parameter{Name: "param", In: "path", Style: "label", Explode: noExplode, Schema: arraySchema}, + param: &openapi3.Parameter{Name: "param", In: "path", Style: "label", Explode: noExplode, Schema: stringArraySchema}, path: "/.foo,bar", want: []interface{}{"foo", "bar"}, found: true, }, { name: "label invalid", - param: &openapi3.Parameter{Name: "param", In: "path", Style: "label", Explode: noExplode, Schema: arraySchema}, + param: &openapi3.Parameter{Name: "param", In: "path", Style: "label", Explode: noExplode, Schema: stringArraySchema}, path: "/foo,bar", found: true, err: &ParseError{Kind: KindInvalidFormat, Value: "foo,bar"}, }, { name: "label explode", - param: &openapi3.Parameter{Name: "param", In: "path", Style: "label", Explode: explode, Schema: arraySchema}, + param: &openapi3.Parameter{Name: "param", In: "path", Style: "label", Explode: explode, Schema: stringArraySchema}, path: "/.foo.bar", want: []interface{}{"foo", "bar"}, found: true, }, { name: "label explode invalid", - param: &openapi3.Parameter{Name: "param", In: "path", Style: "label", Explode: explode, Schema: arraySchema}, + param: &openapi3.Parameter{Name: "param", In: "path", Style: "label", Explode: explode, Schema: stringArraySchema}, path: "/foo.bar", found: true, err: &ParseError{Kind: KindInvalidFormat, Value: "foo.bar"}, }, { name: "matrix", - param: &openapi3.Parameter{Name: "param", In: "path", Style: "matrix", Explode: noExplode, Schema: arraySchema}, + param: &openapi3.Parameter{Name: "param", In: "path", Style: "matrix", Explode: noExplode, Schema: stringArraySchema}, path: "/;param=foo,bar", want: []interface{}{"foo", "bar"}, found: true, }, { name: "matrix invalid", - param: &openapi3.Parameter{Name: "param", In: "path", Style: "matrix", Explode: noExplode, Schema: arraySchema}, + param: &openapi3.Parameter{Name: "param", In: "path", Style: "matrix", Explode: noExplode, Schema: stringArraySchema}, path: "/foo,bar", found: true, err: &ParseError{Kind: KindInvalidFormat, Value: "foo,bar"}, }, { name: "matrix explode", - param: &openapi3.Parameter{Name: "param", In: "path", Style: "matrix", Explode: explode, Schema: arraySchema}, + param: &openapi3.Parameter{Name: "param", In: "path", Style: "matrix", Explode: explode, Schema: stringArraySchema}, path: "/;param=foo;param=bar", want: []interface{}{"foo", "bar"}, found: true, }, { name: "matrix explode invalid", - param: &openapi3.Parameter{Name: "param", In: "path", Style: "matrix", Explode: explode, Schema: arraySchema}, + param: &openapi3.Parameter{Name: "param", In: "path", Style: "matrix", Explode: explode, Schema: stringArraySchema}, path: "/foo,bar", found: true, err: &ParseError{Kind: KindInvalidFormat, Value: "foo,bar"}, }, { name: "default", - param: &openapi3.Parameter{Name: "param", In: "path", Schema: arraySchema}, + param: &openapi3.Parameter{Name: "param", In: "path", Schema: stringArraySchema}, path: "/foo,bar", want: []interface{}{"foo", "bar"}, found: true, @@ -565,49 +785,49 @@ func TestDecodeParameter(t *testing.T) { testCases: []testCase{ { name: "form", - param: &openapi3.Parameter{Name: "param", In: "query", Style: "form", Explode: noExplode, Schema: arraySchema}, + param: &openapi3.Parameter{Name: "param", In: "query", Style: "form", Explode: noExplode, Schema: stringArraySchema}, query: "param=foo,bar", want: []interface{}{"foo", "bar"}, found: true, }, { name: "form explode", - param: &openapi3.Parameter{Name: "param", In: "query", Style: "form", Explode: explode, Schema: arraySchema}, + param: &openapi3.Parameter{Name: "param", In: "query", Style: "form", Explode: explode, Schema: stringArraySchema}, query: "param=foo¶m=bar", want: []interface{}{"foo", "bar"}, found: true, }, { name: "spaceDelimited", - param: &openapi3.Parameter{Name: "param", In: "query", Style: "spaceDelimited", Explode: noExplode, Schema: arraySchema}, + param: &openapi3.Parameter{Name: "param", In: "query", Style: "spaceDelimited", Explode: noExplode, Schema: stringArraySchema}, query: "param=foo bar", want: []interface{}{"foo", "bar"}, found: true, }, { name: "spaceDelimited explode", - param: &openapi3.Parameter{Name: "param", In: "query", Style: "spaceDelimited", Explode: explode, Schema: arraySchema}, + param: &openapi3.Parameter{Name: "param", In: "query", Style: "spaceDelimited", Explode: explode, Schema: stringArraySchema}, query: "param=foo¶m=bar", want: []interface{}{"foo", "bar"}, found: true, }, { name: "pipeDelimited", - param: &openapi3.Parameter{Name: "param", In: "query", Style: "pipeDelimited", Explode: noExplode, Schema: arraySchema}, + param: &openapi3.Parameter{Name: "param", In: "query", Style: "pipeDelimited", Explode: noExplode, Schema: stringArraySchema}, query: "param=foo|bar", want: []interface{}{"foo", "bar"}, found: true, }, { name: "pipeDelimited explode", - param: &openapi3.Parameter{Name: "param", In: "query", Style: "pipeDelimited", Explode: explode, Schema: arraySchema}, + param: &openapi3.Parameter{Name: "param", In: "query", Style: "pipeDelimited", Explode: explode, Schema: stringArraySchema}, query: "param=foo¶m=bar", want: []interface{}{"foo", "bar"}, found: true, }, { name: "default", - param: &openapi3.Parameter{Name: "param", In: "query", Schema: arraySchema}, + param: &openapi3.Parameter{Name: "param", In: "query", Schema: stringArraySchema}, query: "param=foo¶m=bar", want: []interface{}{"foo", "bar"}, found: true, @@ -659,6 +879,275 @@ func TestDecodeParameter(t *testing.T) { want: map[string]interface{}{"id": "foo", "name": "bar"}, found: true, }, + { + name: "deepObject explode additionalProperties with object properties - missing index on nested array", + param: &openapi3.Parameter{ + Name: "param", In: "query", Style: "deepObject", Explode: explode, + Schema: objectOf( + "obj", additionalPropertiesObjectOf(objectOf("item1", integerSchema, "item2", stringArraySchema)), + "objIgnored", objectOf("items", stringArraySchema), + ), + }, + query: "param[obj][prop2][item2]=def", + err: &ParseError{path: []interface{}{"obj", "prop2", "item2"}, Kind: KindInvalidFormat, Reason: "array items must be set with indexes"}, + }, + { + name: "deepObject explode array - missing indexes", + param: &openapi3.Parameter{Name: "param", In: "query", Style: "deepObject", Explode: explode, Schema: objectOf("items", stringArraySchema)}, + query: "param[items]=f%26oo¶m[items]=bar", + found: true, + err: &ParseError{path: []interface{}{"items"}, Kind: KindInvalidFormat, Reason: "array items must be set with indexes"}, + }, + { + name: "deepObject explode array", + param: &openapi3.Parameter{Name: "param", In: "query", Style: "deepObject", Explode: explode, Schema: objectOf("items", integerArraySchema)}, + query: "param[items][1]=456¶m[items][0]=123", + want: map[string]interface{}{"items": []interface{}{int64(123), int64(456)}}, + found: true, + }, + { + name: "deepObject explode nested object additionalProperties", + param: &openapi3.Parameter{ + Name: "param", In: "query", Style: "deepObject", Explode: explode, + Schema: objectOf( + "obj", additionalPropertiesObjectStringSchema, + "objTwo", stringSchema, + "objIgnored", objectOf("items", stringArraySchema), + ), + }, + query: "param[obj][prop1]=bar¶m[obj][prop2]=foo¶m[objTwo]=string", + want: map[string]interface{}{ + "obj": map[string]interface{}{"prop1": "bar", "prop2": "foo"}, + "objTwo": "string", + }, + found: true, + }, + { + name: "deepObject explode additionalProperties with object properties - sharing property", + param: &openapi3.Parameter{ + Name: "param", In: "query", Style: "deepObject", Explode: explode, + Schema: objectOf( + "obj", additionalPropertiesObjectOf(objectOf("item1", integerSchema, "item2", stringSchema)), + "objIgnored", objectOf("items", stringArraySchema), + ), + }, + query: "param[obj][prop1][item1]=1¶m[obj][prop1][item2]=abc", + want: map[string]interface{}{ + "obj": map[string]interface{}{"prop1": map[string]interface{}{ + "item1": int64(1), + "item2": "abc", + }}, + }, + found: true, + }, + { + name: "deepObject explode nested object additionalProperties - bad value", + param: &openapi3.Parameter{ + Name: "param", In: "query", Style: "deepObject", Explode: explode, + Schema: objectOf( + "obj", additionalPropertiesObjectBoolSchema, + "objTwo", stringSchema, + "objIgnored", objectOf("items", stringArraySchema), + ), + }, + query: "param[obj][prop1]=notbool¶m[objTwo]=string", + err: &ParseError{path: []interface{}{"obj", "prop1"}, Cause: &ParseError{Kind: KindInvalidFormat, Value: "notbool"}}, + }, + { + name: "deepObject explode nested object additionalProperties - bad index inside additionalProperties", + param: &openapi3.Parameter{ + Name: "param", In: "query", Style: "deepObject", Explode: explode, + Schema: objectOf( + "obj", additionalPropertiesObjectStringSchema, + "objTwo", stringSchema, + "objIgnored", objectOf("items", stringArraySchema), + ), + }, + query: "param[obj][prop1]=bar¶m[obj][prop2][badindex]=bad¶m[objTwo]=string", + err: &ParseError{ + path: []interface{}{"obj", "prop2"}, + Reason: `path is not convertible to primitive`, + Kind: KindInvalidFormat, + Value: map[string]interface{}{"badindex": "bad"}, + }, + }, + { + name: "deepObject explode nested object", + param: &openapi3.Parameter{ + Name: "param", In: "query", Style: "deepObject", Explode: explode, + Schema: objectOf( + "obj", objectOf("nestedObjOne", stringSchema, "nestedObjTwo", stringSchema), + "objTwo", stringSchema, + "objIgnored", objectOf("items", stringArraySchema), + ), + }, + query: "param[obj][nestedObjOne]=bar¶m[obj][nestedObjTwo]=foo¶m[objTwo]=string", + want: map[string]interface{}{ + "obj": map[string]interface{}{"nestedObjOne": "bar", "nestedObjTwo": "foo"}, + "objTwo": "string", + }, + found: true, + }, + { + name: "deepObject explode nested object - extraneous param ignored", + param: &openapi3.Parameter{ + Name: "param", In: "query", Style: "deepObject", Explode: explode, + Schema: objectOf( + "obj", objectOf("nestedObjOne", stringSchema, "nestedObjTwo", stringSchema), + ), + }, + query: "anotherparam=bar", + want: map[string]interface{}(nil), + }, + { + name: "deepObject explode nested object - bad array item type", + param: &openapi3.Parameter{ + Name: "param", In: "query", Style: "deepObject", Explode: explode, + Schema: objectOf( + "objTwo", integerArraySchema, + ), + }, + query: "param[objTwo][0]=badint", + err: &ParseError{path: []interface{}{"objTwo", "0"}, Cause: &ParseError{Kind: KindInvalidFormat, Value: "badint"}}, + }, + { + name: "deepObject explode deeply nested object - bad array item type", + param: &openapi3.Parameter{ + Name: "param", In: "query", Style: "deepObject", Explode: explode, + Schema: objectOf( + "obj", objectOf("nestedObjOne", integerArraySchema), + ), + }, + query: "param[obj][nestedObjOne][0]=badint", + err: &ParseError{path: []interface{}{"obj", "nestedObjOne", "0"}, Cause: &ParseError{Kind: KindInvalidFormat, Value: "badint"}}, + }, + { + name: "deepObject explode deeply nested object - array index not an int", + param: &openapi3.Parameter{ + Name: "param", In: "query", Style: "deepObject", Explode: explode, + Schema: objectOf( + "obj", objectOf("nestedObjOne", integerArraySchema), + ), + }, + query: "param[obj][nestedObjOne][badindex]=badint", + err: &ParseError{path: []interface{}{"obj", "nestedObjOne"}, Kind: KindInvalidFormat, Reason: "could not convert value map to array: array indexes must be integers: strconv.Atoi: parsing \"badindex\": invalid syntax"}, + }, + { + name: "deepObject explode nested object with array", + param: &openapi3.Parameter{ + Name: "param", In: "query", Style: "deepObject", Explode: explode, + Schema: objectOf( + "obj", objectOf("nestedObjOne", stringSchema, "nestedObjTwo", stringSchema), + "objTwo", stringArraySchema, + "objIgnored", objectOf("items", stringArraySchema), + ), + }, + query: "param[obj][nestedObjOne]=bar¶m[obj][nestedObjTwo]=foo¶m[objTwo][0]=f%26oo¶m[objTwo][1]=bar", + want: map[string]interface{}{ + "obj": map[string]interface{}{"nestedObjOne": "bar", "nestedObjTwo": "foo"}, + "objTwo": []interface{}{"f%26oo", "bar"}, + }, + found: true, + }, + { + name: "deepObject explode nested object with array - bad value", + param: &openapi3.Parameter{ + Name: "param", In: "query", Style: "deepObject", Explode: explode, + Schema: objectOf( + "obj", objectOf("nestedObjOne", stringSchema, "nestedObjTwo", booleanSchema), + "objTwo", stringArraySchema, + "objIgnored", objectOf("items", stringArraySchema), + ), + }, + query: "param[obj][nestedObjOne]=bar¶m[obj][nestedObjTwo]=bad¶m[objTwo][0]=f%26oo¶m[objTwo][1]=bar", + err: &ParseError{path: []interface{}{"obj", "nestedObjTwo"}, Cause: &ParseError{Kind: KindInvalidFormat, Value: "bad"}}, + }, + { + name: "deepObject explode nested object with nested array", + param: &openapi3.Parameter{ + Name: "param", In: "query", Style: "deepObject", Explode: explode, + Schema: objectOf( + "obj", objectOf("nestedObjOne", stringSchema, "nestedObjTwo", stringSchema), + "objTwo", objectOf("items", stringArraySchema), + "objIgnored", objectOf("items", stringArraySchema), + ), + }, + query: "param[obj][nestedObjOne]=bar¶m[obj][nestedObjTwo]=foo¶m[objTwo][items][0]=f%26oo¶m[objTwo][items][1]=bar", + want: map[string]interface{}{ + "obj": map[string]interface{}{"nestedObjOne": "bar", "nestedObjTwo": "foo"}, + "objTwo": map[string]interface{}{"items": []interface{}{"f%26oo", "bar"}}, + }, + found: true, + }, + { + name: "deepObject explode nested object with nested array on different levels", + param: &openapi3.Parameter{ + Name: "param", In: "query", Style: "deepObject", Explode: explode, + Schema: objectOf( + "obj", objectOf("nestedObjOne", objectOf("items", stringArraySchema)), + "objTwo", objectOf("items", stringArraySchema), + ), + }, + query: "param[obj][nestedObjOne][items][0]=baz¶m[objTwo][items][0]=foo¶m[objTwo][items][1]=bar", + want: map[string]interface{}{ + "obj": map[string]interface{}{"nestedObjOne": map[string]interface{}{"items": []interface{}{"baz"}}}, + "objTwo": map[string]interface{}{"items": []interface{}{"foo", "bar"}}, + }, + found: true, + }, + { + name: "deepObject explode array of arrays", + param: &openapi3.Parameter{ + Name: "param", In: "query", Style: "deepObject", Explode: explode, + Schema: objectOf( + "arr", arrayOf(arrayOf(integerSchema)), + ), + }, + query: "param[arr][1][1]=123¶m[arr][1][2]=456", + want: map[string]interface{}{ + "arr": []interface{}{ + nil, + []interface{}{nil, int64(123), int64(456)}, + }, + }, + found: true, + }, + { + name: "deepObject explode nested array of objects - missing intermediate array index", + param: &openapi3.Parameter{ + Name: "param", In: "query", Style: "deepObject", Explode: explode, + Schema: objectOf( + "arr", arrayOf(objectOf("key", booleanSchema)), + ), + }, + query: "param[arr][3][key]=true¶m[arr][0][key]=false", + want: map[string]interface{}{ + "arr": []interface{}{ + map[string]interface{}{"key": false}, + nil, + nil, + map[string]interface{}{"key": true}, + }, + }, + found: true, + }, + { + name: "deepObject explode nested array of objects", + param: &openapi3.Parameter{ + Name: "param", In: "query", Style: "deepObject", Explode: explode, + Schema: objectOf( + "arr", arrayOf(objectOf("key", booleanSchema)), + ), + }, + query: "param[arr][0][key]=true¶m[arr][1][key]=false", + found: true, + want: map[string]interface{}{ + "arr": []interface{}{ + map[string]interface{}{"key": true}, + map[string]interface{}{"key": false}, + }, + }, + }, { name: "default", param: &openapi3.Parameter{Name: "param", In: "query", Schema: objectSchema}, @@ -769,21 +1258,21 @@ func TestDecodeParameter(t *testing.T) { testCases: []testCase{ { name: "simple", - param: &openapi3.Parameter{Name: "X-Param", In: "header", Style: "simple", Explode: noExplode, Schema: arraySchema}, + param: &openapi3.Parameter{Name: "X-Param", In: "header", Style: "simple", Explode: noExplode, Schema: stringArraySchema}, header: "X-Param:foo,bar", want: []interface{}{"foo", "bar"}, found: true, }, { name: "simple explode", - param: &openapi3.Parameter{Name: "X-Param", In: "header", Style: "simple", Explode: explode, Schema: arraySchema}, + param: &openapi3.Parameter{Name: "X-Param", In: "header", Style: "simple", Explode: explode, Schema: stringArraySchema}, header: "X-Param:foo,bar", want: []interface{}{"foo", "bar"}, found: true, }, { name: "default", - param: &openapi3.Parameter{Name: "X-Param", In: "header", Schema: arraySchema}, + param: &openapi3.Parameter{Name: "X-Param", In: "header", Schema: stringArraySchema}, header: "X-Param:foo,bar", want: []interface{}{"foo", "bar"}, found: true, @@ -945,7 +1434,7 @@ func TestDecodeParameter(t *testing.T) { testCases: []testCase{ { name: "form", - param: &openapi3.Parameter{Name: "X-Param", In: "cookie", Style: "form", Explode: noExplode, Schema: arraySchema}, + param: &openapi3.Parameter{Name: "X-Param", In: "cookie", Style: "form", Explode: noExplode, Schema: stringArraySchema}, cookie: "X-Param:foo,bar", want: []interface{}{"foo", "bar"}, found: true, @@ -1044,16 +1533,16 @@ func TestDecodeParameter(t *testing.T) { Title: "MyAPI", Version: "0.1", } - spec := &openapi3.T{OpenAPI: "3.0.0", Info: info} + doc := &openapi3.T{OpenAPI: "3.0.0", Info: info, Paths: openapi3.NewPaths()} op := &openapi3.Operation{ OperationID: "test", Parameters: []*openapi3.ParameterRef{{Value: tc.param}}, Responses: openapi3.NewResponses(), } - spec.AddOperation(path, http.MethodGet, op) - err = spec.Validate(context.Background()) + doc.AddOperation(path, http.MethodGet, op) + err = doc.Validate(context.Background()) require.NoError(t, err) - router, err := legacyrouter.NewRouter(spec) + router, err := legacyrouter.NewRouter(doc) require.NoError(t, err) route, pathParams, err := router.FindRoute(req) @@ -1062,16 +1551,17 @@ func TestDecodeParameter(t *testing.T) { input := &openapi3filter.RequestValidationInput{Request: req, PathParams: pathParams, Route: route} got, found, err := decodeStyledParameter(tc.param, input) - require.Truef(t, found == tc.found, "got found: %t, want found: %t", found, tc.found) - if tc.err != nil { require.Error(t, err) - require.Truef(t, matchParseError(err, tc.err), "got error:\n%v\nwant error:\n%v", err, tc.err) + matchParseError(t, err, tc.err) + return } require.NoError(t, err) - require.Truef(t, reflect.DeepEqual(got, tc.want), "got %v, want %v", got, tc.want) + require.EqualValues(t, tc.want, got) + + require.Truef(t, found == tc.found, "got found: %t, want found: %t", found, tc.found) }) } }) @@ -1079,8 +1569,6 @@ func TestDecodeParameter(t *testing.T) { } func TestDecodeBody(t *testing.T) { - boolPtr := func(b bool) *bool { return &b } - urlencodedForm := make(url.Values) urlencodedForm.Set("a", "a1") urlencodedForm.Set("b", "10") @@ -1210,7 +1698,7 @@ func TestDecodeBody(t *testing.T) { WithProperty("b", openapi3.NewIntegerSchema()). WithProperty("c", openapi3.NewArraySchema().WithItems(openapi3.NewStringSchema())), encoding: map[string]*openapi3.Encoding{ - "c": {Style: openapi3.SerializationSpaceDelimited, Explode: boolPtr(false)}, + "c": {Style: openapi3.SerializationSpaceDelimited, Explode: openapi3.BoolPtr(false)}, }, want: map[string]interface{}{"a": "a1", "b": int64(10), "c": []interface{}{"c1", "c2"}}, }, @@ -1223,7 +1711,7 @@ func TestDecodeBody(t *testing.T) { WithProperty("b", openapi3.NewIntegerSchema()). WithProperty("c", openapi3.NewArraySchema().WithItems(openapi3.NewStringSchema())), encoding: map[string]*openapi3.Encoding{ - "c": {Style: openapi3.SerializationPipeDelimited, Explode: boolPtr(false)}, + "c": {Style: openapi3.SerializationPipeDelimited, Explode: openapi3.BoolPtr(false)}, }, want: map[string]interface{}{"a": "a1", "b": int64(10), "c": []interface{}{"c1", "c2"}}, }, @@ -1304,7 +1792,7 @@ func TestDecodeBody(t *testing.T) { if tc.wantErr != nil { require.Error(t, err) - require.Truef(t, matchParseError(err, tc.wantErr), "got error:\n%v\nwant error:\n%v", err, tc.wantErr) + matchParseError(t, err, tc.wantErr) return } @@ -1387,26 +1875,27 @@ func TestRegisterAndUnregisterBodyDecoder(t *testing.T) { }, err) } -func matchParseError(got, want error) bool { +func matchParseError(t *testing.T, got, want error) { + t.Helper() + wErr, ok := want.(*ParseError) if !ok { - return false + t.Errorf("want error is not a ParseError") + return } gErr, ok := got.(*ParseError) if !ok { - return false + t.Errorf("got error is not a ParseError") + return } - if wErr.Kind != gErr.Kind { - return false - } - if !reflect.DeepEqual(wErr.Value, gErr.Value) { - return false - } - if !reflect.DeepEqual(wErr.Path(), gErr.Path()) { - return false + assert.Equalf(t, wErr.Kind, gErr.Kind, "ParseError Kind differs") + assert.Equalf(t, wErr.Value, gErr.Value, "ParseError Value differs") + assert.Equalf(t, wErr.Path(), gErr.Path(), "ParseError Path differs") + + if wErr.Reason != "" { + assert.Equalf(t, wErr.Reason, gErr.Reason, "ParseError Reason differs") } if wErr.Cause != nil { - return matchParseError(gErr.Cause, wErr.Cause) + matchParseError(t, gErr.Cause, wErr.Cause) } - return true } diff --git a/internal/platform/validator/unknown_parameters_request.go b/internal/platform/validator/unknown_parameters_request.go index 57e56a7..56cf3fa 100644 --- a/internal/platform/validator/unknown_parameters_request.go +++ b/internal/platform/validator/unknown_parameters_request.go @@ -6,6 +6,7 @@ import ( "io" "net/http" "strconv" + "strings" "github.com/getkin/kin-openapi/openapi3" "github.com/getkin/kin-openapi/routers" @@ -97,7 +98,16 @@ func ValidateUnknownRequestParameters(ctx *fasthttp.RequestCtx, route *routers.R unknownQueryParams := RequestUnknownParameterError{} // compare list of all query params and list of params defined in the specification ctx.Request.URI().QueryArgs().VisitAll(func(key, value []byte) { - if _, ok := specParams[utils.B2S(key)+openapi3.ParameterInQuery]; !ok { + + keyStr := utils.B2S(key) + + if i := strings.Index(keyStr, "["); i > 0 { + if _, ok := specParams[keyStr[:i]+openapi3.ParameterInQuery]; ok { + return + } + } + + if _, ok := specParams[keyStr+openapi3.ParameterInQuery]; !ok { unknownQueryParams.Message = ErrUnknownQueryParameter.Error() unknownQueryParams.Parameters = append(unknownQueryParams.Parameters, RequestParameterDetails{ Name: utils.B2S(key), diff --git a/internal/platform/validator/unknown_parameters_request_test.go b/internal/platform/validator/unknown_parameters_request_test.go index acea83a..45f11df 100644 --- a/internal/platform/validator/unknown_parameters_request_test.go +++ b/internal/platform/validator/unknown_parameters_request_test.go @@ -3,10 +3,8 @@ package validator import ( "bytes" "encoding/json" - "errors" "io" "net/http" - "reflect" "strings" "testing" @@ -23,7 +21,7 @@ func TestUnknownParametersRequest(t *testing.T) { openapi: 3.0.0 info: title: 'Validator' - version: 0.0.1 + version: 0.0.2 paths: /category: post: @@ -67,9 +65,21 @@ paths: required: true content: application/json: - schema: {} + schema: + type: object + required: + - subCategory + properties: + subCategory: + type: string application/x-www-form-urlencoded: - schema: {} + schema: + type: object + required: + - subCategory + properties: + subCategory: + type: string responses: '201': description: Created @@ -221,9 +231,9 @@ paths: expectedResp: nil, }, { - name: "Valid POST unknown params", + name: "Valid POST unknown params 0", args: args{ - requestBody: &testRequestBody{SubCategory: "Chocolate", Category: &categoryFood}, + requestBody: &testRequestBody{SubCategory: "Chocolate", UnknownParameter: "unknownValue"}, url: "/unknown", ct: "application/x-www-form-urlencoded", }, @@ -231,31 +241,30 @@ paths: expectedResp: []*RequestUnknownParameterError{ { Parameters: []RequestParameterDetails{{ - Name: "subCategory", + Name: "unknown", Placeholder: "body", Type: "string", - }, - { - Name: "category", - Placeholder: "body", - Type: "string", - }}, + }}, Message: ErrUnknownBodyParameter.Error(), }, }, }, { - name: "Valid JSON unknown params", + name: "Valid POST unknown params 1", args: args{ - requestBody: &testRequestBody{SubCategory: "Chocolate"}, + requestBody: &testRequestBody{SubCategory: "Chocolate", Category: &categoryFood, UnknownParameter: "unknownValue"}, url: "/unknown", - ct: "application/json", + ct: "application/x-www-form-urlencoded", }, expectedErr: nil, expectedResp: []*RequestUnknownParameterError{ { Parameters: []RequestParameterDetails{{ - Name: "subCategory", + Name: "unknown", + Placeholder: "body", + Type: "string", + }, { + Name: "category", Placeholder: "body", Type: "string", }}, @@ -263,6 +272,30 @@ paths: }, }, }, + { + name: "Valid JSON unknown params 2", + args: args{ + requestBody: &testRequestBody{SubCategory: "Chocolate", Category: &categoryFood, UnknownParameter: "unknownValue"}, + url: "/unknown", + ct: "application/json", + }, + expectedErr: nil, + expectedResp: []*RequestUnknownParameterError{ + { + Parameters: []RequestParameterDetails{{ + Name: "unknown", + Placeholder: "body", + Type: "string", + }, + { + Name: "category", + Placeholder: "body", + Type: "string", + }}, + Message: ErrUnknownBodyParameter.Error(), + }, + }, + }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { @@ -281,8 +314,10 @@ paths: if tc.args.requestBody.SubCategory != "" { req.PostArgs().Add("subCategory", tc.args.requestBody.SubCategory) } - if *tc.args.requestBody.Category != "" { - req.PostArgs().Add("category", *tc.args.requestBody.Category) + if tc.args.requestBody.Category != nil { + if *tc.args.requestBody.Category != "" { + req.PostArgs().Add("category", *tc.args.requestBody.Category) + } } requestBody = strings.NewReader(req.PostArgs().String()) case "application/json": @@ -316,13 +351,52 @@ paths: if tc.expectedErr != nil { return } - if tc.expectedResp != nil || len(tc.expectedResp) > 0 { + if tc.expectedResp != nil && len(tc.expectedResp) > 0 { assert.Equal(t, len(tc.expectedResp), len(upRes), "expect the number of unknown parameters: %t, got %t", len(tc.expectedResp), len(upRes)) + assert.Equal(t, true, matchUnknownParamsResp(tc.expectedResp, upRes), "expect unknown parameters: %v, got %v", tc.expectedResp, upRes) + } + }) + } +} - if isEq := reflect.DeepEqual(tc.expectedResp, upRes); !isEq { - assert.Errorf(t, errors.New("got unexpected unknown parameters"), "expect unknown parameters: %v, got %v", tc.expectedResp, upRes) +func matchUnknownParamsResp(expected []*RequestUnknownParameterError, actual []RequestUnknownParameterError) bool { + for _, expectedValue := range expected { + for _, expectedParam := range expectedValue.Parameters { + var found bool + // search for the same param in the actual resp + for _, actualValue := range actual { + for _, actualParam := range actualValue.Parameters { + if expectedParam.Name == actualParam.Name && + expectedParam.Type == actualParam.Type && + expectedParam.Placeholder == actualParam.Placeholder { + found = true + } } } - }) + if !found { + return false + } + } + } + + for _, actualValue := range actual { + for _, actualParam := range actualValue.Parameters { + var found bool + // search for the same param in the actual resp + for _, expectedValue := range expected { + for _, expectedParam := range expectedValue.Parameters { + if expectedParam.Name == actualParam.Name && + expectedParam.Type == actualParam.Type && + expectedParam.Placeholder == actualParam.Placeholder { + found = true + } + } + } + if !found { + return false + } + } } + + return true } diff --git a/internal/platform/validator/validate_request_test.go b/internal/platform/validator/validate_request_test.go index ce2f093..c42c04b 100644 --- a/internal/platform/validator/validate_request_test.go +++ b/internal/platform/validator/validate_request_test.go @@ -13,6 +13,7 @@ import ( "github.com/getkin/kin-openapi/openapi3filter" "github.com/getkin/kin-openapi/routers" "github.com/getkin/kin-openapi/routers/gorillamux" + legacyrouter "github.com/getkin/kin-openapi/routers/legacy" "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -252,3 +253,304 @@ components: }) } } + +func TestValidateQueryParams(t *testing.T) { + type testCase struct { + name string + param *openapi3.Parameter + query string + want map[string]interface{} + err *openapi3.SchemaError // test ParseError in decoder tests + } + + testCases := []testCase{ + { + name: "deepObject explode additionalProperties with object properties - missing required property", + param: &openapi3.Parameter{ + Name: "param", In: "query", Style: "deepObject", Explode: explode, + Schema: objectOf( + "obj", additionalPropertiesObjectOf(func() *openapi3.SchemaRef { + s := objectOf( + "item1", integerSchema, + "requiredProp", stringSchema, + ) + s.Value.Required = []string{"requiredProp"} + + return s + }()), + "objIgnored", objectOf("items", stringArraySchema), + ), + }, + query: "param[obj][prop1][item1]=1", + err: &openapi3.SchemaError{SchemaField: "required", Reason: "property \"requiredProp\" is missing"}, + }, + { + // XXX should this error out? + name: "deepObject explode additionalProperties with object properties - extraneous nested param property ignored", + param: &openapi3.Parameter{ + Name: "param", In: "query", Style: "deepObject", Explode: explode, + Schema: objectOf( + "obj", additionalPropertiesObjectOf(objectOf( + "item1", integerSchema, + "requiredProp", stringSchema, + )), + "objIgnored", objectOf("items", stringArraySchema), + ), + }, + query: "param[obj][prop1][inexistent]=1", + want: map[string]interface{}{ + "obj": map[string]interface{}{ + "prop1": map[string]interface{}{}, + }, + }, + }, + { + name: "deepObject explode additionalProperties with object properties", + param: &openapi3.Parameter{ + Name: "param", In: "query", Style: "deepObject", Explode: explode, + Schema: objectOf( + "obj", additionalPropertiesObjectOf(objectOf( + "item1", numberSchema, + "requiredProp", stringSchema, + )), + "objIgnored", objectOf("items", stringArraySchema), + ), + }, + query: "param[obj][prop1][item1]=1.123", + want: map[string]interface{}{ + "obj": map[string]interface{}{ + "prop1": map[string]interface{}{ + "item1": float64(1.123), + }, + }, + }, + }, + { + name: "deepObject explode nested objects - misplaced parameter", + param: &openapi3.Parameter{ + Name: "param", In: "query", Style: "deepObject", Explode: explode, + Schema: objectOf( + "obj", objectOf("nestedObjOne", objectOf("items", stringArraySchema)), + ), + }, + query: "param[obj][nestedObjOne]=baz", + err: &openapi3.SchemaError{ + SchemaField: "type", Reason: "value must be an object", Value: "baz", Schema: objectOf("items", stringArraySchema).Value, + }, + }, + { + name: "deepObject explode nested object - extraneous param ignored", + param: &openapi3.Parameter{ + Name: "param", In: "query", Style: "deepObject", Explode: explode, + Schema: objectOf( + "obj", objectOf("nestedObjOne", stringSchema, "nestedObjTwo", stringSchema), + ), + }, + query: "anotherparam=bar", + want: map[string]interface{}(nil), + }, + { + name: "deepObject explode additionalProperties with object properties - multiple properties", + param: &openapi3.Parameter{ + Name: "param", In: "query", Style: "deepObject", Explode: explode, + Schema: objectOf( + "obj", additionalPropertiesObjectOf(objectOf("item1", integerSchema, "item2", stringArraySchema)), + "objIgnored", objectOf("items", stringArraySchema), + ), + }, + query: "param[obj][prop1][item1]=1¶m[obj][prop1][item2][0]=abc¶m[obj][prop2][item1]=2¶m[obj][prop2][item2][0]=def", + want: map[string]interface{}{ + "obj": map[string]interface{}{ + "prop1": map[string]interface{}{ + "item1": int64(1), + "item2": []interface{}{"abc"}, + }, + "prop2": map[string]interface{}{ + "item1": int64(2), + "item2": []interface{}{"def"}, + }, + }, + }, + }, + + // + // + { + name: "deepObject explode nested object anyOf", + param: &openapi3.Parameter{ + Name: "param", In: "query", Style: "deepObject", Explode: explode, + Schema: objectOf( + "obj", anyofSchema, + ), + }, + query: "param[obj]=1", + want: map[string]interface{}{ + "obj": int64(1), + }, + }, + { + name: "deepObject explode nested object allOf", + param: &openapi3.Parameter{ + Name: "param", In: "query", Style: "deepObject", Explode: explode, + Schema: objectOf( + "obj", allofSchema, + ), + }, + query: "param[obj]=1", + want: map[string]interface{}{ + "obj": int64(1), + }, + }, + { + name: "deepObject explode nested object oneOf", + param: &openapi3.Parameter{ + Name: "param", In: "query", Style: "deepObject", Explode: explode, + Schema: objectOf( + "obj", oneofSchema, + ), + }, + query: "param[obj]=true", + want: map[string]interface{}{ + "obj": true, + }, + }, + { + name: "deepObject explode nested object oneOf - object", + param: &openapi3.Parameter{ + Name: "param", In: "query", Style: "deepObject", Explode: explode, + Schema: objectOf( + "obj", oneofSchemaObject, + ), + }, + query: "param[obj][id2]=1¶m[obj][name2]=abc", + want: map[string]interface{}{ + "obj": map[string]interface{}{ + "id2": "1", + "name2": "abc", + }, + }, + }, + { + name: "deepObject explode nested object oneOf - object - more than one match", + param: &openapi3.Parameter{ + Name: "param", In: "query", Style: "deepObject", Explode: explode, + Schema: objectOf( + "obj", oneofSchemaObject, + ), + }, + query: "param[obj][id]=1¶m[obj][id2]=2", + err: &openapi3.SchemaError{ + SchemaField: "oneOf", + Value: map[string]interface{}{"id": "1", "id2": "2"}, + Reason: "value matches more than one schema from \"oneOf\" (matches schemas at indices [0 1])", + Schema: oneofSchemaObject.Value, + }, + }, + { + name: "deepObject explode nested object oneOf - array", + param: &openapi3.Parameter{ + Name: "param", In: "query", Style: "deepObject", Explode: explode, + Schema: objectOf( + "obj", oneofSchemaArrayObject, + ), + }, + query: "param[obj][0]=a¶m[obj][1]=b", + want: map[string]interface{}{ + "obj": []interface{}{ + "a", + "b", + }, + }, + }, + // + // + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + info := &openapi3.Info{ + Title: "MyAPI", + Version: "0.1", + } + doc := &openapi3.T{OpenAPI: "3.0.0", Info: info, Paths: openapi3.NewPaths()} + op := &openapi3.Operation{ + OperationID: "test", + Parameters: []*openapi3.ParameterRef{{Value: tc.param}}, + Responses: openapi3.NewResponses(), + } + doc.AddOperation("/test", http.MethodGet, op) + err := doc.Validate(context.Background()) + require.NoError(t, err) + router, err := legacyrouter.NewRouter(doc) + require.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, "http://test.org/test?"+tc.query, nil) + route, pathParams, err := router.FindRoute(req) + require.NoError(t, err) + + input := &openapi3filter.RequestValidationInput{Request: req, PathParams: pathParams, Route: route} + err = ValidateParameter(context.Background(), input, tc.param) + + if tc.err != nil { + require.Error(t, err) + re, ok := err.(*openapi3filter.RequestError) + if !ok { + t.Errorf("error is not a RequestError") + + return + } + + gErr, ok := re.Unwrap().(*openapi3.SchemaError) + if !ok { + t.Errorf("unknown RequestError wrapped error type") + } + matchSchemaError(t, gErr, tc.err) + + return + } + + require.NoError(t, err) + + got, _, err := decodeStyledParameter(tc.param, input) + require.EqualValues(t, tc.want, got) + }) + } +} + +func matchSchemaError(t *testing.T, got, want error) { + t.Helper() + + wErr, ok := want.(*openapi3.SchemaError) + if !ok { + t.Errorf("want error is not a SchemaError") + return + } + gErr, ok := got.(*openapi3.SchemaError) + if !ok { + t.Errorf("got error is not a SchemaError") + return + } + assert.Equalf(t, wErr.SchemaField, gErr.SchemaField, "SchemaError SchemaField differs") + assert.Equalf(t, wErr.Reason, gErr.Reason, "SchemaError Reason differs") + + if wErr.Schema != nil { + assert.EqualValuesf(t, wErr.Schema, gErr.Schema, "SchemaError Schema differs") + } + if wErr.Value != nil { + assert.EqualValuesf(t, wErr.Value, gErr.Value, "SchemaError Value differs") + } + + if gErr.Origin == nil && wErr.Origin != nil { + t.Errorf("expected error origin but got nothing") + } + if gErr.Origin != nil && wErr.Origin != nil { + switch gErrOrigin := gErr.Origin.(type) { + case *openapi3.SchemaError: + matchSchemaError(t, gErrOrigin, wErr.Origin) + case *ParseError: + matchParseError(t, gErrOrigin, wErr.Origin) + default: + t.Errorf("unknown origin error") + } + } +} From 14ce55d600df7dd3fb4e2748cf3b18209b17e823 Mon Sep 17 00:00:00 2001 From: Nikolay Tkachenko Date: Mon, 15 Apr 2024 02:09:01 +0300 Subject: [PATCH 10/12] Small fixes --- .../internal/handlers/proxy/openapi.go | 4 ++-- .../internal/updater/wallarm_api2_update.db | Bin 98304 -> 98304 bytes .../platform/validator/validate_request.go | 16 ++++++++++++---- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/cmd/api-firewall/internal/handlers/proxy/openapi.go b/cmd/api-firewall/internal/handlers/proxy/openapi.go index a20ad3c..72e8b99 100644 --- a/cmd/api-firewall/internal/handlers/proxy/openapi.go +++ b/cmd/api-firewall/internal/handlers/proxy/openapi.go @@ -261,7 +261,7 @@ func (s *openapiWaf) openapiWafHandler(ctx *fasthttp.RequestCtx) error { ctx.SetUserValue(web.RequestBlocked, true) s.logger.WithFields(logrus.Fields{ - "error": err, + "error": strings.ReplaceAll(err.Error(), "\n", " "), "host": strconv.B2S(ctx.Request.Header.Host()), "path": strconv.B2S(ctx.Path()), "method": strconv.B2S(ctx.Request.Header.Method()), @@ -311,7 +311,7 @@ func (s *openapiWaf) openapiWafHandler(ctx *fasthttp.RequestCtx) error { case web.ValidationLog: if err := validator.ValidateRequest(ctx, requestValidationInput, jsonParser); err != nil { s.logger.WithFields(logrus.Fields{ - "error": err, + "error": strings.ReplaceAll(err.Error(), "\n", " "), "host": strconv.B2S(ctx.Request.Header.Host()), "path": strconv.B2S(ctx.Path()), "method": strconv.B2S(ctx.Request.Header.Method()), diff --git a/cmd/api-firewall/internal/updater/wallarm_api2_update.db b/cmd/api-firewall/internal/updater/wallarm_api2_update.db index 0c250b0af094360612e4ab070d28c9a942e3ac1b..776fc409fb101e1014afb6634ee3be001eb229df 100644 GIT binary patch delta 34 qcmZo@U~6b#n;^~jcA|_kCxn|I5)2r>R@GFWE77ytm;{|r0; delta 34 qcmZo@U~6b#n;^}2d!mdpCBn|I5)2r+(XGFWE77ytm*R}2gQ diff --git a/internal/platform/validator/validate_request.go b/internal/platform/validator/validate_request.go index 8e424bc..62f6ac2 100644 --- a/internal/platform/validator/validate_request.go +++ b/internal/platform/validator/validate_request.go @@ -79,6 +79,9 @@ func ValidateRequest(ctx context.Context, input *openapi3filter.RequestValidatio // For each parameter of the Operation for _, parameter := range operationParameters { + if options.ExcludeRequestQueryParams && parameter.Value.In == openapi3.ParameterInQuery { + continue + } if err = ValidateParameter(ctx, input, parameter.Value); err != nil && !options.MultiError { return } @@ -289,7 +292,7 @@ func ValidateRequestBody(ctx context.Context, input *openapi3filter.RequestValid } defaultsSet := false - opts := make([]openapi3.SchemaValidationOption, 0, 3) // 3 potential opts here + opts := make([]openapi3.SchemaValidationOption, 0, 4) // 4 potential opts here opts = append(opts, openapi3.VisitAsRequest()) if !options.SkipSettingDefaults { opts = append(opts, openapi3.DefaultsSet(func() { defaultsSet = true })) @@ -297,6 +300,9 @@ func ValidateRequestBody(ctx context.Context, input *openapi3filter.RequestValid if options.MultiError { opts = append(opts, openapi3.MultiErrors()) } + if options.ExcludeReadOnlyValidations { + opts = append(opts, openapi3.DisableReadOnlyValidation()) + } // Validate JSON with the schema if err := contentType.Schema.Value.VisitJSON(value, opts...); err != nil { @@ -359,9 +365,6 @@ func ValidateSecurityRequirements(ctx context.Context, input *openapi3filter.Req // validateSecurityRequirement validates a single OpenAPI 3 security requirement func validateSecurityRequirement(ctx context.Context, input *openapi3filter.RequestValidationInput, securityRequirement openapi3.SecurityRequirement) error { - doc := input.Route.Spec - securitySchemes := doc.Components.SecuritySchemes - // Ensure deterministic order names := make([]string, 0, len(securityRequirement)) for name := range securityRequirement { @@ -379,6 +382,11 @@ func validateSecurityRequirement(ctx context.Context, input *openapi3filter.Requ return ErrAuthenticationServiceMissing } + var securitySchemes openapi3.SecuritySchemes + if components := input.Route.Spec.Components; components != nil { + securitySchemes = components.SecuritySchemes + } + // For each scheme for the requirement for _, name := range names { var securityScheme *openapi3.SecurityScheme From b3cbe94cb8c5d53f802c0541aaefa0284a490da4 Mon Sep 17 00:00:00 2001 From: Nikolay Tkachenko Date: Mon, 15 Apr 2024 10:44:21 +0300 Subject: [PATCH 11/12] Small log msg update --- .../internal/updater/wallarm_api2_update.db | Bin 98304 -> 98304 bytes cmd/api-firewall/main.go | 6 +++--- cmd/api-firewall/tests/main_api_mode_test.go | 4 ++-- cmd/api-firewall/tests/main_json_test.go | 4 ++-- cmd/api-firewall/tests/main_modsec_test.go | 4 ++-- cmd/api-firewall/tests/main_test.go | 4 ++-- 6 files changed, 11 insertions(+), 11 deletions(-) diff --git a/cmd/api-firewall/internal/updater/wallarm_api2_update.db b/cmd/api-firewall/internal/updater/wallarm_api2_update.db index 776fc409fb101e1014afb6634ee3be001eb229df..bd3eb5280c2b00bc140b92eb332907efbf1b9f6d 100644 GIT binary patch delta 34 qcmZo@U~6b#n;^~jZK8}bD+n|I5)2r>R`GFWE77ytm=(+pGq delta 34 qcmZo@U~6b#n;^~jcA|_kCxn|I5)2r>R@GFWE77ytm;{|r0; diff --git a/cmd/api-firewall/main.go b/cmd/api-firewall/main.go index 01131eb..2d78895 100644 --- a/cmd/api-firewall/main.go +++ b/cmd/api-firewall/main.go @@ -751,18 +751,18 @@ func runProxyMode(logger *logrus.Logger) error { case nil: swagger, err = openapi3.NewLoader().LoadFromFile(cfg.APISpecs) if err != nil { - return errors.Wrap(err, "loading swagwaf file") + return errors.Wrap(err, "loading OpenAPI specification from file") } default: swagger, err = openapi3.NewLoader().LoadFromURI(apiSpecURL) if err != nil { - return errors.Wrap(err, "loading swagwaf url") + return errors.Wrap(err, "loading OpenAPI specification from URL") } } swagRouter, err := loader.NewRouter(swagger, true) if err != nil { - return errors.Wrap(err, "parsing swagwaf file") + return errors.Wrap(err, "parsing OpenAPI specification") } // ========================================================================= diff --git a/cmd/api-firewall/tests/main_api_mode_test.go b/cmd/api-firewall/tests/main_api_mode_test.go index ea95512..2449427 100644 --- a/cmd/api-firewall/tests/main_api_mode_test.go +++ b/cmd/api-firewall/tests/main_api_mode_test.go @@ -522,12 +522,12 @@ func TestAPIModeBasic(t *testing.T) { swagger, err := openapi3.NewLoader().LoadFromData([]byte(apiModeOpenAPISpecAPIModeTest)) if err != nil { - t.Fatalf("loading swagwaf file: %s", err.Error()) + t.Fatalf("loading OpenAPI specification file: %s", err.Error()) } secondSwagger, err := openapi3.NewLoader().LoadFromData([]byte(secondApiModeOpenAPISpecAPIModeTest)) if err != nil { - t.Fatalf("loading swagwaf file: %s", err.Error()) + t.Fatalf("loading OpenAPI specification file: %s", err.Error()) } dbSpec.EXPECT().SchemaIDs().Return([]int{DefaultSchemaID, DefaultCopySchemaID, SecondSchemaID}).AnyTimes() diff --git a/cmd/api-firewall/tests/main_json_test.go b/cmd/api-firewall/tests/main_json_test.go index c1c5538..8087f54 100644 --- a/cmd/api-firewall/tests/main_json_test.go +++ b/cmd/api-firewall/tests/main_json_test.go @@ -102,12 +102,12 @@ func TestJSONBasic(t *testing.T) { swagger, err := openapi3.NewLoader().LoadFromData([]byte(openAPIJSONSpecTest)) if err != nil { - t.Fatalf("loading swagwaf file: %s", err.Error()) + t.Fatalf("loading OpenAPI specification file: %s", err.Error()) } swagRouter, err := loader.NewRouter(swagger, true) if err != nil { - t.Fatalf("parsing swagwaf file: %s", err.Error()) + t.Fatalf("parsing OpenAPI specification file: %s", err.Error()) } shutdown := make(chan os.Signal, 1) diff --git a/cmd/api-firewall/tests/main_modsec_test.go b/cmd/api-firewall/tests/main_modsec_test.go index ff72f53..e7a73cf 100644 --- a/cmd/api-firewall/tests/main_modsec_test.go +++ b/cmd/api-firewall/tests/main_modsec_test.go @@ -122,12 +122,12 @@ func TestModSec(t *testing.T) { swagger, err := openapi3.NewLoader().LoadFromData([]byte(openAPISpecModSecTest)) if err != nil { - t.Fatalf("loading swagwaf file: %s", err.Error()) + t.Fatalf("loading OpenAPI specification file: %s", err.Error()) } swagRouter, err := loader.NewRouter(swagger, true) if err != nil { - t.Fatalf("parsing swagwaf file: %s", err.Error()) + t.Fatalf("parsing OpenAPI specification file: %s", err.Error()) } shutdown := make(chan os.Signal, 1) diff --git a/cmd/api-firewall/tests/main_test.go b/cmd/api-firewall/tests/main_test.go index 3da4ba0..36fb91d 100644 --- a/cmd/api-firewall/tests/main_test.go +++ b/cmd/api-firewall/tests/main_test.go @@ -424,12 +424,12 @@ func TestBasic(t *testing.T) { swagger, err := openapi3.NewLoader().LoadFromData([]byte(openAPISpecTest)) if err != nil { - t.Fatalf("loading swagwaf file: %s", err.Error()) + t.Fatalf("loading OpenAPI specification file: %s", err.Error()) } swagRouter, err := loader.NewRouter(swagger, true) if err != nil { - t.Fatalf("parsing swagwaf file: %s", err.Error()) + t.Fatalf("parsing OpenAPI specification file: %s", err.Error()) } shutdown := make(chan os.Signal, 1) From 99292362ddcc691368199c18b925d2a4c79b1d6e Mon Sep 17 00:00:00 2001 From: Nikolay Tkachenko Date: Mon, 15 Apr 2024 14:32:24 +0300 Subject: [PATCH 12/12] Add panic handlers to the updated and main handler --- .../internal/handlers/api/openapi.go | 25 +++++++++++++++++- cmd/api-firewall/internal/updater/updater.go | 13 +++++++++ .../internal/updater/wallarm_api2_update.db | Bin 98304 -> 98304 bytes 3 files changed, 37 insertions(+), 1 deletion(-) diff --git a/cmd/api-firewall/internal/handlers/api/openapi.go b/cmd/api-firewall/internal/handlers/api/openapi.go index 334526f..cb8a7bc 100644 --- a/cmd/api-firewall/internal/handlers/api/openapi.go +++ b/cmd/api-firewall/internal/handlers/api/openapi.go @@ -3,8 +3,8 @@ package api import ( "context" "fmt" - "github.com/wallarm/api-firewall/internal/platform/router" "net/http" + "runtime/debug" strconv2 "strconv" "strings" "sync" @@ -18,6 +18,7 @@ import ( "github.com/valyala/fastjson" "github.com/wallarm/api-firewall/internal/config" "github.com/wallarm/api-firewall/internal/platform/loader" + "github.com/wallarm/api-firewall/internal/platform/router" "github.com/wallarm/api-firewall/internal/platform/validator" "github.com/wallarm/api-firewall/internal/platform/web" ) @@ -87,6 +88,17 @@ type RequestValidator struct { // Handler validates request and respond with 200, 403 (with error) or 500 status code func (s *RequestValidator) Handler(ctx *fasthttp.RequestCtx) error { + // handle panic + defer func() { + if r := recover(); r != nil { + s.Log.Errorf("panic: %v", r) + + // Log the Go stack trace for this panic'd goroutine. + s.Log.Debugf("%s", debug.Stack()) + return + } + }() + keyValidationErrors := strconv2.Itoa(s.SchemaID) + web.APIModePostfixValidationErrors keyStatusCode := strconv2.Itoa(s.SchemaID) + web.APIModePostfixStatusCode @@ -158,6 +170,17 @@ func (s *RequestValidator) Handler(ctx *fasthttp.RequestCtx) error { go func() { defer wg.Done() + // handle panic + defer func() { + if r := recover(); r != nil { + s.Log.Errorf("panic: %v", r) + + // Log the Go stack trace for this panic'd goroutine. + s.Log.Debugf("%s", debug.Stack()) + return + } + }() + // Get fastjson parser jsonParser := s.ParserPool.Get() defer s.ParserPool.Put(jsonParser) diff --git a/cmd/api-firewall/internal/updater/updater.go b/cmd/api-firewall/internal/updater/updater.go index b0c36e4..27c9cb7 100644 --- a/cmd/api-firewall/internal/updater/updater.go +++ b/cmd/api-firewall/internal/updater/updater.go @@ -2,6 +2,7 @@ package updater import ( "os" + "runtime/debug" "sync" "time" @@ -53,6 +54,18 @@ func NewController(lock *sync.RWMutex, logger *logrus.Logger, sqlLiteStorage dat // Run function performs update of the specification func (s *Specification) Run() { + + // handle panic + defer func() { + if r := recover(); r != nil { + s.logger.Errorf("panic: %v", r) + + // Log the Go stack trace for this panic'd goroutine. + s.logger.Debugf("%s", debug.Stack()) + return + } + }() + updateTicker := time.NewTicker(s.updateTime) for { select { diff --git a/cmd/api-firewall/internal/updater/wallarm_api2_update.db b/cmd/api-firewall/internal/updater/wallarm_api2_update.db index bd3eb5280c2b00bc140b92eb332907efbf1b9f6d..48f2d9351e8eafb01c509a24bbc39b7526753d03 100644 GIT binary patch delta 38 scmZo@U~6b#n;^x+%rQ~M2}o{Ch?He&V&1%4)>(*&k-5oWnE_(}0KpFljQ{`u delta 38 scmZo@U~6b#n;^x+^lhSy6Oi1P5Gl*lSi5<*tg{dk!`~)@Wd@7^0R1}**#H0l