Skip to content

Commit 6c6c550

Browse files
authored
Support resolver schema in Annotated return type (#73)
1 parent 14254a0 commit 6c6c550

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

apischema/graphql/schema.py

+9-10
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,10 @@ class ObjectField:
136136
description: Optional[str] = field_(init=False, default=None)
137137

138138
def __post_init__(self, schema: Optional[Schema]):
139+
if get_origin(self.type) == Annotated:
140+
for annotation in reversed(get_args(self.type)[1:]):
141+
if isinstance(annotation, Mapping) and SCHEMA_METADATA in annotation:
142+
schema = merge_schema(annotation[SCHEMA_METADATA], schema)
139143
if schema is not None and schema.annotations is not None:
140144
object.__setattr__(self, "description", schema.annotations.description)
141145
if schema.annotations.deprecated is True:
@@ -203,7 +207,7 @@ def annotated(
203207
raise ValueError("Annotated schema_ref can only be str")
204208
self._ref = self._ref or ref
205209
if isinstance(annotation, Mapping) and SCHEMA_METADATA in annotation:
206-
self._schema = merge_schema(annotation[SCHEMA_METADATA], self._schema)
210+
self._schema = merge_schema(self._schema, annotation[SCHEMA_METADATA])
207211
return self.visit_with_schema(tp, self._ref, self._schema)
208212

209213
def any(self) -> Thunk[graphql.GraphQLType]:
@@ -222,9 +226,7 @@ def _object_field(self, field: Field, field_type: AnyType) -> ObjectField:
222226
alias=get_alias(field),
223227
conversions=get_field_conversions(field, self.operation),
224228
default=graphql.Undefined if is_required(field) else get_default(field),
225-
schema=merge_schema(
226-
annotated_schema(field_type), field.metadata.get(SCHEMA_METADATA)
227-
),
229+
schema=field.metadata.get(SCHEMA_METADATA),
228230
)
229231

230232
def _visit_merged(
@@ -288,7 +290,6 @@ def named_tuple(
288290
field_name,
289291
field_type,
290292
default=defaults.get(field_name, graphql.Undefined),
291-
schema=annotated_schema(field_type),
292293
)
293294
for field_name, field_type in types.items()
294295
]
@@ -449,11 +450,9 @@ def object(
449450
def typed_dict(
450451
self, cls: Type, keys: Mapping[str, AnyType], total: bool
451452
) -> Thunk[graphql.GraphQLType]:
452-
fields = [
453-
ObjectField(name, type, schema=annotated_schema(type))
454-
for name, type in keys.items()
455-
]
456-
return self.object(cls, fields)
453+
return self.object(
454+
cls, [ObjectField(name, type) for name, type in keys.items()]
455+
)
457456

458457
def _union_result(
459458
self, results: Iterable[Thunk[graphql.GraphQLType]]

0 commit comments

Comments
 (0)