Skip to content

Commit

Permalink
fix(printer): Fix printer ignoring input arguments using snake_case (#…
Browse files Browse the repository at this point in the history
…3780)

* fix(printer): Fix printer ignoring input arguments using snake_case

Fix #3760

* Update strawberry/printer/printer.py

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* Update RELEASE.md

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* one missing propagation

* skip test on gql2 due to different formatting

---------

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
  • Loading branch information
bellini666 and sourcery-ai[bot] authored Feb 13, 2025
1 parent b8bda78 commit 826287f
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 8 deletions.
65 changes: 65 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
Release type: patch

This release fixes an issue where directives with input types using snake_case
would not be printed in the schema.

For example, the following:

```python
@strawberry.input
class FooInput:
hello: str
hello_world: str


@strawberry.schema_directive(locations=[Location.FIELD_DEFINITION])
class FooDirective:
input: FooInput


@strawberry.type
class Query:
@strawberry.field(
directives=[
FooDirective(input=FooInput(hello="hello", hello_world="hello world")),
]
)
def foo(self, info) -> str: ...
```

Would previously print as:

```graphql
directive @fooDirective(
input: FooInput!
optionalInput: FooInput
) on FIELD_DEFINITION

type Query {
foo: String! @fooDirective(input: { hello: "hello" })
}

input FooInput {
hello: String!
hello_world: String!
}
```

Now it will be correctly printed as:

```graphql
directive @fooDirective(
input: FooInput!
optionalInput: FooInput
) on FIELD_DEFINITION

type Query {
foo: String!
@fooDirective(input: { hello: "hello", helloWorld: "hello world" })
}

input FooInput {
hello: String!
hello_world: String!
}
```
51 changes: 43 additions & 8 deletions strawberry/printer/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import (
TYPE_CHECKING,
Any,
Callable,
Optional,
TypeVar,
Union,
Expand Down Expand Up @@ -68,40 +69,73 @@ class PrintExtras:


@overload
def _serialize_dataclasses(value: dict[_T, object]) -> dict[_T, object]: ...
def _serialize_dataclasses(
value: dict[_T, object],
*,
name_converter: Callable[[str], str] | None = None,
) -> dict[_T, object]: ...


@overload
def _serialize_dataclasses(
value: Union[list[object], tuple[object]],
*,
name_converter: Callable[[str], str] | None = None,
) -> list[object]: ...


@overload
def _serialize_dataclasses(value: object) -> object: ...
def _serialize_dataclasses(
value: object,
*,
name_converter: Callable[[str], str] | None = None,
) -> object: ...


def _serialize_dataclasses(value):
def _serialize_dataclasses(
value,
*,
name_converter: Callable[[str], str] | None = None,
):
if name_converter is None:
name_converter = lambda x: x # noqa: E731

if dataclasses.is_dataclass(value):
return {k: v for k, v in dataclasses.asdict(value).items() if v is not UNSET} # type: ignore
return {
name_converter(k): v
for k, v in dataclasses.asdict(value).items() # type: ignore
if v is not UNSET
}
if isinstance(value, (list, tuple)):
return [_serialize_dataclasses(v) for v in value]
return [_serialize_dataclasses(v, name_converter=name_converter) for v in value]
if isinstance(value, dict):
return {k: _serialize_dataclasses(v) for k, v in value.items()}
return {
name_converter(k): _serialize_dataclasses(v, name_converter=name_converter)
for k, v in value.items()
}

return value


def print_schema_directive_params(
directive: GraphQLDirective, values: dict[str, Any]
directive: GraphQLDirective,
values: dict[str, Any],
*,
schema: BaseSchema,
) -> str:
params = []
for name, arg in directive.args.items():
value = values.get(name, arg.default_value)
if value is UNSET:
value = None
else:
ast = ast_from_value(_serialize_dataclasses(value), arg.type)
ast = ast_from_value(
_serialize_dataclasses(
value,
name_converter=schema.config.name_converter.apply_naming_config,
),
arg.type,
)
value = ast and f"{name}: {print_ast(ast)}"

if value:
Expand Down Expand Up @@ -129,6 +163,7 @@ def print_schema_directive(
)
for f in strawberry_directive.fields
},
schema=schema,
)

printed_directive = print_directive(gql_directive, schema=schema)
Expand Down
41 changes: 41 additions & 0 deletions tests/test_printer/test_schema_directives.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,3 +749,44 @@ def foo(self, info) -> str: ...
schema = strawberry.Schema(query=Query)

assert print_schema(schema) == textwrap.dedent(expected_output).strip()


@skip_if_gql_32("formatting is different in gql 3.2")
def test_print_directive_with_snake_case_arguments():
@strawberry.input
class FooInput:
hello: str
hello_world: str

@strawberry.schema_directive(locations=[Location.FIELD_DEFINITION])
class FooDirective:
input: FooInput
optional_input: Optional[FooInput] = strawberry.UNSET

@strawberry.type
class Query:
@strawberry.field(
directives=[
FooDirective(input=FooInput(hello="hello", hello_world="hello world"))
]
)
def foo(self, info) -> str: ...

schema = strawberry.Schema(query=Query)

expected_output = """
directive @fooDirective(input: FooInput!, optionalInput: FooInput) on FIELD_DEFINITION
type Query {
foo: String! @fooDirective(input: { hello: "hello", helloWorld: "hello world" })
}
input FooInput {
hello: String!
helloWorld: String!
}
"""

schema = strawberry.Schema(query=Query)

assert print_schema(schema) == textwrap.dedent(expected_output).strip()

0 comments on commit 826287f

Please sign in to comment.