Skip to content

Commit 01c9e1b

Browse files
authored
[t5gemma] fix get_text_config and related fixes (#40939)
* tmp commit * t5gemma fixes
1 parent 0255319 commit 01c9e1b

File tree

6 files changed

+24
-28
lines changed

6 files changed

+24
-28
lines changed

src/transformers/cache_utils.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,12 @@ def update(
395395
if not self.is_initialized:
396396
self.lazy_initialization(key_states)
397397

398-
cache_position = cache_kwargs.get("cache_position")
398+
# Some old models give None for `cache_position` or even omit passing `cache_kwargs` when used as cross-attention,
399+
# in which case we should copy the whole Layer (key_states.shape[-2] == self.max_cache_len)
400+
cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None
401+
cache_position = (
402+
cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device)
403+
)
399404

400405
cumulative_length = self.cumulative_length
401406
is_full = cumulative_length >= self.max_cache_len
@@ -955,17 +960,19 @@ def __init__(
955960
layers = []
956961
# If a config is passed, use it to infer the layer types and initialize accordingly
957962
if config is not None:
958-
config = config.get_text_config(decoder=True)
959-
sliding_window = getattr(config, "sliding_window", None) or getattr(config, "attention_chunk_size", None)
960-
layer_types = getattr(config, "layer_types", None)
963+
decoder_config = config.get_text_config(decoder=True)
964+
sliding_window = getattr(decoder_config, "sliding_window", None) or getattr(
965+
decoder_config, "attention_chunk_size", None
966+
)
967+
layer_types = getattr(decoder_config, "layer_types", None)
961968
if layer_types is None:
962969
layer_types = [
963970
"sliding_attention" if sliding_window is not None else "full_attention"
964-
for _ in range(config.num_hidden_layers)
971+
for _ in range(decoder_config.num_hidden_layers)
965972
]
966973
# Some models have shared layers thus no cache is needed for them (e.g. Gemma3n)
967-
if hasattr(config, "num_kv_shared_layers"):
968-
layer_types = layer_types[: -config.num_kv_shared_layers]
974+
if hasattr(decoder_config, "num_kv_shared_layers"):
975+
layer_types = layer_types[: -decoder_config.num_kv_shared_layers]
969976

970977
for layer_type in layer_types:
971978
# From a cache point of view, both sliding and chunked are the same in how they should behave and how many

src/transformers/models/t5gemma/configuration_t5gemma.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -324,9 +324,5 @@ def __setattr__(self, key, value):
324324
setattr(self.decoder, key, value)
325325
super().__setattr__(key, value)
326326

327-
def get_text_config(self, *args, **kwargs):
328-
# Always return self, regardless of the decoder option.
329-
return self
330-
331327

332328
__all__ = ["T5GemmaConfig", "T5GemmaModuleConfig"]

src/transformers/models/t5gemma/modular_t5gemma.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -339,10 +339,6 @@ def __setattr__(self, key, value):
339339
setattr(self.decoder, key, value)
340340
super().__setattr__(key, value)
341341

342-
def get_text_config(self, *args, **kwargs):
343-
# Always return self, regardless of the decoder option.
344-
return self
345-
346342

347343
class T5GemmaRMSNorm(Gemma2RMSNorm):
348344
pass

tests/generation/test_utils.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,7 +1209,7 @@ def test_generate_from_inputs_embeds(self, _, num_beams):
12091209

12101210
# This test is for decoder-only models (encoder-decoder models have native input embeddings support in the
12111211
# decoder)
1212-
if config.get_text_config(decoder=True).is_encoder_decoder:
1212+
if config.is_encoder_decoder:
12131213
continue
12141214
config.is_decoder = True
12151215

@@ -1288,7 +1288,7 @@ def test_generate_from_inputs_embeds_with_static_cache(self):
12881288

12891289
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
12901290

1291-
if config.get_text_config(decoder=True).is_encoder_decoder:
1291+
if config.is_encoder_decoder:
12921292
self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache")
12931293

12941294
model = model_class(config).to(torch_device).eval()
@@ -1439,7 +1439,7 @@ def test_generate_continue_from_inputs_embeds(self):
14391439
if "token_type_ids" in inputs_dict:
14401440
del inputs_dict["token_type_ids"]
14411441

1442-
if config.get_text_config(decoder=True).is_encoder_decoder:
1442+
if config.is_encoder_decoder:
14431443
self.skipTest(reason="This model is encoder-decoder")
14441444
# TODO (joao, raushan): the correct line below is `if not hasattr(config.get_text_config(), "use_cache")`,
14451445
# but it breaks a few models. Fix and then apply `has_similar_generate_outputs` pattern
@@ -1512,7 +1512,7 @@ def test_generate_with_static_cache(self):
15121512
set_config_for_less_flaky_test(config)
15131513
main_input = inputs_dict[model_class.main_input_name]
15141514

1515-
if config.get_text_config(decoder=True).is_encoder_decoder:
1515+
if config.is_encoder_decoder:
15161516
self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache")
15171517

15181518
config.is_decoder = True
@@ -1567,10 +1567,7 @@ def test_generate_with_quant_cache(self):
15671567
for model_class in self.all_generative_model_classes:
15681568
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
15691569

1570-
if (
1571-
config.get_text_config(decoder=True).is_encoder_decoder
1572-
or not model_class._supports_default_dynamic_cache()
1573-
):
1570+
if config.is_encoder_decoder or not model_class._supports_default_dynamic_cache():
15741571
self.skipTest(reason="This model does not support the quantized cache format")
15751572

15761573
config.is_decoder = True
@@ -1670,7 +1667,7 @@ def test_generate_compile_model_forward_fullgraph(self):
16701667
if not has_defined_cache_implementation:
16711668
decoder_cache = (
16721669
gen_out.past_key_values.self_attention_cache
1673-
if config.get_text_config(decoder=True).is_encoder_decoder
1670+
if config.is_encoder_decoder
16741671
else gen_out.past_key_values
16751672
)
16761673
self.assertTrue(isinstance(decoder_cache, DynamicCache))
@@ -1696,7 +1693,7 @@ def test_generate_compile_model_forward_fullgraph(self):
16961693
# sanity checks
16971694
decoder_cache = (
16981695
gen_out.past_key_values.self_attention_cache
1699-
if config.get_text_config(decoder=True).is_encoder_decoder
1696+
if config.is_encoder_decoder
17001697
else gen_out.past_key_values
17011698
)
17021699
self.assertFalse(isinstance(decoder_cache, DynamicCache))

tests/models/gemma3n/test_modeling_gemma3n.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ def test_generate_from_inputs_embeds_with_static_cache(self):
448448

449449
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
450450

451-
if config.get_text_config(decoder=True).is_encoder_decoder:
451+
if config.is_encoder_decoder:
452452
self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache")
453453

454454
model = model_class(config).to(torch_device).eval()
@@ -509,7 +509,7 @@ def test_generate_with_static_cache(self):
509509
set_config_for_less_flaky_test(config)
510510
main_input = inputs_dict[model_class.main_input_name]
511511

512-
if config.get_text_config(decoder=True).is_encoder_decoder:
512+
if config.is_encoder_decoder:
513513
self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache")
514514

515515
config.is_decoder = True

tests/test_modeling_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4145,7 +4145,7 @@ def test_flex_attention_with_grads(self):
41454145
if key in inputs_dict:
41464146
dummy_inputs[key] = inputs_dict[key].to(torch_device)
41474147

4148-
if config.get_text_config(decoder=True).is_encoder_decoder:
4148+
if config.is_encoder_decoder:
41494149
dummy_inputs["decoder_input_ids"] = inputs_dict["decoder_input_ids"].to(torch_device)
41504150
dummy_inputs["decoder_attention_mask"] = inputs_dict["decoder_attention_mask"].to(torch_device)
41514151

0 commit comments

Comments
 (0)