@@ -1168,21 +1168,34 @@ def _get_non_default_generation_parameters(self) -> dict[str, Any]:
11681168
11691169 return non_default_generation_parameters
11701170
1171- def get_text_config (self , decoder = False ) -> "PretrainedConfig" :
1171+ def get_text_config (self , decoder = None , encoder = None ) -> "PretrainedConfig" :
11721172 """
1173- Returns the config that is meant to be used with text IO. On most models, it is the original config instance
1174- itself. On specific composite models, it is under a set of valid names.
1173+ Returns the text config related to the text input (encoder) or text output (decoder) of the model. The
1174+ `decoder` and `encoder` input arguments can be used to specify which end of the model we are interested in,
1175+ which is useful on models that have both text input and output modalities.
1176+
1177+ There are three possible outcomes of using this method:
1178+ 1. On most models, it returns the original config instance itself.
1179+ 2. On newer (2024+) composite models, it returns the text section of the config, which is nested under a set
1180+ of valid names.
1181+ 3. On older (2023-) composite models, it discards decoder-only parameters when `encoder=True` and vice-versa.
11751182
11761183 Args:
1177- decoder (`Optional[bool]`, *optional*, defaults to `False` ):
1184+ decoder (`Optional[bool]`, *optional*):
11781185 If set to `True`, then only search for decoder config names.
1186+ encoder (`Optional[bool]`, *optional*):
1187+ If set to `True`, then only search for encoder config names.
11791188 """
1189+ return_both = decoder == encoder # both unset or both set -> search all possible names
1190+
11801191 decoder_possible_text_config_names = ("decoder" , "generator" , "text_config" )
11811192 encoder_possible_text_config_names = ("text_encoder" ,)
1182- if decoder :
1193+ if return_both :
1194+ possible_text_config_names = encoder_possible_text_config_names + decoder_possible_text_config_names
1195+ elif decoder :
11831196 possible_text_config_names = decoder_possible_text_config_names
11841197 else :
1185- possible_text_config_names = encoder_possible_text_config_names + decoder_possible_text_config_names
1198+ possible_text_config_names = encoder_possible_text_config_names
11861199
11871200 valid_text_config_names = []
11881201 for text_config_name in possible_text_config_names :
@@ -1194,12 +1207,27 @@ def get_text_config(self, decoder=False) -> "PretrainedConfig":
11941207 if len (valid_text_config_names ) > 1 :
11951208 raise ValueError (
11961209 f"Multiple valid text configs were found in the model config: { valid_text_config_names } . In this "
1197- "case, using `get_text_config()` would be ambiguous. Please specify the desied text config directly."
1210+ "case, using `get_text_config()` would be ambiguous. Please specify the desired text config directly, "
1211+ "e.g. `text_config = config.sub_config_name`"
11981212 )
11991213 elif len (valid_text_config_names ) == 1 :
12001214 config_to_return = getattr (self , valid_text_config_names [0 ])
12011215 else :
12021216 config_to_return = self
1217+
1218+ # handle legacy models with flat config structure, when we only want one of the configs
1219+ if not return_both and len (valid_text_config_names ) == 0 and config_to_return .is_encoder_decoder :
1220+ config_to_return = copy .deepcopy (config_to_return )
1221+ prefix_to_discard = "encoder" if decoder else "decoder"
1222+ for key in config_to_return .to_dict ():
1223+ if key .startswith (prefix_to_discard ):
1224+ delattr (config_to_return , key )
1225+ # old encoder/decoder models may use "encoder_layers"/"decoder_layers" instead of "num_hidden_layers"
1226+ if decoder and hasattr (config_to_return , "decoder_layers" ):
1227+ config_to_return .num_hidden_layers = config_to_return .decoder_layers
1228+ elif encoder and hasattr (config_to_return , "encoder_layers" ):
1229+ config_to_return .num_hidden_layers = config_to_return .encoder_layers
1230+
12031231 return config_to_return
12041232
12051233 @classmethod
0 commit comments