Skip to content

Commit

Permalink
Merge pull request #281 from mirumee/277-proof-of-concept
Browse files Browse the repository at this point in the history
Escape field arg names if they are conflicting with generated client method variables
  • Loading branch information
rafalp authored Mar 4, 2024
2 parents c720473 + 7fcfd12 commit 379ba40
Show file tree
Hide file tree
Showing 3 changed files with 542 additions and 38 deletions.
124 changes: 86 additions & 38 deletions ariadne_codegen/client_generators/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ def add_method(
arguments, arguments_dict = self.arguments_generator.generate(
definition.variable_definitions
)

variable_names = self.get_variable_names(arguments)

operation_name = definition.name.value if definition.name else ""
if definition.operation == OperationType.SUBSCRIPTION:
if not async_:
Expand All @@ -149,6 +152,7 @@ def add_method(
arguments=arguments,
arguments_dict=arguments_dict,
operation_str=operation_str,
variable_names=variable_names,
)
)
elif async_:
Expand All @@ -159,6 +163,7 @@ def add_method(
arguments_dict=arguments_dict,
operation_str=operation_str,
operation_name=operation_name,
variable_names=variable_names,
)
else:
method_def = self._generate_method(
Expand All @@ -168,6 +173,7 @@ def add_method(
arguments_dict=arguments_dict,
operation_str=operation_str,
operation_name=operation_name,
variable_names=variable_names,
)

method_def.lineno = len(self._class_def.body) + 1
Expand All @@ -181,6 +187,23 @@ def add_method(
generate_import_from(names=[return_type], from_=return_type_module, level=1)
)

def get_variable_names(self, arguments: ast.arguments) -> Dict[str, str]:
mapped_variable_names = [
self._operation_str_variable,
self._variables_dict_variable,
self._response_variable,
self._data_variable,
]
variable_names = {}
argument_names = set(arg.arg for arg in arguments.args)

for variable in mapped_variable_names:
variable_names[variable] = (
f"_{variable}" if variable in argument_names else variable
)

return variable_names

def _add_import(self, import_: Optional[ast.ImportFrom] = None):
if not import_:
return
Expand All @@ -197,6 +220,7 @@ def _generate_subscription_method_def(
arguments: ast.arguments,
arguments_dict: ast.Dict,
operation_str: str,
variable_names: Dict[str, str],
) -> ast.AsyncFunctionDef:
return generate_async_method_definition(
name=name,
Expand All @@ -205,9 +229,11 @@ def _generate_subscription_method_def(
value=generate_name(ASYNC_ITERATOR), slice_=generate_name(return_type)
),
body=[
self._generate_operation_str_assign(operation_str, 1),
self._generate_variables_assign(arguments_dict, 2),
self._generate_async_generator_loop(operation_name, return_type, 3),
self._generate_operation_str_assign(variable_names, operation_str, 1),
self._generate_variables_assign(variable_names, arguments_dict, 2),
self._generate_async_generator_loop(
variable_names, operation_name, return_type, 3
),
],
)

Expand All @@ -219,17 +245,18 @@ def _generate_async_method(
arguments_dict: ast.Dict,
operation_str: str,
operation_name: str,
variable_names: Dict[str, str],
) -> ast.AsyncFunctionDef:
return generate_async_method_definition(
name=name,
arguments=arguments,
return_type=generate_name(return_type),
body=[
self._generate_operation_str_assign(operation_str, 1),
self._generate_variables_assign(arguments_dict, 2),
self._generate_async_response_assign(operation_name, 3),
self._generate_data_retrieval(),
self._generate_return_parsed_obj(return_type),
self._generate_operation_str_assign(variable_names, operation_str, 1),
self._generate_variables_assign(variable_names, arguments_dict, 2),
self._generate_async_response_assign(variable_names, operation_name, 3),
self._generate_data_retrieval(variable_names),
self._generate_return_parsed_obj(variable_names, return_type),
],
)

Expand All @@ -241,25 +268,26 @@ def _generate_method(
arguments_dict: ast.Dict,
operation_str: str,
operation_name: str,
variable_names: Dict[str, str],
) -> ast.FunctionDef:
return generate_method_definition(
name=name,
arguments=arguments,
return_type=generate_name(return_type),
body=[
self._generate_operation_str_assign(operation_str, 1),
self._generate_variables_assign(arguments_dict, 2),
self._generate_response_assign(operation_name, 3),
self._generate_data_retrieval(),
self._generate_return_parsed_obj(return_type),
self._generate_operation_str_assign(variable_names, operation_str, 1),
self._generate_variables_assign(variable_names, arguments_dict, 2),
self._generate_response_assign(variable_names, operation_name, 3),
self._generate_data_retrieval(variable_names),
self._generate_return_parsed_obj(variable_names, return_type),
],
)

def _generate_operation_str_assign(
self, operation_str: str, lineno: int = 1
self, variable_names: Dict[str, str], operation_str: str, lineno: int = 1
) -> ast.Assign:
return generate_assign(
targets=[self._operation_str_variable],
targets=[variable_names[self._operation_str_variable]],
value=generate_call(
func=generate_name(self._gql_func_name),
args=[
Expand All @@ -270,10 +298,10 @@ def _generate_operation_str_assign(
)

def _generate_variables_assign(
self, arguments_dict: ast.Dict, lineno: int = 1
self, variable_names: Dict[str, str], arguments_dict: ast.Dict, lineno: int = 1
) -> ast.AnnAssign:
return generate_ann_assign(
target=self._variables_dict_variable,
target=variable_names[self._variables_dict_variable],
annotation=generate_subscript(
generate_name(DICT),
generate_tuple([generate_name("str"), generate_name("object")]),
Expand All @@ -283,95 +311,115 @@ def _generate_variables_assign(
)

def _generate_async_response_assign(
self, operation_name: str, lineno: int = 1
self, variable_names: Dict[str, str], operation_name: str, lineno: int = 1
) -> ast.Assign:
return generate_assign(
targets=[self._response_variable],
targets=[variable_names[self._response_variable]],
value=generate_await(
self._generate_execute_call(operation_name=operation_name)
self._generate_execute_call(variable_names, operation_name)
),
lineno=lineno,
)

def _generate_response_assign(
self, operation_name: str, lineno: int = 1
self,
variable_names: Dict[str, str],
operation_name: str,
lineno: int = 1,
) -> ast.Assign:
return generate_assign(
targets=[self._response_variable],
value=self._generate_execute_call(operation_name=operation_name),
targets=[variable_names[self._response_variable]],
value=self._generate_execute_call(variable_names, operation_name),
lineno=lineno,
)

def _generate_execute_call(self, operation_name: str) -> ast.Call:
def _generate_execute_call(
self, variable_names: Dict[str, str], operation_name: str
) -> ast.Call:
return generate_call(
func=generate_attribute(generate_name("self"), "execute"),
keywords=[
generate_keyword(
value=generate_name(self._operation_str_variable), arg="query"
value=generate_name(variable_names[self._operation_str_variable]),
arg="query",
),
generate_keyword(
value=generate_constant(operation_name), arg="operation_name"
),
generate_keyword(
value=generate_name(self._variables_dict_variable), arg="variables"
value=generate_name(variable_names[self._variables_dict_variable]),
arg="variables",
),
generate_keyword(value=generate_name(KWARGS_NAMES)),
],
)

def _generate_data_retrieval(self) -> ast.Assign:
def _generate_data_retrieval(self, variable_names: Dict[str, str]) -> ast.Assign:
return generate_assign(
targets=[self._data_variable],
targets=[variable_names[self._data_variable]],
value=generate_call(
func=generate_attribute(value=generate_name("self"), attr="get_data"),
args=[generate_name(self._response_variable)],
args=[generate_name(variable_names[self._response_variable])],
),
)

def _generate_return_parsed_obj(self, return_type: str) -> ast.Return:
def _generate_return_parsed_obj(
self, variable_names: Dict[str, str], return_type: str
) -> ast.Return:
return generate_return(
generate_call(
func=generate_attribute(
generate_name(return_type), MODEL_VALIDATE_METHOD
),
args=[generate_name(self._data_variable)],
args=[generate_name(variable_names[self._data_variable])],
)
)

def _generate_async_generator_loop(
self, operation_name: str, return_type: str, lineno: int = 1
self,
variable_names: Dict[str, str],
operation_name: str,
return_type: str,
lineno: int = 1,
) -> ast.AsyncFor:
return generate_async_for(
target=generate_name(self._data_variable),
target=generate_name(variable_names[self._data_variable]),
iter_=generate_call(
func=generate_attribute(value=generate_name("self"), attr="execute_ws"),
keywords=[
generate_keyword(
value=generate_name(self._operation_str_variable), arg="query"
value=generate_name(
variable_names[self._operation_str_variable]
),
arg="query",
),
generate_keyword(
value=generate_constant(operation_name), arg="operation_name"
),
generate_keyword(
value=generate_name(self._variables_dict_variable),
value=generate_name(
variable_names[self._variables_dict_variable]
),
arg="variables",
),
generate_keyword(value=generate_name(KWARGS_NAMES)),
],
),
body=[self._generate_yield_parsed_obj(return_type)],
body=[self._generate_yield_parsed_obj(variable_names, return_type)],
lineno=lineno,
)

def _generate_yield_parsed_obj(self, return_type: str) -> ast.Expr:
def _generate_yield_parsed_obj(
self, variable_names: Dict[str, str], return_type: str
) -> ast.Expr:
return generate_expr(
generate_yield(
generate_call(
func=generate_attribute(
value=generate_name(return_type),
attr=MODEL_VALIDATE_METHOD,
),
args=[generate_name(self._data_variable)],
args=[generate_name(variable_names[self._data_variable])],
)
)
)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ disable = [
"duplicate-code",
"no-name-in-module",
"too-many-locals",
"too-many-lines",
]

[tool.pytest.ini_options]
Expand Down
Loading

0 comments on commit 379ba40

Please sign in to comment.