Skip to content

Commit

Permalink
Update ValidateRequest method
Browse files Browse the repository at this point in the history
  • Loading branch information
Nikolay Tkachenko committed May 7, 2024
1 parent 0167657 commit 5ef0449
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 11 deletions.
Binary file modified cmd/api-firewall/tests/wallarm_api2_update.db
Binary file not shown.
4 changes: 2 additions & 2 deletions demo/interface/api-mode/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"bufio"
"bytes"
"encoding/json"
strconv2 "github.com/savsgio/gotils/strconv"
"net/http"
"os"
"os/signal"
Expand All @@ -13,6 +12,7 @@ import (
"time"

"github.com/pkg/errors"
strconv2 "github.com/savsgio/gotils/strconv"
"github.com/sirupsen/logrus"
"github.com/valyala/fasthttp"
"github.com/wallarm/api-firewall/demo/interface/api-mode/internal/updater"
Expand Down Expand Up @@ -69,7 +69,7 @@ func main() {
headers.Set(sk, sv)
})

result, err := apiFirewall.ValidateRequest(schemaID, ctx.Request.Header.RequestURI(), ctx.Request.Header.Host(), ctx.Request.Header.Method(), ctx.Request.Body(), headers)
result, err := apiFirewall.ValidateRequest(schemaID, ctx.Request.Header.RequestURI(), ctx.Request.Header.Method(), ctx.Request.Body(), headers)
if err != nil {
logger.Error(err)
}
Expand Down
5 changes: 2 additions & 3 deletions pkg/apifw/apifw.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ var (

type APIFirewall interface {
ValidateRequestFromReader(schemaID int, r *bufio.Reader) (*web.APIModeResponse, error)
ValidateRequest(schemaID int, uri, host, method, body []byte, headers map[string][]string) (*web.APIModeResponse, error)
ValidateRequest(schemaID int, uri, method, body []byte, headers map[string][]string) (*web.APIModeResponse, error)
UpdateSpecsStorage() ([]int, bool, error)
}

Expand Down Expand Up @@ -164,13 +164,12 @@ func (a *APIMode) UpdateSpecsStorage() ([]int, bool, error) {
}

// ValidateRequest method validates request against the spec with provided schema ID
func (a *APIMode) ValidateRequest(schemaID int, uri, host, method, body []byte, headers map[string][]string) (*web.APIModeResponse, error) {
func (a *APIMode) ValidateRequest(schemaID int, uri, method, body []byte, headers map[string][]string) (*web.APIModeResponse, error) {

// build fasthttp RequestCTX
ctx := new(fasthttp.RequestCtx)

ctx.Request.Header.SetRequestURIBytes(uri)
ctx.Request.Header.SetHostBytes(host)
ctx.Request.Header.SetMethodBytes(method)
ctx.Request.SetBody(body)

Expand Down
12 changes: 6 additions & 6 deletions pkg/apifw/apifw_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func validate200req(t *testing.T, apifw APIFirewall, schemaID int) {
ctx.Request.Header.SetMethod("GET")
ctx.Request.Header.SetHost("localhost")

res, err := apifw.ValidateRequest(schemaID, ctx.Request.Header.RequestURI(), ctx.Request.Header.Host(), ctx.Request.Header.Method(), ctx.Request.Body(), http.Header{})
res, err := apifw.ValidateRequest(schemaID, ctx.Request.Header.RequestURI(), ctx.Request.Header.Method(), ctx.Request.Body(), http.Header{})
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -111,7 +111,7 @@ func validate403WrongMethodReq(t *testing.T, apifw APIFirewall, schemaID int) {
ctx.Request.Header.SetMethod("PUT")
ctx.Request.Header.SetHost("localhost")

res, err := apifw.ValidateRequest(schemaID, ctx.Request.Header.RequestURI(), ctx.Request.Header.Host(), ctx.Request.Header.Method(), ctx.Request.Body(), http.Header{})
res, err := apifw.ValidateRequest(schemaID, ctx.Request.Header.RequestURI(), ctx.Request.Header.Method(), ctx.Request.Body(), http.Header{})
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -139,7 +139,7 @@ func validate403UnknownParamReq(t *testing.T, apifw APIFirewall, schemaID int) {
ctx.Request.Header.SetMethod("GET")
ctx.Request.Header.SetHost("localhost")

res, err := apifw.ValidateRequest(schemaID, ctx.Request.Header.RequestURI(), ctx.Request.Header.Host(), ctx.Request.Header.Method(), ctx.Request.Body(), http.Header{})
res, err := apifw.ValidateRequest(schemaID, ctx.Request.Header.RequestURI(), ctx.Request.Header.Method(), ctx.Request.Body(), http.Header{})
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -167,7 +167,7 @@ func validate403RequiredParamMissedReq(t *testing.T, apifw APIFirewall, schemaID
ctx.Request.Header.SetMethod("GET")
ctx.Request.Header.SetHost("localhost")

res, err := apifw.ValidateRequest(schemaID, ctx.Request.Header.RequestURI(), ctx.Request.Header.Host(), ctx.Request.Header.Method(), ctx.Request.Body(), http.Header{})
res, err := apifw.ValidateRequest(schemaID, ctx.Request.Header.RequestURI(), ctx.Request.Header.Method(), ctx.Request.Body(), http.Header{})
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -205,7 +205,7 @@ func validate500UnknownCTReq(t *testing.T, apifw APIFirewall, schemaID int) {
headers.Set(sk, sv)
})

res, err := apifw.ValidateRequest(schemaID, ctx.Request.Header.RequestURI(), ctx.Request.Header.Host(), ctx.Request.Header.Method(), ctx.Request.Body(), headers)
res, err := apifw.ValidateRequest(schemaID, ctx.Request.Header.RequestURI(), ctx.Request.Header.Method(), ctx.Request.Body(), headers)
if !errors.Is(err, ErrRequestParsing) {
t.Error(err)
}
Expand Down Expand Up @@ -233,7 +233,7 @@ func validate200OptionsReq(t *testing.T, apifw APIFirewall, schemaID int) {
ctx.Request.Header.SetMethod("OPTIONS")
ctx.Request.Header.SetHost("localhost")

res, err := apifw.ValidateRequest(schemaID, ctx.Request.Header.RequestURI(), ctx.Request.Header.Host(), ctx.Request.Header.Method(), ctx.Request.Body(), http.Header{})
res, err := apifw.ValidateRequest(schemaID, ctx.Request.Header.RequestURI(), ctx.Request.Header.Method(), ctx.Request.Body(), http.Header{})
if err != nil {
t.Error(err)
}
Expand Down

0 comments on commit 5ef0449

Please sign in to comment.