Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

unexpected keyword argument 'lm_labels' when using BertModel as Decoder with EncoderDecoderModel #4960

Closed
utkd opened this issue Jun 12, 2020 · 2 comments
Labels

Comments

@utkd
Copy link

utkd commented Jun 12, 2020

The BertModel.forward() method does not expect a lm_labels and masked_lm_labels arguments. Yet, it looks like the EncoderDecoderModel.forward() method calls it's decoder's forward() method with those arguments which throws a TypeError when a BertModel is used as a decoder.

Am I using the BertModel incorrectly? I can get rid of the error by modifying the EncoderDecoderModel to not use those arguments for the decoder.

Exact Error:

File "/Users/utkarsh/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/Users/utkarsh/Projects/ai4code/transformers/bert2bert/models.py", line 12, in forward
    dec_out, dec_cls, enc_out, enc_cls = self.bertmodel(input_ids=inputs, attention_mask=input_masks, decoder_input_ids=targets, decoder_attention_mask=target_masks)
  File "/Users/utkarsh/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/Users/utkarsh/anaconda3/envs/py37/lib/python3.7/site-packages/transformers/modeling_encoder_decoder.py", line 283, in forward
    **kwargs_decoder,
  File "/Users/utkarsh/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
TypeError: forward() got an unexpected keyword argument 'lm_labels'

Relevant part of the code:

encoder = BertModel(enc_config)
dec_config = BertConfig(...,is_decoder=True)
decoder = BertModel(dec_config)
model = EncoderDecoderModel(encoder=encoder, decoder=decoder)

...
dec_out, dec_cls, enc_out, enc_cls = model(input_ids=inputs, attention_mask=input_masks, decoder_input_ids=targets, decoder_attention_mask=target_masks)

@gustavscholin
Copy link

gustavscholin commented Jun 14, 2020

I'm facing the same problem. Since #4874 it seems like it should be just labels instead of lm_labels. According to the documentation it should do masked language modeling-loss, but from my debugging it seems like it actually does next word prediction-loss.

@stale
Copy link

stale bot commented Aug 14, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the wontfix label Aug 14, 2020
@stale stale bot closed this as completed Aug 22, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants