Skip to content

Commit 9568b50

Browse files
authored
[generate] handle support for cache classes when num enc layers != num dec layers (#40277)
* handle support for cache classes when num enc layers != num dec layers * handle overwrites * one more corner case * Update src/transformers/generation/utils.py * Update src/transformers/generation/utils.py * Apply suggestions from code review * handle corner case :o
1 parent 7f38068 commit 9568b50

File tree

9 files changed

+89
-24
lines changed

9 files changed

+89
-24
lines changed

src/transformers/configuration_utils.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,21 +1168,34 @@ def _get_non_default_generation_parameters(self) -> dict[str, Any]:
11681168

11691169
return non_default_generation_parameters
11701170

1171-
def get_text_config(self, decoder=False) -> "PretrainedConfig":
1171+
def get_text_config(self, decoder=None, encoder=None) -> "PretrainedConfig":
11721172
"""
1173-
Returns the config that is meant to be used with text IO. On most models, it is the original config instance
1174-
itself. On specific composite models, it is under a set of valid names.
1173+
Returns the text config related to the text input (encoder) or text output (decoder) of the model. The
1174+
`decoder` and `encoder` input arguments can be used to specify which end of the model we are interested in,
1175+
which is useful on models that have both text input and output modalities.
1176+
1177+
There are three possible outcomes of using this method:
1178+
1. On most models, it returns the original config instance itself.
1179+
2. On newer (2024+) composite models, it returns the text section of the config, which is nested under a set
1180+
of valid names.
1181+
3. On older (2023-) composite models, it discards decoder-only parameters when `encoder=True` and vice-versa.
11751182
11761183
Args:
1177-
decoder (`Optional[bool]`, *optional*, defaults to `False`):
1184+
decoder (`Optional[bool]`, *optional*):
11781185
If set to `True`, then only search for decoder config names.
1186+
encoder (`Optional[bool]`, *optional*):
1187+
If set to `True`, then only search for encoder config names.
11791188
"""
1189+
return_both = decoder == encoder # both unset or both set -> search all possible names
1190+
11801191
decoder_possible_text_config_names = ("decoder", "generator", "text_config")
11811192
encoder_possible_text_config_names = ("text_encoder",)
1182-
if decoder:
1193+
if return_both:
1194+
possible_text_config_names = encoder_possible_text_config_names + decoder_possible_text_config_names
1195+
elif decoder:
11831196
possible_text_config_names = decoder_possible_text_config_names
11841197
else:
1185-
possible_text_config_names = encoder_possible_text_config_names + decoder_possible_text_config_names
1198+
possible_text_config_names = encoder_possible_text_config_names
11861199

11871200
valid_text_config_names = []
11881201
for text_config_name in possible_text_config_names:
@@ -1194,12 +1207,27 @@ def get_text_config(self, decoder=False) -> "PretrainedConfig":
11941207
if len(valid_text_config_names) > 1:
11951208
raise ValueError(
11961209
f"Multiple valid text configs were found in the model config: {valid_text_config_names}. In this "
1197-
"case, using `get_text_config()` would be ambiguous. Please specify the desied text config directly."
1210+
"case, using `get_text_config()` would be ambiguous. Please specify the desired text config directly, "
1211+
"e.g. `text_config = config.sub_config_name`"
11981212
)
11991213
elif len(valid_text_config_names) == 1:
12001214
config_to_return = getattr(self, valid_text_config_names[0])
12011215
else:
12021216
config_to_return = self
1217+
1218+
# handle legacy models with flat config structure, when we only want one of the configs
1219+
if not return_both and len(valid_text_config_names) == 0 and config_to_return.is_encoder_decoder:
1220+
config_to_return = copy.deepcopy(config_to_return)
1221+
prefix_to_discard = "encoder" if decoder else "decoder"
1222+
for key in config_to_return.to_dict():
1223+
if key.startswith(prefix_to_discard):
1224+
delattr(config_to_return, key)
1225+
# old encoder/decoder models may use "encoder_layers"/"decoder_layers" instead of "num_hidden_layers"
1226+
if decoder and hasattr(config_to_return, "decoder_layers"):
1227+
config_to_return.num_hidden_layers = config_to_return.decoder_layers
1228+
elif encoder and hasattr(config_to_return, "encoder_layers"):
1229+
config_to_return.num_hidden_layers = config_to_return.encoder_layers
1230+
12031231
return config_to_return
12041232

12051233
@classmethod

src/transformers/generation/utils.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1844,12 +1844,19 @@ def _get_cache(self, cache_implementation: str, batch_size: int, max_cache_len:
18441844
)
18451845

18461846
if need_new_cache:
1847-
cache_kwargs = {"config": self.config, "max_cache_len": max_cache_len, "offloading": offload_cache}
1848-
self._cache = StaticCache(**cache_kwargs)
1847+
self_attention_cache_kwargs = {
1848+
"config": self.config.get_text_config(decoder=True),
1849+
"max_cache_len": max_cache_len,
1850+
"offloading": offload_cache,
1851+
}
1852+
self._cache = StaticCache(**self_attention_cache_kwargs)
18491853
if requires_cross_attention_cache:
1850-
encoder_kwargs = cache_kwargs.copy()
1851-
encoder_kwargs["max_cache_len"] = model_kwargs["encoder_outputs"][0].shape[1]
1852-
self._cache = EncoderDecoderCache(self._cache, StaticCache(**encoder_kwargs))
1854+
cross_attention_cache_kwargs = {
1855+
"config": self.config.get_text_config(encoder=True),
1856+
"max_cache_len": model_kwargs["encoder_outputs"][0].shape[1],
1857+
"offloading": offload_cache,
1858+
}
1859+
self._cache = EncoderDecoderCache(self._cache, StaticCache(**cross_attention_cache_kwargs))
18531860
else:
18541861
self._cache.reset()
18551862
return self._cache

src/transformers/models/colqwen2/configuration_colqwen2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ def __init__(
8787
self.initializer_range = initializer_range
8888
super().__init__(**kwargs)
8989

90-
def get_text_config(self, decoder=False) -> PretrainedConfig:
91-
return self.vlm_config.get_text_config(decoder=decoder)
90+
def get_text_config(self, *args, **kwargs) -> PretrainedConfig:
91+
return self.vlm_config.get_text_config(*args, **kwargs)
9292

9393

9494
__all__ = ["ColQwen2Config"]

src/transformers/models/dia/configuration_dia.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ def __init__(
368368
**kwargs,
369369
)
370370

371-
def get_text_config(self, decoder=False):
371+
def get_text_config(self, *args, **kwargs):
372372
"""Defaulting to audio config as it's the decoder in this case which is usually the text backbone"""
373373
return self.decoder_config
374374

src/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,7 +1073,7 @@ def __init__(
10731073

10741074
super().__init__(**kwargs)
10751075

1076-
def get_text_config(self, decoder=False):
1076+
def get_text_config(self, *args, **kwargs):
10771077
"""
10781078
Returns the config that is meant to be used with text IO. On most models, it is the original config instance
10791079
itself. On specific composite models, it is under a set of valid names.
@@ -1085,7 +1085,7 @@ def get_text_config(self, decoder=False):
10851085
# Overridden for deeply nested config like Qwen2-Omni. We don't have any omni model
10861086
# except for Qwen yet. This has to be generalized if more deeply nested configs are
10871087
# added. NOTE: currently method used only by vLLM
1088-
return self.thinker_config.get_text_config()
1088+
return self.thinker_config.get_text_config(*args, **kwargs)
10891089

10901090

10911091
__all__ = ["Qwen2_5OmniConfig", "Qwen2_5OmniThinkerConfig", "Qwen2_5OmniTalkerConfig", "Qwen2_5OmniToken2WavConfig"]

src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1108,7 +1108,7 @@ def __init__(
11081108

11091109
super().__init__(**kwargs)
11101110

1111-
def get_text_config(self, decoder=False):
1111+
def get_text_config(self, *args, **kwargs):
11121112
"""
11131113
Returns the config that is meant to be used with text IO. On most models, it is the original config instance
11141114
itself. On specific composite models, it is under a set of valid names.
@@ -1120,7 +1120,7 @@ def get_text_config(self, decoder=False):
11201120
# Overridden for deeply nested config like Qwen2-Omni. We don't have any omni model
11211121
# except for Qwen yet. This has to be generalized if more deeply nested configs are
11221122
# added. NOTE: currently method used only by vLLM
1123-
return self.thinker_config.get_text_config()
1123+
return self.thinker_config.get_text_config(*args, **kwargs)
11241124

11251125

11261126
class Qwen2_5OmniPreTrainedModel(Qwen2_5_VLPreTrainedModel):

src/transformers/models/t5gemma/configuration_t5gemma.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,9 +323,8 @@ def __setattr__(self, key, value):
323323
setattr(self.decoder, key, value)
324324
super().__setattr__(key, value)
325325

326-
def get_text_config(self, decoder=False):
326+
def get_text_config(self, *args, **kwargs):
327327
# Always return self, regardless of the decoder option.
328-
del decoder
329328
return self
330329

331330

src/transformers/models/t5gemma/modular_t5gemma.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,9 +213,8 @@ def __setattr__(self, key, value):
213213
setattr(self.decoder, key, value)
214214
super().__setattr__(key, value)
215215

216-
def get_text_config(self, decoder=False):
216+
def get_text_config(self, *args, **kwargs):
217217
# Always return self, regardless of the decoder option.
218-
del decoder
219218
return self
220219

221220

tests/utils/test_configuration_utils.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from huggingface_hub import HfFolder
2525
from requests.exceptions import HTTPError
2626

27-
from transformers import AutoConfig, BertConfig, GPT2Config
27+
from transformers import AutoConfig, BertConfig, Florence2Config, GPT2Config
2828
from transformers.configuration_utils import PretrainedConfig
2929
from transformers.testing_utils import TOKEN, TemporaryHubRepo, is_staging_test
3030

@@ -300,3 +300,35 @@ def test_loading_config_do_not_raise_future_warnings(self):
300300
with warnings.catch_warnings():
301301
warnings.simplefilter("error")
302302
PretrainedConfig.from_pretrained("bert-base-uncased")
303+
304+
def test_get_text_config(self):
305+
"""Tests the `get_text_config` method."""
306+
# 1. model with only text input -> returns the original config instance
307+
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
308+
self.assertEqual(config.get_text_config(), config)
309+
self.assertEqual(config.get_text_config(decoder=True), config)
310+
311+
# 2. composite model (VLM) -> returns the text component
312+
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-LlavaForConditionalGeneration")
313+
self.assertEqual(config.get_text_config(), config.text_config)
314+
self.assertEqual(config.get_text_config(decoder=True), config.text_config)
315+
316+
# 3. ! corner case! : composite model whose sub-config is an old composite model (should behave as above)
317+
config = Florence2Config()
318+
self.assertEqual(config.get_text_config(), config.text_config)
319+
self.assertEqual(config.get_text_config(decoder=True), config.text_config)
320+
321+
# 4. old composite model -> may remove components based on the `decoder` or `encoder` argument
322+
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-bart")
323+
self.assertEqual(config.get_text_config(), config)
324+
# both encoder_layers and decoder_layers exist
325+
self.assertTrue(getattr(config, "encoder_layers", None) is not None)
326+
self.assertTrue(getattr(config, "decoder_layers", None) is not None)
327+
decoder_config = config.get_text_config(decoder=True)
328+
self.assertNotEqual(decoder_config, config)
329+
self.assertEqual(decoder_config.num_hidden_layers, config.decoder_layers)
330+
self.assertTrue(getattr(decoder_config, "encoder_layers", None) is None) # encoder_layers is removed
331+
encoder_config = config.get_text_config(encoder=True)
332+
self.assertNotEqual(encoder_config, config)
333+
self.assertEqual(encoder_config.num_hidden_layers, config.encoder_layers)
334+
self.assertTrue(getattr(encoder_config, "decoder_layers", None) is None) # decoder_layers is removed

0 commit comments

Comments
 (0)