diff --git a/src/labthings/apispec/plugins.py b/src/labthings/apispec/plugins.py index 6a384d3..4339fdf 100644 --- a/src/labthings/apispec/plugins.py +++ b/src/labthings/apispec/plugins.py @@ -112,11 +112,10 @@ def spec_for_interaction(cls, interaction): ) return d - @classmethod - def spec_for_property(cls, prop): - class_schema = ensure_schema(prop.schema) or {} + def spec_for_property(self, prop): + class_schema = ensure_schema(self.spec, prop.schema) or {} - d = cls.spec_for_interaction(prop) + d = self.spec_for_interaction(prop) # Add in writeproperty methods for method in ("put", "post"): @@ -155,9 +154,11 @@ def spec_for_property(cls, prop): return d def spec_for_action(self, action): - action_input = ensure_schema(action.args, name=f"{action.__name__}InputSchema") + action_input = ensure_schema( + self.spec, action.args, name=f"{action.__name__}InputSchema" + ) action_output = ensure_schema( - action.schema, name=f"{action.__name__}OutputSchema" + self.spec, action.schema, name=f"{action.__name__}OutputSchema" ) # We combine input/output parameters with ActionSchema using an # allOf directive, so we don't end up duplicating the schema diff --git a/src/labthings/apispec/utilities.py b/src/labthings/apispec/utilities.py index 6752119..7d42733 100644 --- a/src/labthings/apispec/utilities.py +++ b/src/labthings/apispec/utilities.py @@ -1,21 +1,27 @@ from inspect import isclass from typing import Dict, Type, Union, cast +from apispec import APISpec from apispec.ext.marshmallow import MarshmallowPlugin -from apispec.ext.marshmallow.field_converter import FieldConverterMixin from marshmallow import Schema from .. import fields -def field2property(field): - """Convert a marshmallow Field to OpenAPI dictionary""" - converter = FieldConverterMixin() - converter.init_attribute_functions() - return converter.field2property(field) +def field2property(spec: APISpec, field: fields.Field): + """Convert a marshmallow Field to OpenAPI dictionary + + We require an initialised APISpec object to use its + converter function - in particular, this will depend + on the OpenAPI version defined in `spec`. We also rely + on the spec having a `MarshmallowPlugin` attached. + """ + plugin = get_marshmallow_plugin(spec) + return plugin.converter.field2property(field) def ensure_schema( + spec: APISpec, schema: Union[ fields.Field, Type[fields.Field], @@ -34,11 +40,14 @@ def ensure_schema( Other Schemas are returned as Marshmallow Schema instances, which will be converted to references by the plugin. + + The first argument must be an initialised APISpec object, as the conversion + of single fields to dictionaries is version-dependent. """ if schema is None: return None if isinstance(schema, fields.Field): - return field2property(schema) + return field2property(spec, schema) elif isinstance(schema, dict): return Schema.from_dict(schema, name=name)() elif isinstance(schema, Schema): @@ -46,7 +55,7 @@ def ensure_schema( if isclass(schema): schema = cast(Type, schema) if issubclass(schema, fields.Field): - return field2property(schema()) + return field2property(spec, schema()) elif issubclass(schema, Schema): return schema() raise TypeError( diff --git a/tests/conftest.py b/tests/conftest.py index 8a2a5e0..b698039 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -213,6 +213,17 @@ def post(self, args): thing.add_view(TestFieldProperty, "/TestFieldProperty") + class TestNullableFieldProperty(PropertyView): + schema = fields.Integer(allow_none=True) + + def get(self): + return "one" + + def post(self, args): + pass + + thing.add_view(TestNullableFieldProperty, "/TestNullableFieldProperty") + class FailAction(ActionView): wait_for = 0.1 diff --git a/tests/test_openapi.py b/tests/test_openapi.py index 71786f5..5a81a5e 100644 --- a/tests/test_openapi.py +++ b/tests/test_openapi.py @@ -54,43 +54,49 @@ def post(self): assert original_input_schema != modified_input_schema -def test_ensure_schema_field_instance(): - ret = utilities.ensure_schema(fields.Integer()) +def test_ensure_schema_field_instance(spec): + ret = utilities.ensure_schema(spec, fields.Integer()) assert ret == {"type": "integer"} -def test_ensure_schema_field_class(): - ret = utilities.ensure_schema(fields.Integer) +def test_ensure_schema_nullable_field_instance(spec): + ret = utilities.ensure_schema(spec, fields.Integer(allow_none=True)) + assert ret == {"type": "integer", "nullable": True} + + +def test_ensure_schema_field_class(spec): + ret = utilities.ensure_schema(spec, fields.Integer) assert ret == {"type": "integer"} -def test_ensure_schema_class(): - ret = utilities.ensure_schema(LogRecordSchema) +def test_ensure_schema_class(spec): + ret = utilities.ensure_schema(spec, LogRecordSchema) assert isinstance(ret, Schema) -def test_ensure_schema_instance(): - ret = utilities.ensure_schema(LogRecordSchema()) +def test_ensure_schema_instance(spec): + ret = utilities.ensure_schema(spec, LogRecordSchema()) assert isinstance(ret, Schema) -def test_ensure_schema_dict(): +def test_ensure_schema_dict(spec): ret = utilities.ensure_schema( + spec, { "count": fields.Integer(), "name": fields.String(), - } + }, ) assert isinstance(ret, Schema) -def test_ensure_schema_none(): - assert utilities.ensure_schema(None) is None +def test_ensure_schema_none(spec): + assert utilities.ensure_schema(spec, None) is None -def test_ensure_schema_error(): +def test_ensure_schema_error(spec): with pytest.raises(TypeError): - utilities.ensure_schema(Exception) + utilities.ensure_schema(spec, Exception) def test_get_marshmallow_plugin(spec):