4040}
4141
4242
43+ @dataclass
44+ class VisitorInfo :
45+ name : str
46+ accepts_sequence : bool = False
47+
48+
49+ # Map of AST node types to their corresponding visitor information.
50+ # Only visitors that are different from the default `visit_*` method are included.
51+ # These visitors either have a different name or accept a sequence of items.
52+ type_to_visitor_function : dict [str , VisitorInfo ] = {
53+ "TypeParams" : VisitorInfo ("visit_type_params" , True ),
54+ "Parameters" : VisitorInfo ("visit_parameters" , True ),
55+ "Stmt" : VisitorInfo ("visit_body" , True ),
56+ "Arguments" : VisitorInfo ("visit_arguments" , True ),
57+ }
58+
59+
4360def rustfmt (code : str ) -> str :
4461 return check_output (["rustfmt" , "--emit=stdout" ], input = code , text = True )
4562
@@ -202,6 +219,7 @@ def extract_type_argument(rust_type_str: str) -> str:
202219 if close_bracket_index == - 1 or close_bracket_index <= open_bracket_index :
203220 raise ValueError (f"Brackets are not balanced for type { rust_type_str } " )
204221 inner_type = rust_type_str [open_bracket_index + 1 : close_bracket_index ].strip ()
222+ inner_type = inner_type .replace ("crate::" , "" )
205223 return inner_type
206224
207225
@@ -766,39 +784,6 @@ def write_node(out: list[str], ast: Ast) -> None:
766784# Source order visitor
767785
768786
769- @dataclass
770- class VisitorInfo :
771- name : str
772- accepts_sequence : bool = False
773-
774-
775- # Map of AST node types to their corresponding visitor information
776- type_to_visitor_function : dict [str , VisitorInfo ] = {
777- "Decorator" : VisitorInfo ("visit_decorator" ),
778- "Identifier" : VisitorInfo ("visit_identifier" ),
779- "crate::TypeParams" : VisitorInfo ("visit_type_params" , True ),
780- "crate::Parameters" : VisitorInfo ("visit_parameters" , True ),
781- "Expr" : VisitorInfo ("visit_expr" ),
782- "Stmt" : VisitorInfo ("visit_body" , True ),
783- "Arguments" : VisitorInfo ("visit_arguments" , True ),
784- "crate::Arguments" : VisitorInfo ("visit_arguments" , True ),
785- "Operator" : VisitorInfo ("visit_operator" ),
786- "ElifElseClause" : VisitorInfo ("visit_elif_else_clause" ),
787- "WithItem" : VisitorInfo ("visit_with_item" ),
788- "MatchCase" : VisitorInfo ("visit_match_case" ),
789- "ExceptHandler" : VisitorInfo ("visit_except_handler" ),
790- "Alias" : VisitorInfo ("visit_alias" ),
791- "UnaryOp" : VisitorInfo ("visit_unary_op" ),
792- "DictItem" : VisitorInfo ("visit_dict_item" ),
793- "Comprehension" : VisitorInfo ("visit_comprehension" ),
794- "CmpOp" : VisitorInfo ("visit_cmp_op" ),
795- "FStringValue" : VisitorInfo ("visit_f_string_value" ),
796- "StringLiteralValue" : VisitorInfo ("visit_string_literal" ),
797- "BytesLiteralValue" : VisitorInfo ("visit_bytes_literal" ),
798- }
799- annotation_visitor_function = VisitorInfo ("visit_annotation" )
800-
801-
802787def write_source_order (out : list [str ], ast : Ast ) -> None :
803788 for group in ast .groups :
804789 for node in group .nodes :
@@ -816,24 +801,30 @@ def write_source_order(out: list[str], ast: Ast) -> None:
816801 fields_list += "range: _,\n "
817802
818803 for field in node .fields_in_source_order ():
819- visitor = type_to_visitor_function [field .parsed_ty .inner ]
820- if field .is_annotation :
821- visitor = annotation_visitor_function
804+ visitor_name = (
805+ type_to_visitor_function .get (
806+ field .parsed_ty .inner , VisitorInfo ("" )
807+ ).name
808+ or f"visit_{ to_snake_case (field .parsed_ty .inner )} "
809+ )
810+ visits_sequence = type_to_visitor_function .get (
811+ field .parsed_ty .inner , VisitorInfo ("" )
812+ ).accepts_sequence
822813
823814 if field .parsed_ty .optional :
824815 body += f"""
825816 if let Some({ field .name } ) = { field .name } {{
826- visitor.{ visitor . name } ({ field .name } );
817+ visitor.{ visitor_name } ({ field .name } );
827818 }}\n
828819 """
829- elif not visitor . accepts_sequence and field .parsed_ty .seq :
820+ elif not visits_sequence and field .parsed_ty .seq :
830821 body += f"""
831822 for elm in { field .name } {{
832- visitor.{ visitor . name } (elm);
823+ visitor.{ visitor_name } (elm);
833824 }}
834825 """
835826 else :
836- body += f"visitor.{ visitor . name } ({ field .name } );\n "
827+ body += f"visitor.{ visitor_name } ({ field .name } );\n "
837828
838829 visitor_arg_name = "visitor"
839830 if len (node .fields_in_source_order ()) == 0 :
0 commit comments