From 5ef04491f356594aeaef52d0d7c574946d1de6f0 Mon Sep 17 00:00:00 2001 From: Nikolay Tkachenko Date: Tue, 7 May 2024 13:42:37 +0300 Subject: [PATCH] Update ValidateRequest method --- cmd/api-firewall/tests/wallarm_api2_update.db | Bin 98304 -> 98304 bytes demo/interface/api-mode/main.go | 4 ++-- pkg/apifw/apifw.go | 5 ++--- pkg/apifw/apifw_test.go | 12 ++++++------ 4 files changed, 10 insertions(+), 11 deletions(-) diff --git a/cmd/api-firewall/tests/wallarm_api2_update.db b/cmd/api-firewall/tests/wallarm_api2_update.db index 98608578c937293ebdaa1b4a5fd6daeea17d13bc..1111e92ee9593bceef544d421480e5c5705bdb91 100644 GIT binary patch delta 34 pcmZo@U~6b#n;^{?F;T{uF=Asvlq_TM=H0R`LX5>t2FnZ>0|30c3WNXv delta 34 pcmZo@U~6b#n;^~TKT*b+(SKt?lq_Sx=H0R`LW~7X2FnZ>0|2_`3TyxX diff --git a/demo/interface/api-mode/main.go b/demo/interface/api-mode/main.go index a90a4f5..b43f883 100644 --- a/demo/interface/api-mode/main.go +++ b/demo/interface/api-mode/main.go @@ -4,7 +4,6 @@ import ( "bufio" "bytes" "encoding/json" - strconv2 "github.com/savsgio/gotils/strconv" "net/http" "os" "os/signal" @@ -13,6 +12,7 @@ import ( "time" "github.com/pkg/errors" + strconv2 "github.com/savsgio/gotils/strconv" "github.com/sirupsen/logrus" "github.com/valyala/fasthttp" "github.com/wallarm/api-firewall/demo/interface/api-mode/internal/updater" @@ -69,7 +69,7 @@ func main() { headers.Set(sk, sv) }) - result, err := apiFirewall.ValidateRequest(schemaID, ctx.Request.Header.RequestURI(), ctx.Request.Header.Host(), ctx.Request.Header.Method(), ctx.Request.Body(), headers) + result, err := apiFirewall.ValidateRequest(schemaID, ctx.Request.Header.RequestURI(), ctx.Request.Header.Method(), ctx.Request.Body(), headers) if err != nil { logger.Error(err) } diff --git a/pkg/apifw/apifw.go b/pkg/apifw/apifw.go index 9720158..d80c00f 100644 --- a/pkg/apifw/apifw.go +++ b/pkg/apifw/apifw.go @@ -31,7 +31,7 @@ var ( type APIFirewall interface { ValidateRequestFromReader(schemaID int, r *bufio.Reader) (*web.APIModeResponse, error) - ValidateRequest(schemaID int, uri, host, method, body []byte, headers map[string][]string) (*web.APIModeResponse, error) + ValidateRequest(schemaID int, uri, method, body []byte, headers map[string][]string) (*web.APIModeResponse, error) UpdateSpecsStorage() ([]int, bool, error) } @@ -164,13 +164,12 @@ func (a *APIMode) UpdateSpecsStorage() ([]int, bool, error) { } // ValidateRequest method validates request against the spec with provided schema ID -func (a *APIMode) ValidateRequest(schemaID int, uri, host, method, body []byte, headers map[string][]string) (*web.APIModeResponse, error) { +func (a *APIMode) ValidateRequest(schemaID int, uri, method, body []byte, headers map[string][]string) (*web.APIModeResponse, error) { // build fasthttp RequestCTX ctx := new(fasthttp.RequestCtx) ctx.Request.Header.SetRequestURIBytes(uri) - ctx.Request.Header.SetHostBytes(host) ctx.Request.Header.SetMethodBytes(method) ctx.Request.SetBody(body) diff --git a/pkg/apifw/apifw_test.go b/pkg/apifw/apifw_test.go index 4a788c4..d97b026 100644 --- a/pkg/apifw/apifw_test.go +++ b/pkg/apifw/apifw_test.go @@ -83,7 +83,7 @@ func validate200req(t *testing.T, apifw APIFirewall, schemaID int) { ctx.Request.Header.SetMethod("GET") ctx.Request.Header.SetHost("localhost") - res, err := apifw.ValidateRequest(schemaID, ctx.Request.Header.RequestURI(), ctx.Request.Header.Host(), ctx.Request.Header.Method(), ctx.Request.Body(), http.Header{}) + res, err := apifw.ValidateRequest(schemaID, ctx.Request.Header.RequestURI(), ctx.Request.Header.Method(), ctx.Request.Body(), http.Header{}) if err != nil { t.Error(err) } @@ -111,7 +111,7 @@ func validate403WrongMethodReq(t *testing.T, apifw APIFirewall, schemaID int) { ctx.Request.Header.SetMethod("PUT") ctx.Request.Header.SetHost("localhost") - res, err := apifw.ValidateRequest(schemaID, ctx.Request.Header.RequestURI(), ctx.Request.Header.Host(), ctx.Request.Header.Method(), ctx.Request.Body(), http.Header{}) + res, err := apifw.ValidateRequest(schemaID, ctx.Request.Header.RequestURI(), ctx.Request.Header.Method(), ctx.Request.Body(), http.Header{}) if err != nil { t.Error(err) } @@ -139,7 +139,7 @@ func validate403UnknownParamReq(t *testing.T, apifw APIFirewall, schemaID int) { ctx.Request.Header.SetMethod("GET") ctx.Request.Header.SetHost("localhost") - res, err := apifw.ValidateRequest(schemaID, ctx.Request.Header.RequestURI(), ctx.Request.Header.Host(), ctx.Request.Header.Method(), ctx.Request.Body(), http.Header{}) + res, err := apifw.ValidateRequest(schemaID, ctx.Request.Header.RequestURI(), ctx.Request.Header.Method(), ctx.Request.Body(), http.Header{}) if err != nil { t.Error(err) } @@ -167,7 +167,7 @@ func validate403RequiredParamMissedReq(t *testing.T, apifw APIFirewall, schemaID ctx.Request.Header.SetMethod("GET") ctx.Request.Header.SetHost("localhost") - res, err := apifw.ValidateRequest(schemaID, ctx.Request.Header.RequestURI(), ctx.Request.Header.Host(), ctx.Request.Header.Method(), ctx.Request.Body(), http.Header{}) + res, err := apifw.ValidateRequest(schemaID, ctx.Request.Header.RequestURI(), ctx.Request.Header.Method(), ctx.Request.Body(), http.Header{}) if err != nil { t.Error(err) } @@ -205,7 +205,7 @@ func validate500UnknownCTReq(t *testing.T, apifw APIFirewall, schemaID int) { headers.Set(sk, sv) }) - res, err := apifw.ValidateRequest(schemaID, ctx.Request.Header.RequestURI(), ctx.Request.Header.Host(), ctx.Request.Header.Method(), ctx.Request.Body(), headers) + res, err := apifw.ValidateRequest(schemaID, ctx.Request.Header.RequestURI(), ctx.Request.Header.Method(), ctx.Request.Body(), headers) if !errors.Is(err, ErrRequestParsing) { t.Error(err) } @@ -233,7 +233,7 @@ func validate200OptionsReq(t *testing.T, apifw APIFirewall, schemaID int) { ctx.Request.Header.SetMethod("OPTIONS") ctx.Request.Header.SetHost("localhost") - res, err := apifw.ValidateRequest(schemaID, ctx.Request.Header.RequestURI(), ctx.Request.Header.Host(), ctx.Request.Header.Method(), ctx.Request.Body(), http.Header{}) + res, err := apifw.ValidateRequest(schemaID, ctx.Request.Header.RequestURI(), ctx.Request.Header.Method(), ctx.Request.Body(), http.Header{}) if err != nil { t.Error(err) }