Skip to content

Commit

Permalink
Make the CI lint job pass again (#3691)
Browse files Browse the repository at this point in the history
* Fix variable redefinitions

* Ignore trust issues

* Narrow types of dynamically loaded plugins
  • Loading branch information
DoctorJohn authored Nov 7, 2024
1 parent 184bf28 commit 04936fd
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 18 deletions.
10 changes: 6 additions & 4 deletions strawberry/channels/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,17 @@ async def gql_init(self) -> None:
await self.send_json_to(
ConnectionInitMessage(payload=self.connection_params).as_dict()
)
response = await self.receive_json_from()
assert response == ConnectionAckMessage().as_dict()
graphql_transport_ws_response = await self.receive_json_from()
assert graphql_transport_ws_response == ConnectionAckMessage().as_dict()
else:
assert res == (True, GRAPHQL_WS_PROTOCOL)
await self.send_json_to(
GraphQLWSConnectionInitMessage({"type": "connection_init"})
)
response: GraphQLWSConnectionAckMessage = await self.receive_json_from()
assert response["type"] == "connection_ack"
graphql_ws_response: GraphQLWSConnectionAckMessage = (
await self.receive_json_from()
)
assert graphql_ws_response["type"] == "connection_ack"

# Actual `ExecutionResult`` objects are not available client-side, since they
# get transformed into `FormattedExecutionResult` on the wire, but we attempt
Expand Down
14 changes: 9 additions & 5 deletions strawberry/cli/commands/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import importlib
import inspect
from pathlib import Path # noqa: TCH003
from typing import List, Optional, Type
from typing import List, Optional, Type, Union, cast

import rich
import typer
Expand Down Expand Up @@ -61,7 +61,9 @@ def _import_plugin(plugin: str) -> Optional[Type[QueryCodegenPlugin]]:


@functools.lru_cache
def _load_plugin(plugin_path: str) -> Type[QueryCodegenPlugin]:
def _load_plugin(
plugin_path: str,
) -> Union[Type[QueryCodegenPlugin], Type[ConsolePlugin]]:
# try to import plugin_name from current folder
# then try to import from strawberry.codegen.plugins

Expand All @@ -77,7 +79,9 @@ def _load_plugin(plugin_path: str) -> Type[QueryCodegenPlugin]:
return plugin


def _load_plugins(plugin_ids: List[str], query: Path) -> List[QueryCodegenPlugin]:
def _load_plugins(
plugin_ids: List[str], query: Path
) -> List[Union[QueryCodegenPlugin, ConsolePlugin]]:
plugins = []
for ptype_id in plugin_ids:
ptype = _load_plugin(ptype_id)
Expand Down Expand Up @@ -127,11 +131,11 @@ def codegen(

console_plugin_type = _load_plugin(cli_plugin) if cli_plugin else ConsolePlugin
console_plugin = console_plugin_type(output_dir)
assert isinstance(console_plugin, ConsolePlugin)
console_plugin.before_any_start()

for q in query:
plugins = _load_plugins(selected_plugins, q)
console_plugin.query = q # update the query in the console plugin.
plugins = cast(List[QueryCodegenPlugin], _load_plugins(selected_plugins, q))

code_generator = QueryCodegen(
schema_symbol, plugins=plugins, console_plugin=console_plugin
Expand Down
4 changes: 2 additions & 2 deletions strawberry/experimental/pydantic/object_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _build_dataclass_creation_fields(

return DataclassCreationFields(
name=field.name,
field_type=field_type,
field_type=field_type, # type: ignore
field=strawberry_field,
)

Expand Down Expand Up @@ -198,7 +198,7 @@ def wrap(cls: Any) -> Type[StrawberryTypeFromPydantic[PydanticModel]]:
all_model_fields = [
DataclassCreationFields(
name=field.name,
field_type=field.type,
field_type=field.type, # type: ignore
field=field,
)
for field in extra_fields + private_fields
Expand Down
15 changes: 8 additions & 7 deletions strawberry/subscriptions/protocols/graphql_ws/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,16 +169,17 @@ async def handle_async_results(
await self.websocket.send_json(error_message)
else:
self.subscriptions[operation_id] = agen_or_err

async for result in agen_or_err:
await self.send_data(result, operation_id)
complete_message: CompleteMessage = {
"type": "complete",
"id": operation_id,
}
await self.websocket.send_json(complete_message)

await self.websocket.send_json(
CompleteMessage({"type": "complete", "id": operation_id})
)
except asyncio.CancelledError:
complete_message: CompleteMessage = {"type": "complete", "id": operation_id}
await self.websocket.send_json(complete_message)
await self.websocket.send_json(
CompleteMessage({"type": "complete", "id": operation_id})
)

async def cleanup_operation(self, operation_id: str) -> None:
if operation_id in self.subscriptions:
Expand Down

0 comments on commit 04936fd

Please sign in to comment.