diff --git a/cmd/api-firewall/tests/wallarm_api2_update.db b/cmd/api-firewall/tests/wallarm_api2_update.db index 9860857..1111e92 100644 Binary files a/cmd/api-firewall/tests/wallarm_api2_update.db and b/cmd/api-firewall/tests/wallarm_api2_update.db differ diff --git a/demo/interface/api-mode/main.go b/demo/interface/api-mode/main.go index a90a4f5..b43f883 100644 --- a/demo/interface/api-mode/main.go +++ b/demo/interface/api-mode/main.go @@ -4,7 +4,6 @@ import ( "bufio" "bytes" "encoding/json" - strconv2 "github.com/savsgio/gotils/strconv" "net/http" "os" "os/signal" @@ -13,6 +12,7 @@ import ( "time" "github.com/pkg/errors" + strconv2 "github.com/savsgio/gotils/strconv" "github.com/sirupsen/logrus" "github.com/valyala/fasthttp" "github.com/wallarm/api-firewall/demo/interface/api-mode/internal/updater" @@ -69,7 +69,7 @@ func main() { headers.Set(sk, sv) }) - result, err := apiFirewall.ValidateRequest(schemaID, ctx.Request.Header.RequestURI(), ctx.Request.Header.Host(), ctx.Request.Header.Method(), ctx.Request.Body(), headers) + result, err := apiFirewall.ValidateRequest(schemaID, ctx.Request.Header.RequestURI(), ctx.Request.Header.Method(), ctx.Request.Body(), headers) if err != nil { logger.Error(err) } diff --git a/pkg/apifw/apifw.go b/pkg/apifw/apifw.go index 9720158..d80c00f 100644 --- a/pkg/apifw/apifw.go +++ b/pkg/apifw/apifw.go @@ -31,7 +31,7 @@ var ( type APIFirewall interface { ValidateRequestFromReader(schemaID int, r *bufio.Reader) (*web.APIModeResponse, error) - ValidateRequest(schemaID int, uri, host, method, body []byte, headers map[string][]string) (*web.APIModeResponse, error) + ValidateRequest(schemaID int, uri, method, body []byte, headers map[string][]string) (*web.APIModeResponse, error) UpdateSpecsStorage() ([]int, bool, error) } @@ -164,13 +164,12 @@ func (a *APIMode) UpdateSpecsStorage() ([]int, bool, error) { } // ValidateRequest method validates request against the spec with provided schema ID -func (a *APIMode) ValidateRequest(schemaID int, uri, host, method, body []byte, headers map[string][]string) (*web.APIModeResponse, error) { +func (a *APIMode) ValidateRequest(schemaID int, uri, method, body []byte, headers map[string][]string) (*web.APIModeResponse, error) { // build fasthttp RequestCTX ctx := new(fasthttp.RequestCtx) ctx.Request.Header.SetRequestURIBytes(uri) - ctx.Request.Header.SetHostBytes(host) ctx.Request.Header.SetMethodBytes(method) ctx.Request.SetBody(body) diff --git a/pkg/apifw/apifw_test.go b/pkg/apifw/apifw_test.go index 4a788c4..d97b026 100644 --- a/pkg/apifw/apifw_test.go +++ b/pkg/apifw/apifw_test.go @@ -83,7 +83,7 @@ func validate200req(t *testing.T, apifw APIFirewall, schemaID int) { ctx.Request.Header.SetMethod("GET") ctx.Request.Header.SetHost("localhost") - res, err := apifw.ValidateRequest(schemaID, ctx.Request.Header.RequestURI(), ctx.Request.Header.Host(), ctx.Request.Header.Method(), ctx.Request.Body(), http.Header{}) + res, err := apifw.ValidateRequest(schemaID, ctx.Request.Header.RequestURI(), ctx.Request.Header.Method(), ctx.Request.Body(), http.Header{}) if err != nil { t.Error(err) } @@ -111,7 +111,7 @@ func validate403WrongMethodReq(t *testing.T, apifw APIFirewall, schemaID int) { ctx.Request.Header.SetMethod("PUT") ctx.Request.Header.SetHost("localhost") - res, err := apifw.ValidateRequest(schemaID, ctx.Request.Header.RequestURI(), ctx.Request.Header.Host(), ctx.Request.Header.Method(), ctx.Request.Body(), http.Header{}) + res, err := apifw.ValidateRequest(schemaID, ctx.Request.Header.RequestURI(), ctx.Request.Header.Method(), ctx.Request.Body(), http.Header{}) if err != nil { t.Error(err) } @@ -139,7 +139,7 @@ func validate403UnknownParamReq(t *testing.T, apifw APIFirewall, schemaID int) { ctx.Request.Header.SetMethod("GET") ctx.Request.Header.SetHost("localhost") - res, err := apifw.ValidateRequest(schemaID, ctx.Request.Header.RequestURI(), ctx.Request.Header.Host(), ctx.Request.Header.Method(), ctx.Request.Body(), http.Header{}) + res, err := apifw.ValidateRequest(schemaID, ctx.Request.Header.RequestURI(), ctx.Request.Header.Method(), ctx.Request.Body(), http.Header{}) if err != nil { t.Error(err) } @@ -167,7 +167,7 @@ func validate403RequiredParamMissedReq(t *testing.T, apifw APIFirewall, schemaID ctx.Request.Header.SetMethod("GET") ctx.Request.Header.SetHost("localhost") - res, err := apifw.ValidateRequest(schemaID, ctx.Request.Header.RequestURI(), ctx.Request.Header.Host(), ctx.Request.Header.Method(), ctx.Request.Body(), http.Header{}) + res, err := apifw.ValidateRequest(schemaID, ctx.Request.Header.RequestURI(), ctx.Request.Header.Method(), ctx.Request.Body(), http.Header{}) if err != nil { t.Error(err) } @@ -205,7 +205,7 @@ func validate500UnknownCTReq(t *testing.T, apifw APIFirewall, schemaID int) { headers.Set(sk, sv) }) - res, err := apifw.ValidateRequest(schemaID, ctx.Request.Header.RequestURI(), ctx.Request.Header.Host(), ctx.Request.Header.Method(), ctx.Request.Body(), headers) + res, err := apifw.ValidateRequest(schemaID, ctx.Request.Header.RequestURI(), ctx.Request.Header.Method(), ctx.Request.Body(), headers) if !errors.Is(err, ErrRequestParsing) { t.Error(err) } @@ -233,7 +233,7 @@ func validate200OptionsReq(t *testing.T, apifw APIFirewall, schemaID int) { ctx.Request.Header.SetMethod("OPTIONS") ctx.Request.Header.SetHost("localhost") - res, err := apifw.ValidateRequest(schemaID, ctx.Request.Header.RequestURI(), ctx.Request.Header.Host(), ctx.Request.Header.Method(), ctx.Request.Body(), http.Header{}) + res, err := apifw.ValidateRequest(schemaID, ctx.Request.Header.RequestURI(), ctx.Request.Header.Method(), ctx.Request.Body(), http.Header{}) if err != nil { t.Error(err) }