Skip to content

Commit 17b822f

Browse files
committed
fix assignment as well
1 parent 3368392 commit 17b822f

File tree

7 files changed

+74
-62
lines changed

7 files changed

+74
-62
lines changed

src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1165,7 +1165,7 @@ def forward(
11651165
)
11661166

11671167
def _prepare_generation_config(self, *args, **kwargs):
1168-
generation_config, model_kwargs = GenerationMixin._prepare_generation_config(self, *args, **kwargs)
1168+
generation_config, model_kwargs = super()._prepare_generation_config(*args, **kwargs)
11691169
# this should be passed to the model kwargs for the input preparation
11701170
model_kwargs["audio_window_size"] = (
11711171
generation_config.audio_window_size if hasattr(generation_config, "audio_window_size") else None
@@ -1178,8 +1178,7 @@ def _prepare_model_inputs(
11781178
bos_token_id: Optional[torch.Tensor] = None,
11791179
model_kwargs: Optional[dict[str, torch.Tensor]] = None,
11801180
) -> tuple[torch.Tensor, Optional[str], dict[str, torch.Tensor]]:
1181-
inputs, input_name, model_kwargs = GenerationMixin._prepare_model_inputs(
1182-
self,
1181+
inputs, input_name, model_kwargs = super()._prepare_model_inputs(
11831182
inputs=inputs,
11841183
bos_token_id=bos_token_id,
11851184
model_kwargs=model_kwargs,
@@ -1264,7 +1263,7 @@ def prepare_inputs_for_generation(
12641263
padding_cache: Optional[KyutaiSpeechToTextConv1dPaddingCache] = None,
12651264
**kwargs,
12661265
):
1267-
model_inputs = GenerationMixin.prepare_inputs_for_generation(self, *args, **kwargs)
1266+
model_inputs = super().prepare_inputs_for_generation(*args, **kwargs)
12681267

12691268
if input_values is not None:
12701269
cache_position = model_inputs["cache_position"]
@@ -1311,9 +1310,9 @@ def prepare_inputs_for_generation(
13111310
@classmethod
13121311
def from_pretrained(cls, *args, **kwargs):
13131312
if kwargs.get("output_loading_info", False):
1314-
model, loading_info = PreTrainedModel.from_pretrained(self, *args, **kwargs)
1313+
model, loading_info = super().from_pretrained(*args, **kwargs)
13151314
else:
1316-
model = PreTrainedModel.from_pretrained(self, *args, **kwargs)
1315+
model = super().from_pretrained(*args, **kwargs)
13171316

13181317
# copy depth decoder generation conf attr to the depth decoder generation config
13191318
prefix = "codec_"

src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def __init__(self, config):
251251
self.embed_tokens = KyutaiSpeechToTextEmbeddings(config)
252252

253253

254-
class KyutaiSpeechToTextForConditionalGeneration(LlamaForCausalLM, GenerationMixin, PreTrainedModel):
254+
class KyutaiSpeechToTextForConditionalGeneration(LlamaForCausalLM, GenerationMixin):
255255
_keep_in_fp32_modules_strict = ["codec_model"]
256256

257257
def __init__(self, config):
@@ -445,9 +445,9 @@ def prepare_inputs_for_generation(
445445
@classmethod
446446
def from_pretrained(cls, *args, **kwargs):
447447
if kwargs.get("output_loading_info", False):
448-
model, loading_info = PreTrainedModel.from_pretrained(self, *args, **kwargs)
448+
model, loading_info = PreTrainedModel.from_pretrained(*args, **kwargs)
449449
else:
450-
model = PreTrainedModel.from_pretrained(self, *args, **kwargs)
450+
model = PreTrainedModel.from_pretrained(*args, **kwargs)
451451

452452
# copy depth decoder generation conf attr to the depth decoder generation config
453453
prefix = "codec_"

src/transformers/models/phi3/modeling_phi3.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -516,8 +516,7 @@ def prepare_inputs_for_generation(
516516
if past_length <= self.config.original_max_position_embeddings:
517517
past_key_values = None
518518

519-
model_inputs = Phi3PreTrainedModel.prepare_inputs_for_generation(
520-
self,
519+
model_inputs = super().prepare_inputs_for_generation(
521520
input_ids=input_ids,
522521
past_key_values=past_key_values,
523522
attention_mask=attention_mask,

src/transformers/models/phimoe/modeling_phimoe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1334,7 +1334,7 @@ def prepare_inputs_for_generation(
13341334
if past_length <= self.config.original_max_position_embeddings:
13351335
past_key_values = None
13361336

1337-
model_inputs = Phi3PreTrainedModel().prepare_inputs_for_generation(
1337+
model_inputs = super().prepare_inputs_for_generation(
13381338
input_ids=input_ids,
13391339
past_key_values=past_key_values,
13401340
attention_mask=attention_mask,

src/transformers/models/t5gemma/modeling_t5gemma.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from ...processing_utils import Unpack
4444
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
4545
from ...utils.deprecation import deprecate_kwarg
46-
from ..utils.generic import OutputRecorder, check_model_inputs
46+
from ...utils.generic import OutputRecorder, check_model_inputs
4747
from .configuration_t5gemma import T5GemmaConfig, T5GemmaModuleConfig
4848

4949

src/transformers/models/t5gemma/modular_t5gemma.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@
3737
TransformersKwargs,
3838
auto_docstring,
3939
can_return_tuple,
40-
is_torch_flex_attn_available,
4140
is_torchdynamo_compiling,
4241
logging,
4342
)
4443
from ...utils.deprecation import deprecate_kwarg
44+
from ...utils.generic import OutputRecorder, check_model_inputs
4545
from ..gemma2.configuration_gemma2 import Gemma2Config
4646
from ..gemma2.modeling_gemma2 import (
4747
Gemma2Attention,
@@ -53,16 +53,11 @@
5353
create_sliding_window_causal_mask,
5454
eager_attention_forward,
5555
)
56-
from ..utils.generic import OutputRecorder, check_model_inputs
5756

5857

5958
_CHECKPOINT_FOR_DOC = "google/t5gemma-2b-2b-prefixlm-it"
6059

6160

62-
if is_torch_flex_attn_available():
63-
pass
64-
65-
6661
logger = logging.get_logger(__name__)
6762

6863

utils/modular_model_converter.py

Lines changed: 62 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,60 @@ def get_full_attribute_name(node: Union[cst.Attribute, cst.Name]) -> Optional[st
190190
return None
191191

192192

193-
class SuperTransformer(cst.CSTTransformer):
194-
METADATA_DEPENDENCIES = (ParentNodeProvider,)
193+
class ReplaceParentClassCallTransformer(cst.CSTTransformer):
194+
"""
195+
This Transformer is used to replace all calls of the form `module.Class.func(...)` by a call of the form
196+
`super().func(...)`.
197+
"""
198+
199+
def __init__(self, new_bases: list[str]):
200+
self.new_bases = new_bases
201+
202+
def is_call_to_parent_class(self, node: cst.SimpleStatementLine):
203+
"""Check whether `node` corresponds to a call to a parent class function, such as `module.Parent.func_name(...)`"""
204+
parent_call_node = m.Call(func=m.Attribute(value=m.Name() | m.Attribute()))
205+
# It can be used as a return, simple expression, or assignment
206+
potential_bodies = [m.Return(parent_call_node) | m.Expr(parent_call_node) | m.Assign(value=parent_call_node)]
207+
return m.matches(node, m.SimpleStatementLine(body=potential_bodies))
208+
209+
def leave_SimpleStatementLine(
210+
self, original_node: cst.SimpleStatementLine, updated_node: cst.SimpleStatementLine
211+
) -> cst.SimpleStatementLine:
212+
"""Replace a call of the form `module.Class.func(...)` by a call of the form `super().func(...)`
213+
if the `Class` being called is one of the bases."""
214+
if self.is_call_to_parent_class(updated_node):
215+
expr_node = updated_node.body[0]
216+
full_parent_class_name = get_full_attribute_name(expr_node.value.func.value)
217+
# Replace only if it's a base, or a few special rules
218+
if (
219+
full_parent_class_name in self.new_bases
220+
or ("nn.Module" in full_parent_class_name and self.new_bases == ["GradientCheckpointingLayer"])
221+
or (
222+
full_parent_class_name == "PreTrainedModel"
223+
and any("PreTrainedModel" in base for base in self.new_bases)
224+
)
225+
):
226+
# Replace `full_parent_class_name.func(...)` with `super().func(...)`
227+
attribute_node = expr_node.value.func.with_changes(value=cst.Call(func=cst.Name("super")))
228+
# Check if the first argument is 'self', and remove it
229+
new_args = (
230+
expr_node.value.args[1:]
231+
if len(expr_node.value.args) > 0 and m.matches(expr_node.value.args[0].value, m.Name("self"))
232+
else expr_node.value.args
233+
)
234+
call_node = expr_node.value.with_changes(func=attribute_node, args=new_args)
235+
new_expr_node = expr_node.with_changes(value=call_node)
236+
return updated_node.with_changes(body=[new_expr_node])
237+
return updated_node
238+
239+
240+
class ReplaceSuperCallTransformer(cst.CSTTransformer):
241+
"""
242+
This Transformer is used to unravel all calls to `super().func(...)` in class methods by the explicit parent's
243+
code. It will also in turn replace all calls of the form `module.Class.func(...)` by a call of the form
244+
`super().func(...)`. Those calls are used to explicitly skip the unravelling of code, but we should still follow
245+
python's standards and use `super().func(...)` instead of `Parent.func(self, ...)`.
246+
"""
195247

196248
def __init__(
197249
self,
@@ -205,7 +257,8 @@ def __init__(
205257
self.modular_methods = modular_methods
206258
self.all_assign_target = {}
207259
self.deleted_targets = {} # child node can delete some arguments
208-
self.new_bases = [get_full_attribute_name(base.value) for base in new_bases]
260+
new_bases = [get_full_attribute_name(base.value) for base in new_bases]
261+
self.parent_class_call_transformer = ReplaceParentClassCallTransformer(new_bases)
209262

210263
def update_body(self, existing_body, new_statements):
211264
"""
@@ -283,45 +336,14 @@ def _fix_init_location(self, new_body):
283336
break
284337
return new_body
285338

286-
def replace_parent_class_call(self, node: cst.SimpleStatementLine) -> cst.SimpleStatementLine:
287-
"""Replace a call of the form `module.Class.func(...)` by a call of the form `super().func(...)`
288-
if the `Class` being called is one of the bases."""
289-
expr_node = node.body[0]
290-
full_parent_class_name = get_full_attribute_name(expr_node.value.func.value)
291-
# Replace only if it's a base, or if using nn.Module on a GradientCheckpointingLayer
292-
if (
293-
full_parent_class_name in self.new_bases
294-
or ("nn.Module" in full_parent_class_name and self.new_bases == ["GradientCheckpointingLayer"])
295-
or (
296-
full_parent_class_name == "PreTrainedModel"
297-
and any("PreTrainedModel" in base for base in self.new_bases)
298-
)
299-
):
300-
# Replace `full_parent_class_name.func(...)` with `super().func(...)`
301-
attribute_node = expr_node.value.func.with_changes(value=cst.Call(func=cst.Name("super")))
302-
call_node = expr_node.value.with_changes(func=attribute_node)
303-
# Check if the first argument is 'self', and remove it
304-
new_args = (
305-
call_node.args[1:]
306-
if len(call_node.args) > 0 and m.matches(call_node.args[0].value, m.Name("self"))
307-
else call_node.args
308-
)
309-
new_expr_node = expr_node.with_changes(value=call_node.with_changes(args=new_args))
310-
return node.with_changes(body=[new_expr_node])
311-
return node
312-
313339
def is_call_to_super(self, node: cst.BaseStatement, func_name: str):
314340
"""Check whether `node` corresponds to a call to `super().func_name(...)`"""
315341
super_call_node = m.Call(func=m.Attribute(value=m.Call(func=m.Name("super")), attr=m.Name(func_name)))
316342
return m.matches(node, m.SimpleStatementLine(body=[m.Return(super_call_node) | m.Expr(super_call_node)]))
317343

318-
def is_call_to_parent_class(self, node: cst.BaseStatement):
319-
"""Check whether `node` corresponds to a call to a parent class function, such as `module.Parent.func_name(...)`"""
320-
parent_call_node = m.Call(func=m.Attribute(value=m.Name() | m.Attribute()))
321-
return m.matches(node, m.SimpleStatementLine(body=[m.Return(parent_call_node) | m.Expr(parent_call_node)]))
322-
323344
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
324345
func_name = updated_node.name.value
346+
self.should_check_statements = False
325347
if func_name in self.modular_methods:
326348
actual_body = updated_node.body.body # first body is an `IndentedBlock` wrapper
327349
new_body = []
@@ -332,11 +354,9 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu
332354
new_body = self._fix_init_location(new_body)
333355
# Break here as all future statement were already accounted for in `update_body`
334356
break
335-
elif self.is_call_to_parent_class(base_statement_node):
336-
new_body.append(self.replace_parent_class_call(base_statement_node))
337-
else:
338-
new_body.append(base_statement_node)
339-
357+
# If not a call to super, this will replace all calls of the form `module.Class.func(...)` by a
358+
# call of the form `super().func(...)
359+
new_body.append(base_statement_node.visit(self.parent_class_call_transformer))
340360
return updated_node.with_changes(body=updated_node.body.with_changes(body=new_body))
341361
return updated_node
342362

@@ -1028,9 +1048,8 @@ def replace_class_node(
10281048
# Replace the calls to `super()` of the redefined modular methods with the unrolled code
10291049
result_node = original_modeling_node.with_changes(body=cst.IndentedBlock(body=new_class_body))
10301050
temp_module = cst.Module(body=[result_node])
1031-
new_module = MetadataWrapper(temp_module)
1032-
new_replacement_class = new_module.visit(
1033-
SuperTransformer(temp_module, original_modeling_methods, modular_methods, new_class_bases)
1051+
new_replacement_class = temp_module.visit(
1052+
ReplaceSuperCallTransformer(temp_module, original_modeling_methods, modular_methods, new_class_bases)
10341053
)
10351054
new_class_body = new_replacement_class.body[0].body # get the indented block
10361055

0 commit comments

Comments
 (0)