Skip to content

Commit

Permalink
fixup rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
Daniel Grossmann-Kavanagh committed Jul 25, 2018
1 parent 4e1c4bf commit 975567a
Show file tree
Hide file tree
Showing 12 changed files with 182 additions and 700 deletions.
21 changes: 8 additions & 13 deletions connexion/apis/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,24 +86,17 @@ def __init__(self, specification, base_path=None, arguments=None,

self.spec_version = self._get_spec_version(self.specification)

self.options = ConnexionOptions(old_style_options, oas_version=self.spec_version)
# options is added last to preserve the highest priority
self.options = self.options.extend(options)

# TODO: Remove this in later versions (Current version is 1.1.9)
if base_path is None and 'base_url' in old_style_options:
base_path = old_style_options['base_url']
logger.warning("Parameter base_url should be no longer used. Use base_path instead.")
self.options = ConnexionOptions(options, oas_version=self.spec_version)

logger.debug('Options Loaded',
extra={'swagger_ui': self.options.openapi_console_ui_available,
'swagger_path': self.options.openapi_console_ui_from_dir,
'swagger_url': self.options.openapi_console_ui_path})

# Avoid validator having ability to modify specification
spec = copy.deepcopy(self.specification)
self._validate_spec(spec)
self.specification = resolve_refs(spec)
self.raw_spec = copy.deepcopy(self.specification)
self._validate_spec(self.specification)
self.specification = resolve_refs(self.specification)

# https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#fixed-fields
# If base_path is not on provided then we try to read it from the swagger.yaml or use / by default
Expand Down Expand Up @@ -229,7 +222,8 @@ def add_operation(self, method, path, swagger_operation, path_parameters):
strict_validation=self.strict_validation,
resolver=self.resolver,
pythonic_params=self.pythonic_params,
uri_parser_class=self.options.uri_parser_class)
uri_parser_class=self.options.uri_parser_class,
pass_context_arg_name=self.pass_context_arg_name)
else:
operation = OpenAPIOperation(self,
method=method,
Expand All @@ -243,7 +237,8 @@ def add_operation(self, method, path, swagger_operation, path_parameters):
strict_validation=self.strict_validation,
resolver=self.resolver,
pythonic_params=self.pythonic_params,
uri_parser_class=self.options.uri_parser_class)
uri_parser_class=self.options.uri_parser_class,
pass_context_arg_name=self.pass_context_arg_name)

self._add_operation_internal(method, path, operation)

Expand Down
2 changes: 1 addition & 1 deletion connexion/apis/aiohttp_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def _get_openapi_json(self, req):
return web.Response(
status=200,
content_type='application/json',
body=self.jsonifier.dumps(self.specification)
body=self.jsonifier.dumps(self.raw_spec)
)

def add_swagger_ui(self):
Expand Down
2 changes: 1 addition & 1 deletion connexion/apis/flask_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def add_openapi_json(self):
endpoint_name = "{name}_openapi_json".format(name=self.blueprint.name)
self.blueprint.add_url_rule(self.options.openapi_spec_path,
endpoint_name,
lambda: flask.jsonify(self.specification))
lambda: flask.jsonify(self.raw_spec))

def add_swagger_ui(self):
"""
Expand Down
78 changes: 11 additions & 67 deletions connexion/operations/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from ..decorators.produces import BaseSerializer, Produces
from ..decorators.response import ResponseValidator
from ..decorators.validation import ParameterValidator, RequestBodyValidator
from ..exceptions import InvalidSpecification
from ..utils import all_json, deep_get, is_nullable
from ..utils import all_json, is_nullable

logger = logging.getLogger('connexion.operations.abstract')

Expand Down Expand Up @@ -54,7 +53,8 @@ def __init__(self, api, method, path, operation, resolver,
app_security=None, security_schemes=None,
validate_responses=False, strict_validation=False,
randomize_endpoint=None, validator_map=None,
pythonic_params=False, uri_parser_class=None):
pythonic_params=False, uri_parser_class=None,
pass_context_arg_name=None):
"""
"""
self._api = api
Expand All @@ -68,6 +68,7 @@ def __init__(self, api, method, path, operation, resolver,
self._strict_validation = strict_validation
self._pythonic_params = pythonic_params
self._uri_parser_class = uri_parser_class
self.pass_context_arg_name = pass_context_arg_name
self._randomize_endpoint = randomize_endpoint

self._router_controller = self._operation.get('x-swagger-router-controller')
Expand Down Expand Up @@ -187,18 +188,6 @@ def consumes(self):
"""
return []

@abc.abstractproperty
def _spec_definitions(self):
"""
a nested dictionary that is used by _resolve_reference.
It contains the definitions referenced in the spec.
for example, a spec with "#/components/schemas/Banana"
would have a definitions map that looked like:
{"components": {"schemas": {"Banana": {...}}}}
"""
return {}

@abc.abstractproperty
def body_schema(self):
"""
Expand Down Expand Up @@ -253,57 +242,9 @@ def _validate_defaults(self):
"""

@abc.abstractmethod
def _resolve_reference(self, schema):
"""
replaces schema references like "#/components/schemas/MySchema"
with the contents of that reference.
relies on self._components to be a nested dictionary with the
definitions for all of the components.
See helper methods _check_references and _retrieve_reference
"""

def _check_references(self, schema):
"""
Searches the keys and values of a schema object for json references.
If it finds one, it attempts to locate it and will thrown an exception
if the reference can't be found in the definitions dictionary.
:param schema: The schema object to check
:type schema: dict
:raises InvalidSpecification: raised when a reference isn't found
"""
stack = [schema]
visited = set()
while stack:
schema = stack.pop()
for k, v in schema.items():
if k == "$ref":
if v in visited:
continue
visited.add(v)
stack.append(self._retrieve_reference(v))
elif isinstance(v, (list, tuple)):
continue
elif hasattr(v, "items"):
stack.append(v)

def _retrieve_reference(self, reference):
if not reference.startswith('#/'):
raise InvalidSpecification(
"{method} {path} '$ref' needs to start with '#/'".format(
method=self.method,
path=self.path))
path = reference[2:].split('/')
try:
definition = deep_get(self._spec_definitions, path)
except KeyError:
raise InvalidSpecification(
"{method} {path} $ref '{reference}' not found".format(
reference=reference, method=self.method, path=self.path))

return definition
def with_definitions(self, schema):
"""
"""

def get_mimetype(self):
"""
Expand Down Expand Up @@ -333,7 +274,10 @@ def function(self):
:rtype: types.FunctionType
"""
function = parameter_to_arg(self, self._resolution.function, self.pythonic_params)
function = parameter_to_arg(
self, self._resolution.function, self.pythonic_params,
self.pass_context_arg_name
)
function = self._request_begin_lifecycle_decorator(function)

if self.validate_responses:
Expand Down
102 changes: 18 additions & 84 deletions connexion/operations/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class OpenAPIOperation(AbstractOperation):
def __init__(self, api, method, path, operation, resolver, path_parameters=None,
app_security=None, components=None, validate_responses=False,
strict_validation=False, randomize_endpoint=None, validator_map=None,
pythonic_params=False, uri_parser_class=None):
pythonic_params=False, uri_parser_class=None, pass_context_arg_name=None):
"""
This class uses the OperationID identify the module and function that will handle the operation
Expand Down Expand Up @@ -60,6 +60,9 @@ def __init__(self, api, method, path, operation, resolver, path_parameters=None,
:type pythonic_params: bool
:param uri_parser_class: class to use for uri parseing
:type uri_parser_class: AbstractURIParser
:param pass_context_arg_name: If not None will try to inject the request context to the function using this
name.
:type pass_context_arg_name: str|None
"""
self.components = components or {}

Expand All @@ -84,7 +87,8 @@ def component_get(oas3_name):
randomize_endpoint=randomize_endpoint,
validator_map=validator_map,
pythonic_params=pythonic_params,
uri_parser_class=uri_parser_class
uri_parser_class=uri_parser_class,
pass_context_arg_name=pass_context_arg_name
)

self._definitions_map = {
Expand All @@ -102,47 +106,12 @@ def component_get(oas3_name):
# todo support definition references
# todo support references to application level parameters
self._request_body = operation.get('requestBody')
if self._request_body:
self._request_body = self._resolve_reference(self._request_body)

def resolve_parameters(parameters):
return [self._resolve_reference(p) for p in parameters]

self.parameters = resolve_parameters(operation.get('parameters', []))
self.parameters = operation.get('parameters', [])
if path_parameters:
self.parameters += resolve_parameters(path_parameters)

def resolve_responses(responses):
if not responses:
return responses
responses = deepcopy(responses)
for status_code, resp in responses.items():
# check components/responses
if '$ref' in resp:
ref = self._resolve_reference(resp)
del resp['$ref']
resp = ref

content = resp.get("content", {})
for mimetype, resp in content.items():
# check components/examples
examples = resp.get("examples", {})
examples = {k: self._resolve_reference(v)
for k, v in examples.items()}

example = resp.get("example", {})
ref = self._resolve_reference(example)
if ref:
resp["example"] = ref

schema = resp.get("schema", {})
ref = self._resolve_reference(schema)
if ref:
resp["schema"] = ref

return responses

self._responses = resolve_responses(operation.get('responses', {}))
self.parameters += path_parameters

self._responses = operation.get('responses', {})

# TODO figure out how to support multiple mimetypes
# NOTE we currently just combine all of the possible mimetypes,
Expand Down Expand Up @@ -191,59 +160,24 @@ def _validate_defaults(self):
' type \'{param_type}\''.format(param_name=param_defn['name'],
param_type=param_schema['type']))

def _resolve_reference(self, schema):
schema = deepcopy(schema) # avoid changing the original schema
self._check_references(schema)

# find the object we need to resolve/update if this is not a proper SchemaObject
# e.g a response or parameter object
for obj in schema, schema.get('items'):
reference = obj and obj.get('$ref') # type: str
if reference:
break
if reference:
definition = deepcopy(self._retrieve_reference(reference))
# Update schema
obj.update(definition)
del obj['$ref']

# if the schema includes allOf or oneOf or anyOf
for multi in ['allOf', 'anyOf', 'oneOf']:
upd = []
for s in schema.get(multi, []):
upd.append(self._resolve_reference(s))
if upd:
schema[multi] = upd

# additionalProperties
try:
ap = schema['additionalProperties']
if ap:
schema['additionalProperties'] = self._resolve_reference(ap)
except KeyError:
pass

# if there is a schema object on this param or response, then we just
# need to include the defs and it can be validated by jsonschema
if "$ref" in schema.get("schema", {}):
if self.components:
schema['schema']['components'] = self.components
return schema

def with_definitions(self, schema):
if self.components:
schema['schema']['components'] = self.components
return schema

def response_definition(self, status_code=None, content_type=None):
content_type = content_type or self.get_mimetype()
response_definitions = self._responses
response_definition = response_definitions.get(str(status_code), response_definitions.get("default", {}))
response_definition = self._resolve_reference(response_definition)
return response_definition

def response_schema(self, status_code=None, content_type=None):
response_definition = self.response_definition(status_code, content_type)
content_definition = response_definition.get("content", response_definition)
content_definition = content_definition.get(content_type, content_definition)
return self._resolve_reference(content_definition.get("schema", {}))
if "schema" in content_definition:
return self.with_definitions(content_definition).get("schema", {})
return {}

def example_response(self, code=None, content_type=None):
"""
Expand Down Expand Up @@ -296,7 +230,7 @@ def body_schema(self):
"""
The body schema definition for this operation.
"""
return self._resolve_reference(self.body_definition.get('schema', {}))
return self.body_definition.get('schema', {})

@property
def body_definition(self):
Expand All @@ -313,7 +247,7 @@ def body_definition(self):
'this operation accepts multiple content types, using %s',
self.consumes[0])
res = self._request_body.get('content', {}).get(self.consumes[0], {})
return self._resolve_reference(res)
return self.with_definitions(res)
return {}

def _get_body_argument(self, body, arguments, has_kwargs):
Expand Down
Loading

0 comments on commit 975567a

Please sign in to comment.