Skip to content

Commit

Permalink
extracted OpenAPI validator (#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
Johannes Koch authored and Marcel Ludwig committed Dec 4, 2020
1 parent 95767fc commit 810e9ed
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 92 deletions.
84 changes: 84 additions & 0 deletions handler/openapi_validator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package handler

import (
"context"
"io/ioutil"
"net/http"
"os"

"github.com/avenga/couper/config"
"github.com/getkin/kin-openapi/openapi3filter"
)

type OpenAPIValidatorFactory struct {
router *openapi3filter.Router
ignoreRequestViolations bool
ignoreResponseViolations bool
}

func NewOpenAPIValidatorFactory(openapi *config.OpenAPI) (*OpenAPIValidatorFactory, error) {
if openapi == nil {
return nil, nil
}
dir, err := os.Getwd()
if err != nil {
return nil, err
}
return &OpenAPIValidatorFactory{
router: openapi3filter.NewRouter().WithSwaggerFromFile(dir + "/" + openapi.File),
ignoreRequestViolations: openapi.IgnoreRequestViolations,
ignoreResponseViolations: openapi.IgnoreResponseViolations,
}, nil
}

func (f *OpenAPIValidatorFactory) NewOpenAPIValidator() *OpenAPIValidator {
return &OpenAPIValidator{
factory: f,
validationCtx: context.Background(),
}
}

type OpenAPIValidator struct {
factory *OpenAPIValidatorFactory
route *openapi3filter.Route
requestValidationInput *openapi3filter.RequestValidationInput
validationCtx context.Context
Body []byte
}

func (v *OpenAPIValidator) ValidateRequest(req *http.Request) (bool, error) {
route, pathParams, _ := v.factory.router.FindRoute(req.Method, req.URL)
v.route = route

v.requestValidationInput = &openapi3filter.RequestValidationInput{
Request: req,
PathParams: pathParams,
Route: route,
}

return v.factory.ignoreRequestViolations, openapi3filter.ValidateRequest(v.validationCtx, v.requestValidationInput)
}

func (v *OpenAPIValidator) ValidateResponse(res *http.Response) (bool, error) {
if v.route == nil {
return v.factory.ignoreResponseViolations, nil
}
responseValidationInput := &openapi3filter.ResponseValidationInput{
RequestValidationInput: v.requestValidationInput,
Status: res.StatusCode,
Header: res.Header,
Options: &openapi3filter.Options{IncludeResponseStatus: true /* undefined response codes are invalid */},
}
body, err := ioutil.ReadAll(res.Body)
if err != nil {
return v.factory.ignoreResponseViolations, err
}
responseValidationInput.SetBodyBytes(body)
v.Body = body

err = openapi3filter.ValidateResponse(v.validationCtx, responseValidationInput)
if err != nil {
return v.factory.ignoreResponseViolations, err
}
return v.factory.ignoreResponseViolations, nil
}
97 changes: 14 additions & 83 deletions handler/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"crypto/tls"
"fmt"
"io"
"io/ioutil"
"math"
"net"
"net/http"
Expand All @@ -19,7 +18,6 @@ import (
"sync"
"time"

"github.com/getkin/kin-openapi/openapi3filter"
"github.com/hashicorp/hcl/v2"
"github.com/sirupsen/logrus"
"golang.org/x/net/http/httpguts"
Expand Down Expand Up @@ -108,23 +106,6 @@ func (c *CORSOptions) AllowsOrigin(origin string) bool {
return false
}

type OpenAPIOptions struct {
File string
IgnoreRequestViolations bool
IgnoreResponseViolations bool
}

func NewOpenAPIOptions(openapi *config.OpenAPI) *OpenAPIOptions {
if openapi == nil {
return nil
}
return &OpenAPIOptions{
File: openapi.File,
IgnoreRequestViolations: openapi.IgnoreRequestViolations,
IgnoreResponseViolations: openapi.IgnoreResponseViolations,
}
}

func NewProxy(options *ProxyOptions, log *logrus.Entry, srvOpts *server.Options, evalCtx *hcl.EvalContext) (http.Handler, error) {
logConf := *logging.DefaultConfig
logConf.TypeFieldKey = "couper_backend"
Expand Down Expand Up @@ -188,44 +169,6 @@ func (p *Proxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
p.upstreamLog.ServeHTTP(rw, req, logging.RoundtripHandlerFunc(p.roundtrip), startTime)
}

func (p *Proxy) prepareRequestValidation(outreq *http.Request) (context.Context, *openapi3filter.Route, *openapi3filter.RequestValidationInput, error) {
if p.options.OpenAPI != nil {
dir, err := os.Getwd()
if err != nil {
return nil, nil, nil, err
}
router := openapi3filter.NewRouter().WithSwaggerFromFile(dir + "/" + p.options.OpenAPI.File)
validationCtx := context.Background()
route, pathParams, _ := router.FindRoute(outreq.Method, outreq.URL)

requestValidationInput := &openapi3filter.RequestValidationInput{
Request: outreq,
PathParams: pathParams,
Route: route,
}
return validationCtx, route, requestValidationInput, nil
}
return nil, nil, nil, nil
}

func (p *Proxy) prepareResponseValidation(requestValidationInput *openapi3filter.RequestValidationInput, res *http.Response) (*openapi3filter.ResponseValidationInput, []byte, error) {
if p.options.OpenAPI != nil {
responseValidationInput := &openapi3filter.ResponseValidationInput{
RequestValidationInput: requestValidationInput,
Status: res.StatusCode,
Header: res.Header,
Options: &openapi3filter.Options{IncludeResponseStatus: true /* undefined response codes are invalid */},
}
body, err := ioutil.ReadAll(res.Body)
if err != nil {
return nil, nil, err
}
responseValidationInput.SetBodyBytes(body)
return responseValidationInput, body, nil
}
return nil, nil, nil
}

func (p *Proxy) roundtrip(rw http.ResponseWriter, req *http.Request) {
ctx := req.Context()
if p.options.Timeout > 0 {
Expand Down Expand Up @@ -280,19 +223,14 @@ func (p *Proxy) roundtrip(rw http.ResponseWriter, req *http.Request) {

roundtripInfo := req.Context().Value(request.RoundtripInfo).(*logging.RoundtripInfo)

validationCtx, route, requestValidationInput, err := p.prepareRequestValidation(outreq)
if err != nil {
// this only happens if os.Getwd() fails
// TODO: use error template from parent endpoint>api>server
roundtripInfo.Err = err
couperErr.DefaultJSON.ServeError(couperErr.UpstreamRequestValidationFailed).ServeHTTP(rw, req)
return
}
if requestValidationInput != nil {
if err := openapi3filter.ValidateRequest(validationCtx, requestValidationInput); err != nil {
// TODO: use error template from parent endpoint>api>server
var openapiValidator *OpenAPIValidator
if p.options.OpenAPI != nil {
openapiValidator = p.options.OpenAPI.NewOpenAPIValidator()
ignoreRequestViolations, err := openapiValidator.ValidateRequest(outreq)
if err != nil {
roundtripInfo.Err = err
if !p.options.OpenAPI.IgnoreRequestViolations {
if !ignoreRequestViolations {
// TODO: use error template from parent endpoint>api>server
couperErr.DefaultJSON.ServeError(couperErr.UpstreamRequestValidationFailed).ServeHTTP(rw, req)
return
}
Expand Down Expand Up @@ -330,19 +268,12 @@ func (p *Proxy) roundtrip(rw http.ResponseWriter, req *http.Request) {
res.Body = eval.NewReadCloser(src, res.Body)
}

responseValidationInput, body, err := p.prepareResponseValidation(requestValidationInput, res)
if err != nil {
// this only happens if response body buffering fails
// TODO: use error template from parent endpoint>api>server
roundtripInfo.Err = err
couperErr.DefaultJSON.ServeError(couperErr.UpstreamResponseBufferingFailed).ServeHTTP(rw, req)
return
}
if responseValidationInput != nil && route != nil {
if err := openapi3filter.ValidateResponse(validationCtx, responseValidationInput); err != nil {
// TODO: use error template from parent endpoint>api>server
if openapiValidator != nil {
ignoreResponseViolations, err := openapiValidator.ValidateResponse(res)
if err != nil {
roundtripInfo.Err = err
if !p.options.OpenAPI.IgnoreResponseViolations {
if !ignoreResponseViolations {
// TODO: use error template from parent endpoint>api>server
couperErr.DefaultJSON.ServeError(couperErr.UpstreamResponseValidationFailed).ServeHTTP(rw, req)
return
}
Expand Down Expand Up @@ -372,8 +303,8 @@ func (p *Proxy) roundtrip(rw http.ResponseWriter, req *http.Request) {

rw.WriteHeader(res.StatusCode)

if body != nil {
rw.Write(body)
if openapiValidator != nil && openapiValidator.Body != nil {
rw.Write(openapiValidator.Body)
} else {
_, err = io.Copy(rw, res.Body)
if err != nil {
Expand Down
9 changes: 7 additions & 2 deletions handler/proxy_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ type ProxyOptions struct {
Context []hcl.Body
BackendName string
CORS *CORSOptions
OpenAPI *OpenAPIOptions
OpenAPI *OpenAPIValidatorFactory
RequestBodyLimit int64
}

Expand Down Expand Up @@ -43,12 +43,17 @@ func NewProxyOptions(conf *config.Backend, corsOpts *CORSOptions, remainCtx []hc
cors = &CORSOptions{}
}

openapiValidatorFactory, err := NewOpenAPIValidatorFactory(conf.OpenAPI)
if err != nil {
return nil, err
}

return &ProxyOptions{
BackendName: conf.Name,
CORS: cors,
ConnectTimeout: connectD,
Context: remainCtx,
OpenAPI: NewOpenAPIOptions(conf.OpenAPI),
OpenAPI: openapiValidatorFactory,
RequestBodyLimit: bodyLimit,
TTFBTimeout: ttfbD,
Timeout: totalD,
Expand Down
19 changes: 12 additions & 7 deletions handler/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/sirupsen/logrus"
logrustest "github.com/sirupsen/logrus/hooks/test"

"github.com/avenga/couper/config"
"github.com/avenga/couper/config/request"
"github.com/avenga/couper/errors"
"github.com/avenga/couper/eval"
Expand Down Expand Up @@ -430,47 +431,47 @@ paths:

tests := []struct {
name string
openapiOptions *handler.OpenAPIOptions
openapi *config.OpenAPI
requestMethod string
requestPath string
expectedStatusCode int
expectedLogMessage string
}{
{
"valid request / valid response",
&handler.OpenAPIOptions{File: "testdata/upstream.yaml"},
&config.OpenAPI{File: "testdata/upstream.yaml"},
http.MethodGet,
"/get",
http.StatusOK,
"",
},
{
"invalid request",
&handler.OpenAPIOptions{File: "testdata/upstream.yaml"},
&config.OpenAPI{File: "testdata/upstream.yaml"},
http.MethodPost,
"/get",
http.StatusBadRequest,
"invalid route",
},
{
"invalid request, IgnoreRequestViolations",
&handler.OpenAPIOptions{File: "testdata/upstream.yaml", IgnoreRequestViolations: true},
&config.OpenAPI{File: "testdata/upstream.yaml", IgnoreRequestViolations: true},
http.MethodPost,
"/get",
http.StatusOK,
"invalid route",
},
{
"invalid response",
&handler.OpenAPIOptions{File: "testdata/upstream.yaml"},
&config.OpenAPI{File: "testdata/upstream.yaml"},
http.MethodGet,
"/get?404",
http.StatusBadGateway,
"status is not supported",
},
{
"invalid response, IgnoreResponseViolations",
&handler.OpenAPIOptions{File: "testdata/upstream.yaml", IgnoreResponseViolations: true},
&config.OpenAPI{File: "testdata/upstream.yaml", IgnoreResponseViolations: true},
http.MethodGet,
"/get?404",
http.StatusNotFound,
Expand All @@ -481,7 +482,11 @@ paths:
for _, tt := range tests {
t.Run(tt.name, func(subT *testing.T) {
logger, hook := logrustest.NewNullLogger()
p, err := handler.NewProxy(&handler.ProxyOptions{Origin: origin.URL, OpenAPI: tt.openapiOptions}, logger.WithContext(context.Background()), nil, eval.NewENVContext(nil))
openapiValidatorFactory, err := handler.NewOpenAPIValidatorFactory(tt.openapi)
if err != nil {
subT.Fatal(err)
}
p, err := handler.NewProxy(&handler.ProxyOptions{Origin: origin.URL, OpenAPI: openapiValidatorFactory}, logger.WithContext(context.Background()), nil, eval.NewENVContext(nil))
if err != nil {
subT.Fatal(err)
}
Expand Down

0 comments on commit 810e9ed

Please sign in to comment.