Skip to content

Commit c08d12a

Browse files
committed
Remove redundant type_to_visitor_function entries
1 parent 3872d57 commit c08d12a

File tree

2 files changed

+33
-42
lines changed

2 files changed

+33
-42
lines changed

crates/ruff_python_ast/generate.py

Lines changed: 31 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,23 @@
4040
}
4141

4242

43+
@dataclass
44+
class VisitorInfo:
45+
name: str
46+
accepts_sequence: bool = False
47+
48+
49+
# Map of AST node types to their corresponding visitor information.
50+
# Only visitors that are different from the default `visit_*` method are included.
51+
# These visitors either have a different name or accept a sequence of items.
52+
type_to_visitor_function: dict[str, VisitorInfo] = {
53+
"TypeParams": VisitorInfo("visit_type_params", True),
54+
"Parameters": VisitorInfo("visit_parameters", True),
55+
"Stmt": VisitorInfo("visit_body", True),
56+
"Arguments": VisitorInfo("visit_arguments", True),
57+
}
58+
59+
4360
def rustfmt(code: str) -> str:
4461
return check_output(["rustfmt", "--emit=stdout"], input=code, text=True)
4562

@@ -202,6 +219,7 @@ def extract_type_argument(rust_type_str: str) -> str:
202219
if close_bracket_index == -1 or close_bracket_index <= open_bracket_index:
203220
raise ValueError(f"Brackets are not balanced for type {rust_type_str}")
204221
inner_type = rust_type_str[open_bracket_index + 1 : close_bracket_index].strip()
222+
inner_type = inner_type.replace("crate::", "")
205223
return inner_type
206224

207225

@@ -766,39 +784,6 @@ def write_node(out: list[str], ast: Ast) -> None:
766784
# Source order visitor
767785

768786

769-
@dataclass
770-
class VisitorInfo:
771-
name: str
772-
accepts_sequence: bool = False
773-
774-
775-
# Map of AST node types to their corresponding visitor information
776-
type_to_visitor_function: dict[str, VisitorInfo] = {
777-
"Decorator": VisitorInfo("visit_decorator"),
778-
"Identifier": VisitorInfo("visit_identifier"),
779-
"crate::TypeParams": VisitorInfo("visit_type_params", True),
780-
"crate::Parameters": VisitorInfo("visit_parameters", True),
781-
"Expr": VisitorInfo("visit_expr"),
782-
"Stmt": VisitorInfo("visit_body", True),
783-
"Arguments": VisitorInfo("visit_arguments", True),
784-
"crate::Arguments": VisitorInfo("visit_arguments", True),
785-
"Operator": VisitorInfo("visit_operator"),
786-
"ElifElseClause": VisitorInfo("visit_elif_else_clause"),
787-
"WithItem": VisitorInfo("visit_with_item"),
788-
"MatchCase": VisitorInfo("visit_match_case"),
789-
"ExceptHandler": VisitorInfo("visit_except_handler"),
790-
"Alias": VisitorInfo("visit_alias"),
791-
"UnaryOp": VisitorInfo("visit_unary_op"),
792-
"DictItem": VisitorInfo("visit_dict_item"),
793-
"Comprehension": VisitorInfo("visit_comprehension"),
794-
"CmpOp": VisitorInfo("visit_cmp_op"),
795-
"FStringValue": VisitorInfo("visit_f_string_value"),
796-
"StringLiteralValue": VisitorInfo("visit_string_literal"),
797-
"BytesLiteralValue": VisitorInfo("visit_bytes_literal"),
798-
}
799-
annotation_visitor_function = VisitorInfo("visit_annotation")
800-
801-
802787
def write_source_order(out: list[str], ast: Ast) -> None:
803788
for group in ast.groups:
804789
for node in group.nodes:
@@ -816,24 +801,30 @@ def write_source_order(out: list[str], ast: Ast) -> None:
816801
fields_list += "range: _,\n"
817802

818803
for field in node.fields_in_source_order():
819-
visitor = type_to_visitor_function[field.parsed_ty.inner]
820-
if field.is_annotation:
821-
visitor = annotation_visitor_function
804+
visitor_name = (
805+
type_to_visitor_function.get(
806+
field.parsed_ty.inner, VisitorInfo("")
807+
).name
808+
or f"visit_{to_snake_case(field.parsed_ty.inner)}"
809+
)
810+
visits_sequence = type_to_visitor_function.get(
811+
field.parsed_ty.inner, VisitorInfo("")
812+
).accepts_sequence
822813

823814
if field.parsed_ty.optional:
824815
body += f"""
825816
if let Some({field.name}) = {field.name} {{
826-
visitor.{visitor.name}({field.name});
817+
visitor.{visitor_name}({field.name});
827818
}}\n
828819
"""
829-
elif not visitor.accepts_sequence and field.parsed_ty.seq:
820+
elif not visits_sequence and field.parsed_ty.seq:
830821
body += f"""
831822
for elm in {field.name} {{
832-
visitor.{visitor.name}(elm);
823+
visitor.{visitor_name}(elm);
833824
}}
834825
"""
835826
else:
836-
body += f"visitor.{visitor.name}({field.name});\n"
827+
body += f"visitor.{visitor_name}({field.name});\n"
837828

838829
visitor_arg_name = "visitor"
839830
if len(node.fields_in_source_order()) == 0:

crates/ruff_python_ast/src/generated.rs

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)