Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Encoder-Decoder: OTP as a decoder #23350

Closed
4 tasks
BramVanroy opened this issue May 13, 2023 · 8 comments
Closed
4 tasks

Encoder-Decoder: OTP as a decoder #23350

BramVanroy opened this issue May 13, 2023 · 8 comments

Comments

@BramVanroy
Copy link
Collaborator

System Info

  • transformers version: 4.29.1
  • Platform: Linux-5.14.0-162.6.1.el9_1.0.1.x86_64-x86_64-with-glibc2.34
  • Python version: 3.10.10
  • Huggingface_hub version: 0.14.1
  • Safetensors version: not installed
  • PyTorch version (GPU?): 2.0.1+cu117 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)

Who can help?

@ArthurZucker and @younesbelkada, and maybe also @gante for generation in enc/dec scenarios

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

import torch

from transformers import EncoderDecoderModel, OPTConfig, MT5Config, MT5Model, OPTForCausalLM, AutoTokenizer


def init_enc_dec(enc_model_name: str = "google/mt5-small", dec_model_name: str = "facebook/opt-350m"):
    config_encoder = MT5Config.from_pretrained(enc_model_name)
    config_encoder.is_encoder_decoder = False
    config_encoder.add_cross_attention = False
    config_encoder.is_decoder = False
    config_encoder.num_decoder_layers = 0

    config_decoder = OPTConfig.from_pretrained(dec_model_name)
    config_decoder.add_cross_attention = True
    config_decoder.is_decoder = True

    encoder = MT5Model.from_pretrained(enc_model_name, config=config_encoder).get_encoder()
    decoder = OPTForCausalLM.from_pretrained(dec_model_name, config=config_decoder)
    model = EncoderDecoderModel(encoder=encoder, decoder=decoder)

    return model


def main():
    model = init_enc_dec()
    model.eval()
    enc_tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
    dec_tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
    with torch.no_grad():
        inputs = enc_tokenizer("I like bananas", return_tensors="pt")
        outputs = model.generate(**inputs)
        print(dec_tokenizer.batch_decode(**outputs))


if __name__ == '__main__':
    main()

This leads to

Traceback (most recent call last):
  File "/home/local/vanroy/llm-generation/enc_dec.py", line 38, in <module>
    main()
  File "/home/local/vanroy/llm-generation/enc_dec.py", line 33, in main
    outputs = model.generate(**inputs)
  File "/home/local/vanroy/llm-generation/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/local/vanroy/llm-generation/.venv/lib/python3.10/site-packages/transformers/generation/utils.py", line 1515, in generate
    return self.greedy_search(
  File "/home/local/vanroy/llm-generation/.venv/lib/python3.10/site-packages/transformers/generation/utils.py", line 2332, in greedy_search
    outputs = self(
  File "/home/local/vanroy/llm-generation/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/local/vanroy/llm-generation/.venv/lib/python3.10/site-packages/transformers/models/encoder_decoder/modeling_encoder_decoder.py", line 617, in forward
    decoder_outputs = self.decoder(
  File "/home/local/vanroy/llm-generation/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
TypeError: OPTForCausalLM.forward() got an unexpected keyword argument 'encoder_hidden_states'

Expected behavior

I am trying to use the encoder-decoder functionality but I am not sure whether I am doing something wrong, or whether OPT is simply not compatible with this architecture.

@BramVanroy BramVanroy changed the title Encoder-Decoder: OT Encoder-Decoder: OTP as a decoder May 13, 2023
@gante
Copy link
Member

gante commented May 16, 2023

Hey @BramVanroy 👋

I believe most, if not all, recent decoder-only models are not compatible with EncoderDecoderModel, as they are missing a block like this one in GPT2 (plus related changes, like making encoder_hidden_states an argument).

@BramVanroy
Copy link
Collaborator Author

Thanks for the reply @gante! I indeed had found this difference in the code. Do you know whether there are any plans to make more decoders compatible?

@gante
Copy link
Member

gante commented May 16, 2023

@BramVanroy not on our end, since Decoder-only models have been stealing the spotlight!

We'd be happy to merge the appropriate changes, though

@BramVanroy
Copy link
Collaborator Author

Okay, that makes sense. Different priorities! Thanks for the reply João.

@alvations
Copy link

Managed to find this issue after https://twitter.com/alvations/status/1763206884205425054

And then trying to do this:

from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, DataCollatorForSeq2Seq
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

import torch
from transformers import EncoderDecoderModel


bert2mistral = EncoderDecoderModel.from_encoder_decoder_pretrained(
    "bert-base-cased", "mistralai/Mistral-7B-v0.1",
)

Not sure if it's "useful" eventually. But it'll be nice if the above is possible.

@michaelmior
Copy link

@gante Could you update which block you're referring to? Ideally using a permanent link if possible. Thanks!

@gante
Copy link
Member

gante commented May 29, 2024

(@michaelmior updated with a permalink)

@Tarak200
Copy link

can anyone suggest any decoder only model which is compatible with the graphcodeBERT?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants