Skip to content

Commit 47e878e

Browse files
committed
more general
1 parent 17b822f commit 47e878e

File tree

2 files changed

+11
-19
lines changed

2 files changed

+11
-19
lines changed

src/transformers/models/sam2/image_processing_sam2_fast.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,7 @@ def _preprocess(
510510
return_tensors: Optional[Union[str, TensorType]],
511511
**kwargs,
512512
) -> "torch.Tensor":
513-
return BaseImageProcessorFast._preprocess(self, images, return_tensors=return_tensors, **kwargs).pixel_values
513+
return super()._preprocess(images, return_tensors=return_tensors, **kwargs).pixel_values
514514

515515
def generate_crop_boxes(
516516
self,

utils/modular_model_converter.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -201,39 +201,31 @@ def __init__(self, new_bases: list[str]):
201201

202202
def is_call_to_parent_class(self, node: cst.SimpleStatementLine):
203203
"""Check whether `node` corresponds to a call to a parent class function, such as `module.Parent.func_name(...)`"""
204-
parent_call_node = m.Call(func=m.Attribute(value=m.Name() | m.Attribute()))
205-
# It can be used as a return, simple expression, or assignment
206-
potential_bodies = [m.Return(parent_call_node) | m.Expr(parent_call_node) | m.Assign(value=parent_call_node)]
207-
return m.matches(node, m.SimpleStatementLine(body=potential_bodies))
208-
209-
def leave_SimpleStatementLine(
210-
self, original_node: cst.SimpleStatementLine, updated_node: cst.SimpleStatementLine
211-
) -> cst.SimpleStatementLine:
204+
return m.matches(node, m.Call(func=m.Attribute(value=m.Name() | m.Attribute())))
205+
206+
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call:
212207
"""Replace a call of the form `module.Class.func(...)` by a call of the form `super().func(...)`
213208
if the `Class` being called is one of the bases."""
214209
if self.is_call_to_parent_class(updated_node):
215-
expr_node = updated_node.body[0]
216-
full_parent_class_name = get_full_attribute_name(expr_node.value.func.value)
210+
full_parent_class_name = get_full_attribute_name(updated_node.func.value)
217211
# Replace only if it's a base, or a few special rules
218212
if (
219213
full_parent_class_name in self.new_bases
220-
or ("nn.Module" in full_parent_class_name and self.new_bases == ["GradientCheckpointingLayer"])
214+
or (full_parent_class_name == "nn.Module" and "GradientCheckpointingLayer" in self.new_bases)
221215
or (
222216
full_parent_class_name == "PreTrainedModel"
223217
and any("PreTrainedModel" in base for base in self.new_bases)
224218
)
225219
):
226220
# Replace `full_parent_class_name.func(...)` with `super().func(...)`
227-
attribute_node = expr_node.value.func.with_changes(value=cst.Call(func=cst.Name("super")))
221+
attribute_node = updated_node.func.with_changes(value=cst.Call(func=cst.Name("super")))
228222
# Check if the first argument is 'self', and remove it
229223
new_args = (
230-
expr_node.value.args[1:]
231-
if len(expr_node.value.args) > 0 and m.matches(expr_node.value.args[0].value, m.Name("self"))
232-
else expr_node.value.args
224+
updated_node.args[1:]
225+
if len(updated_node.args) > 0 and m.matches(updated_node.args[0].value, m.Name("self"))
226+
else updated_node.args
233227
)
234-
call_node = expr_node.value.with_changes(func=attribute_node, args=new_args)
235-
new_expr_node = expr_node.with_changes(value=call_node)
236-
return updated_node.with_changes(body=[new_expr_node])
228+
return updated_node.with_changes(func=attribute_node, args=new_args)
237229
return updated_node
238230

239231

0 commit comments

Comments
 (0)