Skip to content

Commit

Permalink
Handle body rewind
Browse files Browse the repository at this point in the history
Refactor openAPI error handling
  • Loading branch information
Marcel Ludwig committed Dec 4, 2020
1 parent 1b75f73 commit 3235383
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 76 deletions.
12 changes: 11 additions & 1 deletion config/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,17 @@ func (b *Backend) Merge(other *Backend) (*Backend, []hcl.Body) {
}

if other.OpenAPI != nil {
result.OpenAPI = other.OpenAPI
if other.OpenAPI.File != "" {
result.OpenAPI.File = other.OpenAPI.File
}

if other.OpenAPI.IgnoreRequestViolations != result.OpenAPI.IgnoreRequestViolations {
result.OpenAPI.IgnoreRequestViolations = other.OpenAPI.IgnoreRequestViolations
}

if other.OpenAPI.IgnoreResponseViolations != result.OpenAPI.IgnoreResponseViolations {
result.OpenAPI.IgnoreResponseViolations = other.OpenAPI.IgnoreResponseViolations
}
}

return &result, bodies
Expand Down
1 change: 1 addition & 0 deletions config/runtime/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ func NewServerConfiguration(conf *config.Gateway, httpConf *HTTPConfig, log *log
if srvConf.API != nil {
// map backends to endpoint
endpoints := make(map[string]bool)

for _, endpoint := range srvConf.API.Endpoint {
pattern := utils.JoinPath("/", serverOptions.APIBasePath, endpoint.Pattern)

Expand Down
81 changes: 58 additions & 23 deletions handler/openapi_validator.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package handler

import (
"bytes"
"context"
"fmt"
"io"
"io/ioutil"
"net/http"
"os"
Expand All @@ -11,15 +14,18 @@ import (
"github.com/getkin/kin-openapi/openapi3filter"

"github.com/avenga/couper/config"
"github.com/avenga/couper/eval"
"github.com/avenga/couper/logging"
)

type OpenAPIValidatorFactory struct {
type OpenAPIValidatorOptions struct {
router *openapi3filter.Router
buffer eval.BufferOption
ignoreRequestViolations bool
ignoreResponseViolations bool
}

func NewOpenAPIValidatorFactory(openapi *config.OpenAPI) (*OpenAPIValidatorFactory, error) {
func NewOpenAPIValidatorOptions(openapi *config.OpenAPI) (*OpenAPIValidatorOptions, error) {
if openapi == nil {
return nil, nil
}
Expand All @@ -28,14 +34,14 @@ func NewOpenAPIValidatorFactory(openapi *config.OpenAPI) (*OpenAPIValidatorFacto
return nil, err
}

bytes, err := ioutil.ReadFile(filepath.Join(dir, openapi.File))
b, err := ioutil.ReadFile(filepath.Join(dir, openapi.File))
if err != nil {
return nil, err
}
return NewOpenAPIValidatorFactoryFromBytes(openapi, bytes)
return NewOpenAPIValidatorOptionsFromBytes(openapi, b)
}

func NewOpenAPIValidatorFactoryFromBytes(openapi *config.OpenAPI, bytes []byte) (*OpenAPIValidatorFactory, error) {
func NewOpenAPIValidatorOptionsFromBytes(openapi *config.OpenAPI, bytes []byte) (*OpenAPIValidatorOptions, error) {
if openapi == nil || bytes == nil {
return nil, nil
}
Expand All @@ -50,30 +56,38 @@ func NewOpenAPIValidatorFactoryFromBytes(openapi *config.OpenAPI, bytes []byte)
return nil, err
}

return &OpenAPIValidatorFactory{
apiValidation := eval.BufferRequest | eval.BufferResponse
if openapi.IgnoreRequestViolations {
apiValidation ^= eval.BufferRequest
}
if openapi.IgnoreResponseViolations {
apiValidation ^= eval.BufferResponse
}

return &OpenAPIValidatorOptions{
buffer: apiValidation,
router: router,
ignoreRequestViolations: openapi.IgnoreRequestViolations,
ignoreResponseViolations: openapi.IgnoreResponseViolations,
}, nil
}

func (f *OpenAPIValidatorFactory) NewOpenAPIValidator() *OpenAPIValidator {
func (o *OpenAPIValidatorOptions) NewOpenAPIValidator() *OpenAPIValidator {
return &OpenAPIValidator{
factory: f,
options: o,
validationCtx: context.Background(),
}
}

type OpenAPIValidator struct {
factory *OpenAPIValidatorFactory
options *OpenAPIValidatorOptions
route *openapi3filter.Route
requestValidationInput *openapi3filter.RequestValidationInput
validationCtx context.Context
Body []byte
}

func (v *OpenAPIValidator) ValidateRequest(req *http.Request) (bool, error) {
route, pathParams, _ := v.factory.router.FindRoute(req.Method, req.URL)
func (v *OpenAPIValidator) ValidateRequest(req *http.Request, tripInfo *logging.RoundtripInfo) error {
route, pathParams, _ := v.options.router.FindRoute(req.Method, req.URL)
v.route = route

v.requestValidationInput = &openapi3filter.RequestValidationInput{
Expand All @@ -82,29 +96,50 @@ func (v *OpenAPIValidator) ValidateRequest(req *http.Request) (bool, error) {
Route: route,
}

return v.factory.ignoreRequestViolations, openapi3filter.ValidateRequest(v.validationCtx, v.requestValidationInput)
err := openapi3filter.ValidateRequest(v.validationCtx, v.requestValidationInput)

if req.GetBody != nil {
req.Body, _ = req.GetBody() // rewind
}

if err != nil {
err = fmt.Errorf("request validation: %w", err)
if !v.options.ignoreRequestViolations {
return err
}
tripInfo.ValidationError = append(tripInfo.ValidationError, err)
}

return nil
}

func (v *OpenAPIValidator) ValidateResponse(res *http.Response) (bool, error) {
func (v *OpenAPIValidator) ValidateResponse(beresp *http.Response, tripInfo *logging.RoundtripInfo) error {
if v.route == nil {
return v.factory.ignoreResponseViolations, nil
return nil
}
responseValidationInput := &openapi3filter.ResponseValidationInput{
RequestValidationInput: v.requestValidationInput,
Status: res.StatusCode,
Header: res.Header,
Status: beresp.StatusCode,
Header: beresp.Header.Clone(),
Options: &openapi3filter.Options{IncludeResponseStatus: true /* undefined response codes are invalid */},
}
body, err := ioutil.ReadAll(res.Body)

buf := &bytes.Buffer{}
_, err := io.Copy(buf, beresp.Body)
if err != nil {
return v.factory.ignoreResponseViolations, err
return err
}
responseValidationInput.SetBodyBytes(body)
v.Body = body
// reset
beresp.Body = eval.NewReadCloser(bytes.NewBuffer(buf.Bytes()), beresp.Body)
responseValidationInput.SetBodyBytes(buf.Bytes())

err = openapi3filter.ValidateResponse(v.validationCtx, responseValidationInput)
if err != nil {
return v.factory.ignoreResponseViolations, err
err = fmt.Errorf("response validation: %w", err)
if !v.options.ignoreResponseViolations {
return err
}
tripInfo.ValidationError = append(tripInfo.ValidationError, err)
}
return v.factory.ignoreResponseViolations, nil
return nil
}
68 changes: 31 additions & 37 deletions handler/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,13 @@ func NewCORSOptions(cors *config.CORS) (*CORSOptions, error) {
}
corsMaxAge := strconv.Itoa(int(math.Floor(dur.Seconds())))

allowed_origins := seetie.ValueToStringSlice(cors.AllowedOrigins)
for i, a := range allowed_origins {
allowed_origins[i] = strings.ToLower(a)
allowedOrigins := seetie.ValueToStringSlice(cors.AllowedOrigins)
for i, a := range allowedOrigins {
allowedOrigins[i] = strings.ToLower(a)
}

return &CORSOptions{
AllowedOrigins: allowed_origins,
AllowedOrigins: allowedOrigins,
AllowCredentials: cors.AllowCredentials,
MaxAge: corsMaxAge,
}, nil
Expand All @@ -97,11 +97,16 @@ func (c *CORSOptions) NeedsVary() bool {
}

func (c *CORSOptions) AllowsOrigin(origin string) bool {
if c == nil {
return false
}

for _, a := range c.AllowedOrigins {
if a == strings.ToLower(origin) || a == "*" {
return true
}
}

return false
}

Expand All @@ -110,8 +115,13 @@ func NewProxy(options *ProxyOptions, log *logrus.Entry, srvOpts *server.Options,
logConf.TypeFieldKey = "couper_backend"
env.DecodeWithPrefix(&logConf, "BACKEND_")

var apiValidation eval.BufferOption
if options.OpenAPI != nil {
apiValidation = options.OpenAPI.buffer
}

proxy := &Proxy{
bufferOption: eval.MustBuffer(options.Context),
bufferOption: apiValidation | eval.MustBuffer(options.Context),
evalContext: evalCtx,
log: log,
options: options,
Expand Down Expand Up @@ -225,23 +235,16 @@ func (p *Proxy) roundtrip(rw http.ResponseWriter, req *http.Request) {
var openapiValidator *OpenAPIValidator
if p.options.OpenAPI != nil {
openapiValidator = p.options.OpenAPI.NewOpenAPIValidator()
ignoreRequestViolations, err := openapiValidator.ValidateRequest(outreq)
if err != nil {
roundtripInfo.Err = err
if !ignoreRequestViolations {
// TODO: use error template from parent endpoint>api>server
couperErr.DefaultJSON.ServeError(couperErr.UpstreamRequestValidationFailed).ServeHTTP(rw, req)
return
}
if roundtripInfo.Err = openapiValidator.ValidateRequest(outreq, roundtripInfo); roundtripInfo.Err != nil {
p.srvOptions.APIErrTpl.ServeError(couperErr.UpstreamRequestValidationFailed).ServeHTTP(rw, req)
return
}
}

res, err := p.getTransport(outreq.URL.Scheme, outreq.URL.Host, outreq.Host).RoundTrip(outreq)
roundtripInfo.BeReq, roundtripInfo.BeResp = outreq, res
if err != nil {
roundtripInfo.Err = err
}
if err != nil {
p.srvOptions.APIErrTpl.ServeError(couperErr.APIConnect).ServeHTTP(rw, req)
return
}
Expand All @@ -267,24 +270,20 @@ func (p *Proxy) roundtrip(rw http.ResponseWriter, req *http.Request) {
res.Body = eval.NewReadCloser(src, res.Body)
}

if openapiValidator != nil {
ignoreResponseViolations, err := openapiValidator.ValidateResponse(res)
if err != nil {
roundtripInfo.Err = err
if !ignoreResponseViolations {
// TODO: use error template from parent endpoint>api>server
couperErr.DefaultJSON.ServeError(couperErr.UpstreamResponseValidationFailed).ServeHTTP(rw, req)
return
}
}
}

removeConnectionHeaders(res.Header)

for _, h := range hopHeaders {
res.Header.Del(h)
}

if openapiValidator != nil {
roundtripInfo.Err = openapiValidator.ValidateResponse(res, roundtripInfo)
if roundtripInfo.Err != nil {
p.srvOptions.APIErrTpl.ServeError(couperErr.UpstreamResponseValidationFailed).ServeHTTP(rw, req)
return
}
}

p.SetRoundtripContext(req, res)

copyHeader(rw.Header(), res.Header)
Expand All @@ -302,19 +301,14 @@ func (p *Proxy) roundtrip(rw http.ResponseWriter, req *http.Request) {

rw.WriteHeader(res.StatusCode)

if openapiValidator != nil && openapiValidator.Body != nil {
rw.Write(openapiValidator.Body)
} else {
_, err = io.Copy(rw, res.Body)
if err != nil {
defer res.Body.Close()
roundtripInfo.Err = err
return
}
}
_, roundtripInfo.Err = io.Copy(rw, res.Body)

res.Body.Close() // close now, instead of defer, to populate res.Trailer

if roundtripInfo.Err != nil {
return
}

if len(res.Trailer) > 0 {
// Force chunking if we saw a response trailer.
// This prevents net/http from calculating the length for short
Expand Down
6 changes: 3 additions & 3 deletions handler/proxy_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ type ProxyOptions struct {
Context []hcl.Body
BackendName string
CORS *CORSOptions
OpenAPI *OpenAPIValidatorFactory
OpenAPI *OpenAPIValidatorOptions
RequestBodyLimit int64
}

Expand Down Expand Up @@ -43,7 +43,7 @@ func NewProxyOptions(conf *config.Backend, corsOpts *CORSOptions, remainCtx []hc
cors = &CORSOptions{}
}

openapiValidatorFactory, err := NewOpenAPIValidatorFactory(conf.OpenAPI)
openAPIValidatorOptions, err := NewOpenAPIValidatorOptions(conf.OpenAPI)
if err != nil {
return nil, err
}
Expand All @@ -53,7 +53,7 @@ func NewProxyOptions(conf *config.Backend, corsOpts *CORSOptions, remainCtx []hc
CORS: cors,
ConnectTimeout: connectD,
Context: remainCtx,
OpenAPI: openapiValidatorFactory,
OpenAPI: openAPIValidatorOptions,
RequestBodyLimit: bodyLimit,
TTFBTimeout: ttfbD,
Timeout: totalD,
Expand Down
Loading

0 comments on commit 3235383

Please sign in to comment.