@@ -201,39 +201,31 @@ def __init__(self, new_bases: list[str]):
201201
202202 def is_call_to_parent_class (self , node : cst .SimpleStatementLine ):
203203 """Check whether `node` corresponds to a call to a parent class function, such as `module.Parent.func_name(...)`"""
204- parent_call_node = m .Call (func = m .Attribute (value = m .Name () | m .Attribute ()))
205- # It can be used as a return, simple expression, or assignment
206- potential_bodies = [m .Return (parent_call_node ) | m .Expr (parent_call_node ) | m .Assign (value = parent_call_node )]
207- return m .matches (node , m .SimpleStatementLine (body = potential_bodies ))
208-
209- def leave_SimpleStatementLine (
210- self , original_node : cst .SimpleStatementLine , updated_node : cst .SimpleStatementLine
211- ) -> cst .SimpleStatementLine :
204+ return m .matches (node , m .Call (func = m .Attribute (value = m .Name () | m .Attribute ())))
205+
206+ def leave_Call (self , original_node : cst .Call , updated_node : cst .Call ) -> cst .Call :
212207 """Replace a call of the form `module.Class.func(...)` by a call of the form `super().func(...)`
213208 if the `Class` being called is one of the bases."""
214209 if self .is_call_to_parent_class (updated_node ):
215- expr_node = updated_node .body [0 ]
216- full_parent_class_name = get_full_attribute_name (expr_node .value .func .value )
210+ full_parent_class_name = get_full_attribute_name (updated_node .func .value )
217211 # Replace only if it's a base, or a few special rules
218212 if (
219213 full_parent_class_name in self .new_bases
220- or ("nn.Module" in full_parent_class_name and self .new_bases == [ "GradientCheckpointingLayer" ] )
214+ or (full_parent_class_name == "nn.Module" and "GradientCheckpointingLayer" in self .new_bases )
221215 or (
222216 full_parent_class_name == "PreTrainedModel"
223217 and any ("PreTrainedModel" in base for base in self .new_bases )
224218 )
225219 ):
226220 # Replace `full_parent_class_name.func(...)` with `super().func(...)`
227- attribute_node = expr_node . value .func .with_changes (value = cst .Call (func = cst .Name ("super" )))
221+ attribute_node = updated_node .func .with_changes (value = cst .Call (func = cst .Name ("super" )))
228222 # Check if the first argument is 'self', and remove it
229223 new_args = (
230- expr_node . value .args [1 :]
231- if len (expr_node . value . args ) > 0 and m .matches (expr_node . value .args [0 ].value , m .Name ("self" ))
232- else expr_node . value .args
224+ updated_node .args [1 :]
225+ if len (updated_node . args ) > 0 and m .matches (updated_node .args [0 ].value , m .Name ("self" ))
226+ else updated_node .args
233227 )
234- call_node = expr_node .value .with_changes (func = attribute_node , args = new_args )
235- new_expr_node = expr_node .with_changes (value = call_node )
236- return updated_node .with_changes (body = [new_expr_node ])
228+ return updated_node .with_changes (func = attribute_node , args = new_args )
237229 return updated_node
238230
239231
0 commit comments