From 810e9ed15e6ea12556a41800dd958d639b3a9463 Mon Sep 17 00:00:00 2001 From: Johannes Koch Date: Fri, 9 Oct 2020 18:29:34 +0200 Subject: [PATCH] extracted OpenAPI validator (#21) --- handler/openapi_validator.go | 84 +++++++++++++++++++++++++++++++ handler/proxy.go | 97 ++++++------------------------------ handler/proxy_options.go | 9 +++- handler/proxy_test.go | 19 ++++--- 4 files changed, 117 insertions(+), 92 deletions(-) create mode 100644 handler/openapi_validator.go diff --git a/handler/openapi_validator.go b/handler/openapi_validator.go new file mode 100644 index 000000000..3dfdf36ae --- /dev/null +++ b/handler/openapi_validator.go @@ -0,0 +1,84 @@ +package handler + +import ( + "context" + "io/ioutil" + "net/http" + "os" + + "github.com/avenga/couper/config" + "github.com/getkin/kin-openapi/openapi3filter" +) + +type OpenAPIValidatorFactory struct { + router *openapi3filter.Router + ignoreRequestViolations bool + ignoreResponseViolations bool +} + +func NewOpenAPIValidatorFactory(openapi *config.OpenAPI) (*OpenAPIValidatorFactory, error) { + if openapi == nil { + return nil, nil + } + dir, err := os.Getwd() + if err != nil { + return nil, err + } + return &OpenAPIValidatorFactory{ + router: openapi3filter.NewRouter().WithSwaggerFromFile(dir + "/" + openapi.File), + ignoreRequestViolations: openapi.IgnoreRequestViolations, + ignoreResponseViolations: openapi.IgnoreResponseViolations, + }, nil +} + +func (f *OpenAPIValidatorFactory) NewOpenAPIValidator() *OpenAPIValidator { + return &OpenAPIValidator{ + factory: f, + validationCtx: context.Background(), + } +} + +type OpenAPIValidator struct { + factory *OpenAPIValidatorFactory + route *openapi3filter.Route + requestValidationInput *openapi3filter.RequestValidationInput + validationCtx context.Context + Body []byte +} + +func (v *OpenAPIValidator) ValidateRequest(req *http.Request) (bool, error) { + route, pathParams, _ := v.factory.router.FindRoute(req.Method, req.URL) + v.route = route + + v.requestValidationInput = &openapi3filter.RequestValidationInput{ + Request: req, + PathParams: pathParams, + Route: route, + } + + return v.factory.ignoreRequestViolations, openapi3filter.ValidateRequest(v.validationCtx, v.requestValidationInput) +} + +func (v *OpenAPIValidator) ValidateResponse(res *http.Response) (bool, error) { + if v.route == nil { + return v.factory.ignoreResponseViolations, nil + } + responseValidationInput := &openapi3filter.ResponseValidationInput{ + RequestValidationInput: v.requestValidationInput, + Status: res.StatusCode, + Header: res.Header, + Options: &openapi3filter.Options{IncludeResponseStatus: true /* undefined response codes are invalid */}, + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + return v.factory.ignoreResponseViolations, err + } + responseValidationInput.SetBodyBytes(body) + v.Body = body + + err = openapi3filter.ValidateResponse(v.validationCtx, responseValidationInput) + if err != nil { + return v.factory.ignoreResponseViolations, err + } + return v.factory.ignoreResponseViolations, nil +} diff --git a/handler/proxy.go b/handler/proxy.go index 854d249d8..c1a515b88 100644 --- a/handler/proxy.go +++ b/handler/proxy.go @@ -7,7 +7,6 @@ import ( "crypto/tls" "fmt" "io" - "io/ioutil" "math" "net" "net/http" @@ -19,7 +18,6 @@ import ( "sync" "time" - "github.com/getkin/kin-openapi/openapi3filter" "github.com/hashicorp/hcl/v2" "github.com/sirupsen/logrus" "golang.org/x/net/http/httpguts" @@ -108,23 +106,6 @@ func (c *CORSOptions) AllowsOrigin(origin string) bool { return false } -type OpenAPIOptions struct { - File string - IgnoreRequestViolations bool - IgnoreResponseViolations bool -} - -func NewOpenAPIOptions(openapi *config.OpenAPI) *OpenAPIOptions { - if openapi == nil { - return nil - } - return &OpenAPIOptions{ - File: openapi.File, - IgnoreRequestViolations: openapi.IgnoreRequestViolations, - IgnoreResponseViolations: openapi.IgnoreResponseViolations, - } -} - func NewProxy(options *ProxyOptions, log *logrus.Entry, srvOpts *server.Options, evalCtx *hcl.EvalContext) (http.Handler, error) { logConf := *logging.DefaultConfig logConf.TypeFieldKey = "couper_backend" @@ -188,44 +169,6 @@ func (p *Proxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { p.upstreamLog.ServeHTTP(rw, req, logging.RoundtripHandlerFunc(p.roundtrip), startTime) } -func (p *Proxy) prepareRequestValidation(outreq *http.Request) (context.Context, *openapi3filter.Route, *openapi3filter.RequestValidationInput, error) { - if p.options.OpenAPI != nil { - dir, err := os.Getwd() - if err != nil { - return nil, nil, nil, err - } - router := openapi3filter.NewRouter().WithSwaggerFromFile(dir + "/" + p.options.OpenAPI.File) - validationCtx := context.Background() - route, pathParams, _ := router.FindRoute(outreq.Method, outreq.URL) - - requestValidationInput := &openapi3filter.RequestValidationInput{ - Request: outreq, - PathParams: pathParams, - Route: route, - } - return validationCtx, route, requestValidationInput, nil - } - return nil, nil, nil, nil -} - -func (p *Proxy) prepareResponseValidation(requestValidationInput *openapi3filter.RequestValidationInput, res *http.Response) (*openapi3filter.ResponseValidationInput, []byte, error) { - if p.options.OpenAPI != nil { - responseValidationInput := &openapi3filter.ResponseValidationInput{ - RequestValidationInput: requestValidationInput, - Status: res.StatusCode, - Header: res.Header, - Options: &openapi3filter.Options{IncludeResponseStatus: true /* undefined response codes are invalid */}, - } - body, err := ioutil.ReadAll(res.Body) - if err != nil { - return nil, nil, err - } - responseValidationInput.SetBodyBytes(body) - return responseValidationInput, body, nil - } - return nil, nil, nil -} - func (p *Proxy) roundtrip(rw http.ResponseWriter, req *http.Request) { ctx := req.Context() if p.options.Timeout > 0 { @@ -280,19 +223,14 @@ func (p *Proxy) roundtrip(rw http.ResponseWriter, req *http.Request) { roundtripInfo := req.Context().Value(request.RoundtripInfo).(*logging.RoundtripInfo) - validationCtx, route, requestValidationInput, err := p.prepareRequestValidation(outreq) - if err != nil { - // this only happens if os.Getwd() fails - // TODO: use error template from parent endpoint>api>server - roundtripInfo.Err = err - couperErr.DefaultJSON.ServeError(couperErr.UpstreamRequestValidationFailed).ServeHTTP(rw, req) - return - } - if requestValidationInput != nil { - if err := openapi3filter.ValidateRequest(validationCtx, requestValidationInput); err != nil { - // TODO: use error template from parent endpoint>api>server + var openapiValidator *OpenAPIValidator + if p.options.OpenAPI != nil { + openapiValidator = p.options.OpenAPI.NewOpenAPIValidator() + ignoreRequestViolations, err := openapiValidator.ValidateRequest(outreq) + if err != nil { roundtripInfo.Err = err - if !p.options.OpenAPI.IgnoreRequestViolations { + if !ignoreRequestViolations { + // TODO: use error template from parent endpoint>api>server couperErr.DefaultJSON.ServeError(couperErr.UpstreamRequestValidationFailed).ServeHTTP(rw, req) return } @@ -330,19 +268,12 @@ func (p *Proxy) roundtrip(rw http.ResponseWriter, req *http.Request) { res.Body = eval.NewReadCloser(src, res.Body) } - responseValidationInput, body, err := p.prepareResponseValidation(requestValidationInput, res) - if err != nil { - // this only happens if response body buffering fails - // TODO: use error template from parent endpoint>api>server - roundtripInfo.Err = err - couperErr.DefaultJSON.ServeError(couperErr.UpstreamResponseBufferingFailed).ServeHTTP(rw, req) - return - } - if responseValidationInput != nil && route != nil { - if err := openapi3filter.ValidateResponse(validationCtx, responseValidationInput); err != nil { - // TODO: use error template from parent endpoint>api>server + if openapiValidator != nil { + ignoreResponseViolations, err := openapiValidator.ValidateResponse(res) + if err != nil { roundtripInfo.Err = err - if !p.options.OpenAPI.IgnoreResponseViolations { + if !ignoreResponseViolations { + // TODO: use error template from parent endpoint>api>server couperErr.DefaultJSON.ServeError(couperErr.UpstreamResponseValidationFailed).ServeHTTP(rw, req) return } @@ -372,8 +303,8 @@ func (p *Proxy) roundtrip(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(res.StatusCode) - if body != nil { - rw.Write(body) + if openapiValidator != nil && openapiValidator.Body != nil { + rw.Write(openapiValidator.Body) } else { _, err = io.Copy(rw, res.Body) if err != nil { diff --git a/handler/proxy_options.go b/handler/proxy_options.go index 7556c1112..1389cab9e 100644 --- a/handler/proxy_options.go +++ b/handler/proxy_options.go @@ -15,7 +15,7 @@ type ProxyOptions struct { Context []hcl.Body BackendName string CORS *CORSOptions - OpenAPI *OpenAPIOptions + OpenAPI *OpenAPIValidatorFactory RequestBodyLimit int64 } @@ -43,12 +43,17 @@ func NewProxyOptions(conf *config.Backend, corsOpts *CORSOptions, remainCtx []hc cors = &CORSOptions{} } + openapiValidatorFactory, err := NewOpenAPIValidatorFactory(conf.OpenAPI) + if err != nil { + return nil, err + } + return &ProxyOptions{ BackendName: conf.Name, CORS: cors, ConnectTimeout: connectD, Context: remainCtx, - OpenAPI: NewOpenAPIOptions(conf.OpenAPI), + OpenAPI: openapiValidatorFactory, RequestBodyLimit: bodyLimit, TTFBTimeout: ttfbD, Timeout: totalD, diff --git a/handler/proxy_test.go b/handler/proxy_test.go index 635d4c331..3cc18e819 100644 --- a/handler/proxy_test.go +++ b/handler/proxy_test.go @@ -21,6 +21,7 @@ import ( "github.com/sirupsen/logrus" logrustest "github.com/sirupsen/logrus/hooks/test" + "github.com/avenga/couper/config" "github.com/avenga/couper/config/request" "github.com/avenga/couper/errors" "github.com/avenga/couper/eval" @@ -430,7 +431,7 @@ paths: tests := []struct { name string - openapiOptions *handler.OpenAPIOptions + openapi *config.OpenAPI requestMethod string requestPath string expectedStatusCode int @@ -438,7 +439,7 @@ paths: }{ { "valid request / valid response", - &handler.OpenAPIOptions{File: "testdata/upstream.yaml"}, + &config.OpenAPI{File: "testdata/upstream.yaml"}, http.MethodGet, "/get", http.StatusOK, @@ -446,7 +447,7 @@ paths: }, { "invalid request", - &handler.OpenAPIOptions{File: "testdata/upstream.yaml"}, + &config.OpenAPI{File: "testdata/upstream.yaml"}, http.MethodPost, "/get", http.StatusBadRequest, @@ -454,7 +455,7 @@ paths: }, { "invalid request, IgnoreRequestViolations", - &handler.OpenAPIOptions{File: "testdata/upstream.yaml", IgnoreRequestViolations: true}, + &config.OpenAPI{File: "testdata/upstream.yaml", IgnoreRequestViolations: true}, http.MethodPost, "/get", http.StatusOK, @@ -462,7 +463,7 @@ paths: }, { "invalid response", - &handler.OpenAPIOptions{File: "testdata/upstream.yaml"}, + &config.OpenAPI{File: "testdata/upstream.yaml"}, http.MethodGet, "/get?404", http.StatusBadGateway, @@ -470,7 +471,7 @@ paths: }, { "invalid response, IgnoreResponseViolations", - &handler.OpenAPIOptions{File: "testdata/upstream.yaml", IgnoreResponseViolations: true}, + &config.OpenAPI{File: "testdata/upstream.yaml", IgnoreResponseViolations: true}, http.MethodGet, "/get?404", http.StatusNotFound, @@ -481,7 +482,11 @@ paths: for _, tt := range tests { t.Run(tt.name, func(subT *testing.T) { logger, hook := logrustest.NewNullLogger() - p, err := handler.NewProxy(&handler.ProxyOptions{Origin: origin.URL, OpenAPI: tt.openapiOptions}, logger.WithContext(context.Background()), nil, eval.NewENVContext(nil)) + openapiValidatorFactory, err := handler.NewOpenAPIValidatorFactory(tt.openapi) + if err != nil { + subT.Fatal(err) + } + p, err := handler.NewProxy(&handler.ProxyOptions{Origin: origin.URL, OpenAPI: openapiValidatorFactory}, logger.WithContext(context.Background()), nil, eval.NewENVContext(nil)) if err != nil { subT.Fatal(err) }