Skip to content

Commit

Permalink
cache modified swagger
Browse files Browse the repository at this point in the history
  • Loading branch information
Johannes Koch committed May 14, 2021
1 parent 85adeb0 commit 7f3fa49
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 4 deletions.
2 changes: 1 addition & 1 deletion handler/transport/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func (b *Backend) RoundTrip(req *http.Request) (*http.Response, error) {
}

if b.openAPIValidator != nil {
if err = b.openAPIValidator.ValidateRequest(req); err != nil {
if err = b.openAPIValidator.ValidateRequest(req, tc.hash(), tc.Origin); err != nil {
return nil, errors.BackendValidation.Label(b.name).With(err)
}
}
Expand Down
50 changes: 47 additions & 3 deletions handler/validation/openapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"io"
"io/ioutil"
"net/http"
"net/url"
"sync"

"github.com/avenga/couper/config/request"

Expand All @@ -15,6 +17,8 @@ import (
"github.com/avenga/couper/eval"
)

var swaggers sync.Map

type OpenAPI struct {
options *OpenAPIOptions
requestValidationInput *openapi3filter.RequestValidationInput
Expand All @@ -29,18 +33,58 @@ func NewOpenAPI(opts *OpenAPIOptions) *OpenAPI {
}
}

func (v *OpenAPI) getModifiedSwagger(key, origin string) (*openapi3.Swagger, error) {
swagger, exists := swaggers.Load(key)
if !exists {
clonedSwagger := cloneSwagger(v.options.swagger)

var newServers []string
for _, s := range clonedSwagger.Servers {
su, err := url.Parse(s.URL)
if err != nil {
return nil, err
}
if !su.IsAbs() {
newServers = append(newServers, origin+s.URL)
}
}
for _, ns := range newServers {
clonedSwagger.AddServer(&openapi3.Server{URL: ns})
}

swaggers.Store(key, clonedSwagger)
swagger = clonedSwagger
}

if s, ok := swagger.(*openapi3.Swagger); ok {
return s, nil
}

err := fmt.Errorf("request validation: swagger wrong type: %v", swagger)
return nil, err
}

func cloneSwagger(s *openapi3.Swagger) *openapi3.Swagger {
sw := *s
// this is not a deep clone; we only want to add servers
sw.Servers = s.Servers[:]
return &sw
}

func (v *OpenAPI) ValidateRequest(req *http.Request) error {
clonedSwagger := cloneSwagger(v.options.swagger)
func (v *OpenAPI) ValidateRequest(req *http.Request, key, origin string) error {
swagger, err := v.getModifiedSwagger(key, origin)
if err != nil {
if ctx, ok := req.Context().Value(request.OpenAPI).(*OpenAPIContext); ok {
ctx.errors = append(ctx.errors, err)
}
if !v.options.ignoreRequestViolations {
return err
}
return nil
}

router := openapi3filter.NewRouter()
if err := router.AddSwagger(clonedSwagger); err != nil {
if err = router.AddSwagger(swagger); err != nil {
return err
}

Expand Down

0 comments on commit 7f3fa49

Please sign in to comment.