From 60764c7748eec168a881d24e23c093ab9b9b0143 Mon Sep 17 00:00:00 2001 From: Johannes Koch Date: Thu, 17 Sep 2020 12:01:51 +0200 Subject: [PATCH] upstream validation: implementation, first try (#21) --- config/backend.go | 15 +++++++ config/runtime/handler.go | 12 +++++ errors/code.go | 10 +++++ handler/proxy.go | 95 ++++++++++++++++++++++++++++++++++++--- 4 files changed, 126 insertions(+), 6 deletions(-) diff --git a/config/backend.go b/config/backend.go index d9e76131b..b30ce6e76 100644 --- a/config/backend.go +++ b/config/backend.go @@ -13,6 +13,9 @@ type Backend struct { Path string `hcl:"path,optional"` Timeout string `hcl:"timeout,optional"` TTFBTimeout string `hcl:"ttfb_timeout,optional"` + SwaggerDef string `hcl:"swagger_definition,optional"` + ValidateReq bool `hcl:"validate_request,optional"` + ValidateRes bool `hcl:"validate_response,optional"` } // Merge overrides the left backend configuration and returns a new instance. @@ -62,5 +65,17 @@ func (b *Backend) Merge(other *Backend) (*Backend, []hcl.Body) { result.TTFBTimeout = other.TTFBTimeout } + if other.SwaggerDef != "" { + result.SwaggerDef = other.SwaggerDef + } + + if other.ValidateReq { + result.ValidateReq = other.ValidateReq + } + + if other.ValidateRes { + result.ValidateRes = other.ValidateRes + } + return &result, bodies } diff --git a/config/runtime/handler.go b/config/runtime/handler.go index af6ebf66a..be867dfad 100644 --- a/config/runtime/handler.go +++ b/config/runtime/handler.go @@ -80,6 +80,9 @@ func BuildEntrypointHandlers(conf *config.Gateway, httpConf *HTTPConfig, log *lo Path: beConf.Path, Timeout: t, TTFBTimeout: ttfbt, + SwaggerDef: beConf.SwaggerDef, + ValidateReq: beConf.ValidateReq, + ValidateRes: beConf.ValidateRes, }, log, conf.Context) if err != nil { log.Fatal(err) @@ -193,6 +196,9 @@ func BuildEntrypointHandlers(conf *config.Gateway, httpConf *HTTPConfig, log *lo Path: beConf.Path, Timeout: t, TTFBTimeout: ttfbt, + SwaggerDef: beConf.SwaggerDef, + ValidateReq: beConf.ValidateReq, + ValidateRes: beConf.ValidateRes, }, log, conf.Context) if err != nil { log.Fatal(err) @@ -262,6 +268,9 @@ func BuildEntrypointHandlers(conf *config.Gateway, httpConf *HTTPConfig, log *lo Path: beConf.Path, Timeout: t, TTFBTimeout: ttfbt, + SwaggerDef: beConf.SwaggerDef, + ValidateReq: beConf.ValidateReq, + ValidateRes: beConf.ValidateRes, }, log, conf.Context) if err != nil { log.Fatal(err) @@ -479,6 +488,9 @@ func newInlineBackend(evalCtx *hcl.EvalContext, inlineDef hcl.Body, cors *config Path: beConf.Path, Timeout: t, TTFBTimeout: ttfbt, + SwaggerDef: beConf.SwaggerDef, + ValidateReq: beConf.ValidateReq, + ValidateRes: beConf.ValidateRes, }, log, evalCtx) return proxy, beConf, err } diff --git a/errors/code.go b/errors/code.go index 3fb110997..6f934bf1c 100644 --- a/errors/code.go +++ b/errors/code.go @@ -30,6 +30,12 @@ const ( BasicAuthFailed ) +const ( + UpstreamRequestValidationFailed Code = 6000 + iota + UpstreamResponseValidationFailed + UpstreamResponseBufferingFailed +) + var codes = map[Code]string{ // 1xxx Server: "Server error", @@ -51,6 +57,10 @@ var codes = map[Code]string{ AuthorizationRequired: "Authorization required", AuthorizationFailed: "Authorization failed", BasicAuthFailed: "Unauthorized", + // 6xxx + UpstreamRequestValidationFailed: "Upstream request validation failed", + UpstreamResponseValidationFailed: "Upstream response validation failed", + UpstreamResponseBufferingFailed: "Upstream response buffering failed", } type Code int diff --git a/handler/proxy.go b/handler/proxy.go index bba6f458e..bd2b1386b 100644 --- a/handler/proxy.go +++ b/handler/proxy.go @@ -6,14 +6,17 @@ import ( "errors" "fmt" "io" + "io/ioutil" "math" "net" "net/http" "net/url" + "os" "strconv" "strings" "time" + "github.com/getkin/kin-openapi/openapi3filter" "github.com/hashicorp/hcl/v2" "github.com/sirupsen/logrus" "golang.org/x/net/http/httpguts" @@ -52,7 +55,8 @@ type ProxyOptions struct { ConnectTimeout, Timeout, TTFBTimeout time.Duration Context []hcl.Body BackendName string - Hostname, Origin, Path string + Hostname, Origin, Path, SwaggerDef string + ValidateReq, ValidateRes bool CORS *CORSOptions } @@ -156,6 +160,44 @@ func (p *Proxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { p.upstreamLog.ServeHTTP(rw, req, logging.RoundtripHandlerFunc(p.roundtrip)) } +func (p *Proxy) preparetRequestValidatation(outreq *http.Request) (context.Context, *openapi3filter.Route, *openapi3filter.RequestValidationInput, error) { + if p.options.ValidateReq || p.options.ValidateRes { + dir, err := os.Getwd() + if err != nil { + return nil, nil, nil, err + } + router := openapi3filter.NewRouter().WithSwaggerFromFile(dir + "/" + p.options.SwaggerDef) + validationCtx := context.Background() + route, pathParams, _ := router.FindRoute(outreq.Method, outreq.URL) + + requestValidationInput := &openapi3filter.RequestValidationInput{ + Request: outreq, + PathParams: pathParams, + Route: route, + } + return validationCtx, route, requestValidationInput, nil + } + return nil, nil, nil, nil +} + +func (p *Proxy) prepareResponseValidatation(requestValidationInput *openapi3filter.RequestValidationInput, res *http.Response) (*openapi3filter.ResponseValidationInput, []byte, error) { + if p.options.ValidateRes { + responseValidationInput := &openapi3filter.ResponseValidationInput{ + RequestValidationInput: requestValidationInput, + Status: res.StatusCode, + Header: res.Header, + Options: &openapi3filter.Options{IncludeResponseStatus: true /* undefined response codes are invalid */}, + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + return nil, nil, err + } + responseValidationInput.SetBodyBytes(body) + return responseValidationInput, body, nil + } + return nil, nil, nil +} + func (p *Proxy) roundtrip(rw http.ResponseWriter, req *http.Request) { ctx := req.Context() if p.options.Timeout > 0 { @@ -203,6 +245,23 @@ func (p *Proxy) roundtrip(rw http.ResponseWriter, req *http.Request) { outreq.Header.Set("X-Forwarded-For", clientIP) } + validationCtx, route, requestValidationInput, err := p.preparetRequestValidatation(outreq) + if err != nil { + // this only happens if os.Getwd() fails + // TODO: use error template from parent endpoint>api>server + p.log.WithField("upstream request validation", err).Error() + couperErr.DefaultJSON.ServeError(couperErr.UpstreamRequestValidationFailed).ServeHTTP(rw, req) + return + } + if (p.options.ValidateReq) { + if err := openapi3filter.ValidateRequest(validationCtx, requestValidationInput); err != nil { + // TODO: use error template from parent endpoint>api>server + p.log.WithField("upstream request validation", err).Error() + couperErr.DefaultJSON.ServeError(couperErr.UpstreamRequestValidationFailed).ServeHTTP(rw, req) + return + } + } + res, err := p.transport.RoundTrip(outreq) roundtripInfo := req.Context().Value(request.RoundtripInfo).(*logging.RoundtripInfo) roundtripInfo.BeReq, roundtripInfo.BeResp, roundtripInfo.Err = outreq, res, err @@ -212,6 +271,26 @@ func (p *Proxy) roundtrip(rw http.ResponseWriter, req *http.Request) { return } + responseValidationInput, body, err := p.prepareResponseValidatation(requestValidationInput, res) + if err != nil { + // TODO: use error template from parent endpoint>api>server + p.log.WithField("upstream response validation", err).Error() + couperErr.DefaultJSON.ServeError(couperErr.UpstreamResponseBufferingFailed).ServeHTTP(rw, req) + return + } + if responseValidationInput != nil { + if route != nil { + if err := openapi3filter.ValidateResponse(validationCtx, responseValidationInput); err != nil { + // TODO: use error template from parent endpoint>api>server + p.log.WithField("upstream response validation", err).Error() + couperErr.DefaultJSON.ServeError(couperErr.UpstreamResponseValidationFailed).ServeHTTP(rw, req) + return + } + } else { + p.log.Info("response validation enabled, but no route found") + } + } + // Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc) if res.StatusCode == http.StatusSwitchingProtocols { p.setRoundtripContext(req, res) @@ -242,11 +321,15 @@ func (p *Proxy) roundtrip(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(res.StatusCode) - _, err = io.Copy(rw, res.Body) - if err != nil { - defer res.Body.Close() - roundtripInfo.Err = err - return + if body != nil { + rw.Write(body) + } else { + _, err = io.Copy(rw, res.Body) + if err != nil { + defer res.Body.Close() + roundtripInfo.Err = err + return + } } res.Body.Close() // close now, instead of defer, to populate res.Trailer