@@ -190,8 +190,60 @@ def get_full_attribute_name(node: Union[cst.Attribute, cst.Name]) -> Optional[st
190190 return None
191191
192192
193- class SuperTransformer (cst .CSTTransformer ):
194- METADATA_DEPENDENCIES = (ParentNodeProvider ,)
193+ class ReplaceParentClassCallTransformer (cst .CSTTransformer ):
194+ """
195+ This Transformer is used to replace all calls of the form `module.Class.func(...)` by a call of the form
196+ `super().func(...)`.
197+ """
198+
199+ def __init__ (self , new_bases : list [str ]):
200+ self .new_bases = new_bases
201+
202+ def is_call_to_parent_class (self , node : cst .SimpleStatementLine ):
203+ """Check whether `node` corresponds to a call to a parent class function, such as `module.Parent.func_name(...)`"""
204+ parent_call_node = m .Call (func = m .Attribute (value = m .Name () | m .Attribute ()))
205+ # It can be used as a return, simple expression, or assignment
206+ potential_bodies = [m .Return (parent_call_node ) | m .Expr (parent_call_node ) | m .Assign (value = parent_call_node )]
207+ return m .matches (node , m .SimpleStatementLine (body = potential_bodies ))
208+
209+ def leave_SimpleStatementLine (
210+ self , original_node : cst .SimpleStatementLine , updated_node : cst .SimpleStatementLine
211+ ) -> cst .SimpleStatementLine :
212+ """Replace a call of the form `module.Class.func(...)` by a call of the form `super().func(...)`
213+ if the `Class` being called is one of the bases."""
214+ if self .is_call_to_parent_class (updated_node ):
215+ expr_node = updated_node .body [0 ]
216+ full_parent_class_name = get_full_attribute_name (expr_node .value .func .value )
217+ # Replace only if it's a base, or a few special rules
218+ if (
219+ full_parent_class_name in self .new_bases
220+ or ("nn.Module" in full_parent_class_name and self .new_bases == ["GradientCheckpointingLayer" ])
221+ or (
222+ full_parent_class_name == "PreTrainedModel"
223+ and any ("PreTrainedModel" in base for base in self .new_bases )
224+ )
225+ ):
226+ # Replace `full_parent_class_name.func(...)` with `super().func(...)`
227+ attribute_node = expr_node .value .func .with_changes (value = cst .Call (func = cst .Name ("super" )))
228+ # Check if the first argument is 'self', and remove it
229+ new_args = (
230+ expr_node .value .args [1 :]
231+ if len (expr_node .value .args ) > 0 and m .matches (expr_node .value .args [0 ].value , m .Name ("self" ))
232+ else expr_node .value .args
233+ )
234+ call_node = expr_node .value .with_changes (func = attribute_node , args = new_args )
235+ new_expr_node = expr_node .with_changes (value = call_node )
236+ return updated_node .with_changes (body = [new_expr_node ])
237+ return updated_node
238+
239+
240+ class ReplaceSuperCallTransformer (cst .CSTTransformer ):
241+ """
242+ This Transformer is used to unravel all calls to `super().func(...)` in class methods by the explicit parent's
243+ code. It will also in turn replace all calls of the form `module.Class.func(...)` by a call of the form
244+ `super().func(...)`. Those calls are used to explicitly skip the unravelling of code, but we should still follow
245+ python's standards and use `super().func(...)` instead of `Parent.func(self, ...)`.
246+ """
195247
196248 def __init__ (
197249 self ,
@@ -205,7 +257,8 @@ def __init__(
205257 self .modular_methods = modular_methods
206258 self .all_assign_target = {}
207259 self .deleted_targets = {} # child node can delete some arguments
208- self .new_bases = [get_full_attribute_name (base .value ) for base in new_bases ]
260+ new_bases = [get_full_attribute_name (base .value ) for base in new_bases ]
261+ self .parent_class_call_transformer = ReplaceParentClassCallTransformer (new_bases )
209262
210263 def update_body (self , existing_body , new_statements ):
211264 """
@@ -283,45 +336,14 @@ def _fix_init_location(self, new_body):
283336 break
284337 return new_body
285338
286- def replace_parent_class_call (self , node : cst .SimpleStatementLine ) -> cst .SimpleStatementLine :
287- """Replace a call of the form `module.Class.func(...)` by a call of the form `super().func(...)`
288- if the `Class` being called is one of the bases."""
289- expr_node = node .body [0 ]
290- full_parent_class_name = get_full_attribute_name (expr_node .value .func .value )
291- # Replace only if it's a base, or if using nn.Module on a GradientCheckpointingLayer
292- if (
293- full_parent_class_name in self .new_bases
294- or ("nn.Module" in full_parent_class_name and self .new_bases == ["GradientCheckpointingLayer" ])
295- or (
296- full_parent_class_name == "PreTrainedModel"
297- and any ("PreTrainedModel" in base for base in self .new_bases )
298- )
299- ):
300- # Replace `full_parent_class_name.func(...)` with `super().func(...)`
301- attribute_node = expr_node .value .func .with_changes (value = cst .Call (func = cst .Name ("super" )))
302- call_node = expr_node .value .with_changes (func = attribute_node )
303- # Check if the first argument is 'self', and remove it
304- new_args = (
305- call_node .args [1 :]
306- if len (call_node .args ) > 0 and m .matches (call_node .args [0 ].value , m .Name ("self" ))
307- else call_node .args
308- )
309- new_expr_node = expr_node .with_changes (value = call_node .with_changes (args = new_args ))
310- return node .with_changes (body = [new_expr_node ])
311- return node
312-
313339 def is_call_to_super (self , node : cst .BaseStatement , func_name : str ):
314340 """Check whether `node` corresponds to a call to `super().func_name(...)`"""
315341 super_call_node = m .Call (func = m .Attribute (value = m .Call (func = m .Name ("super" )), attr = m .Name (func_name )))
316342 return m .matches (node , m .SimpleStatementLine (body = [m .Return (super_call_node ) | m .Expr (super_call_node )]))
317343
318- def is_call_to_parent_class (self , node : cst .BaseStatement ):
319- """Check whether `node` corresponds to a call to a parent class function, such as `module.Parent.func_name(...)`"""
320- parent_call_node = m .Call (func = m .Attribute (value = m .Name () | m .Attribute ()))
321- return m .matches (node , m .SimpleStatementLine (body = [m .Return (parent_call_node ) | m .Expr (parent_call_node )]))
322-
323344 def leave_FunctionDef (self , original_node : cst .FunctionDef , updated_node : cst .FunctionDef ) -> cst .FunctionDef :
324345 func_name = updated_node .name .value
346+ self .should_check_statements = False
325347 if func_name in self .modular_methods :
326348 actual_body = updated_node .body .body # first body is an `IndentedBlock` wrapper
327349 new_body = []
@@ -332,11 +354,9 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu
332354 new_body = self ._fix_init_location (new_body )
333355 # Break here as all future statement were already accounted for in `update_body`
334356 break
335- elif self .is_call_to_parent_class (base_statement_node ):
336- new_body .append (self .replace_parent_class_call (base_statement_node ))
337- else :
338- new_body .append (base_statement_node )
339-
357+ # If not a call to super, this will replace all calls of the form `module.Class.func(...)` by a
358+ # call of the form `super().func(...)
359+ new_body .append (base_statement_node .visit (self .parent_class_call_transformer ))
340360 return updated_node .with_changes (body = updated_node .body .with_changes (body = new_body ))
341361 return updated_node
342362
@@ -1028,9 +1048,8 @@ def replace_class_node(
10281048 # Replace the calls to `super()` of the redefined modular methods with the unrolled code
10291049 result_node = original_modeling_node .with_changes (body = cst .IndentedBlock (body = new_class_body ))
10301050 temp_module = cst .Module (body = [result_node ])
1031- new_module = MetadataWrapper (temp_module )
1032- new_replacement_class = new_module .visit (
1033- SuperTransformer (temp_module , original_modeling_methods , modular_methods , new_class_bases )
1051+ new_replacement_class = temp_module .visit (
1052+ ReplaceSuperCallTransformer (temp_module , original_modeling_methods , modular_methods , new_class_bases )
10341053 )
10351054 new_class_body = new_replacement_class .body [0 ].body # get the indented block
10361055
0 commit comments