Skip to content

Commit

Permalink
[SpeechEncoderDecoder] Fix from pretrained (huggingface#15043)
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickvonplaten authored and Steven committed Jan 6, 2022
1 parent 68b95f0 commit e6eef10
Showing 1 changed file with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def from_encoder_decoder_pretrained(
)

if "config" not in kwargs_encoder:
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path, **kwargs_encoder)
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
logger.info(
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
Expand All @@ -391,7 +391,7 @@ def from_encoder_decoder_pretrained(

kwargs_encoder["config"] = encoder_config

encoder = AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args)
encoder = AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder)

decoder = kwargs_decoder.pop("model", None)
if decoder is None:
Expand All @@ -402,7 +402,7 @@ def from_encoder_decoder_pretrained(
)

if "config" not in kwargs_decoder:
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
logger.info(
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
Expand All @@ -424,7 +424,7 @@ def from_encoder_decoder_pretrained(
"`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
)

decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path)
decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)

# instantiate config with corresponding kwargs
config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
Expand Down

0 comments on commit e6eef10

Please sign in to comment.