Skip to content

Commit

Permalink
simplify spec reference resolution
Browse files Browse the repository at this point in the history
  • Loading branch information
Daniel Grossmann-Kavanagh committed Jun 18, 2018
1 parent e86c6ec commit f8458c4
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 29 deletions.
41 changes: 14 additions & 27 deletions connexion/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .decorators.validation import (ParameterValidator, RequestBodyValidator,
TypeValidationError)
from .exceptions import InvalidSpecification
from .utils import all_json, is_nullable
from .utils import all_json, deep_get, is_nullable

logger = logging.getLogger('connexion.operation')

Expand Down Expand Up @@ -218,19 +218,21 @@ def __init__(self, api, method, path, operation, resolver, app_produces, app_con
# openapi3
self.components = components or {}

def component_get(oas3_name):
return self.components.get(oas3_name, {})
component_get = lambda oas3_name: self.components.get(oas3_name, {})

self.security_definitions = self.security_definitions or component_get('securitySchemes')
self.parameter_definitions = self.parameter_definitions or component_get('parameters')
self.response_definitions = self.response_definitions or component_get('responses')

self.definitions_map = {
'components.schemas': self.components.get('schemas', {}),
'components.requestBodies': self.components.get('requestBodies', {}),
'components.parameters': self.components.get('parameters', {}),
'components.securitySchemes': self.components.get('securitySchemes', {}),
'components.responses': self.components.get('responses', {}),
'components': {
'schemas': self.components.get('schemas', {}),
'requestBodies': self.components.get('requestBodies', {}),
'parameters': self.components.get('parameters', {}),
'securitySchemes': self.components.get('securitySchemes', {}),
'responses': self.components.get('responses', {}),
'headers': self.components.get('headers', {}),
},
'definitions': self.definitions,
'parameters': self.parameter_definitions,
'responses': self.response_definitions
Expand Down Expand Up @@ -362,28 +364,13 @@ def _retrieve_reference(self, reference):
if not reference.startswith('#/'):
raise InvalidSpecification(
"{method} {path} '$ref' needs to start with '#/'".format(**vars(self)))
path = reference.split('/')
definition_type = ".".join(path[1:-1])
path = reference[2:].split('/')
try:
definitions = self.definitions_map[definition_type]
definition = deep_get(self.definitions_map, path)
except KeyError:
ref_possible = ', '.join(self.definitions_map.keys())
raise InvalidSpecification(
"{method} {path} $ref \"{reference}\" needs to point to one of: "
"{ref_possible}".format(
method=self.method,
path=self.path,
reference=reference,
ref_possible=ref_possible
))
definition_name = path[-1]
try:
# Get sub definition
definition = deepcopy(definitions[definition_name])
except KeyError:
raise InvalidSpecification(
"{method} {path} Definition '{definition_name}' not found".format(
definition_name=definition_name, method=self.method, path=self.path))
"{method} {path} $ref '{reference}' not found".format(
reference=reference, method=self.method, path=self.path))

return definition

Expand Down
9 changes: 9 additions & 0 deletions connexion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@ def deep_getattr(obj, attr):
return functools.reduce(getattr, attr.split('.'), obj)


def deep_get(obj, keys):
"""
Recurses through a nested object get a leaf value.
"""
if not keys:
return obj
return deep_get(obj[keys[0]], keys[1:])


def get_function_from_name(function_name):
"""
Tries to get function by fully qualified name (e.g. "mymodule.myobj.myfunc")
Expand Down
4 changes: 2 additions & 2 deletions tests/test_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,8 @@ def test_non_existent_reference(api):
operation.body_schema

exception = exc_info.value
assert str(exception) == "<InvalidSpecification: GET endpoint Definition 'new_stack' not found>"
assert repr(exception) == "<InvalidSpecification: GET endpoint Definition 'new_stack' not found>"
assert str(exception).startswith("<InvalidSpecification: GET endpoint $ref")
assert repr(exception).startswith("<InvalidSpecification: GET endpoint $ref")


def test_multi_body(api):
Expand Down

0 comments on commit f8458c4

Please sign in to comment.