Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow setting different decoder_start_token_ids for each item in a batch in the generate function. #28763

Open
dpernes opened this issue Jan 29, 2024 · 7 comments
Labels
Feature request Request for a new feature

Comments

@dpernes
Copy link

dpernes commented Jan 29, 2024

Feature request

@gante
The generate function has a decoder_start_token_id argument that allows the specification of the decoder start token when generating from an encoder-decoder model (e.g. mT5). Currently, decoder_start_token_id must be an integer, which means that the same start token is used for all elements in the batch. I request that you allow the specification of different start tokens for each element of the batch. For this purpose, decoder_start_token_id must be a tensor with shape (batch_size,).

Motivation

Some multilingual encoder-decoder models use the decoder_start_token_id to indicate the target language. Thus, this change would allow generation into multiple target languages in parallel, as illustrated in the code below.

Your contribution

import re
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

WHITESPACE_HANDLER = lambda k: re.sub('\s+', ' ', re.sub('\n+', ' ', k.strip()))

article_text = """Videos that say approved vaccines are dangerous and cause autism, cancer or infertility are among those that will be taken down, the company said.  The policy includes the termination of accounts of anti-vaccine influencers.  Tech giants have been criticised for not doing more to counter false health information on their sites.  In July, US President Joe Biden said social media platforms were largely responsible for people's scepticism in getting vaccinated by spreading misinformation, and appealed for them to address the issue.  YouTube, which is owned by Google, said 130,000 videos were removed from its platform since last year, when it implemented a ban on content spreading misinformation about Covid vaccines.  In a blog post, the company said it had seen false claims about Covid jabs "spill over into misinformation about vaccines in general". The new policy covers long-approved vaccines, such as those against measles or hepatitis B.  "We're expanding our medical misinformation policies on YouTube with new guidelines on currently administered vaccines that are approved and confirmed to be safe and effective by local health authorities and the WHO," the post said, referring to the World Health Organization."""

model_name = "csebuetnlp/mT5_m2m_crossSum_enhanced"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

get_lang_id = lambda lang: tokenizer._convert_token_to_id(
    model.config.task_specific_params["langid_map"][lang][1]
)

target_langs = ["portuguese", "spanish"]

input_ids = tokenizer(
    [WHITESPACE_HANDLER(article_text)],
    return_tensors="pt",
    padding="max_length",
    truncation=True,
    max_length=512
)["input_ids"]
input_ids = input_ids.expand(len(target_langs), -1)   # shape (num_target_languages, num_input_tokens)

decoder_start_token_id = torch.tensor(
    [get_lang_id(t) for t in target_langs],
    dtype=input_ids.dtype,
    device=input_ids.device
)  # shape (num_target_languages,)

output_ids = model.generate(
    input_ids=input_ids,
    decoder_start_token_id=decoder_start_token_id,
    max_length=84,
    no_repeat_ngram_size=2,
    num_beams=4,
)

summaries = tokenizer.batch_decode(
    output_ids,
    skip_special_tokens=True,
    clean_up_tokenization_spaces=False
)

print(summaries)
@ArthurZucker ArthurZucker added the Feature request Request for a new feature label Jan 30, 2024
@gante
Copy link
Member

gante commented Feb 5, 2024

cc @zucchini-nlp

@zucchini-nlp
Copy link
Member

@dpernes Hi, if you want to specify in different decoder_start_token_ids for each element, you can do it by passing a tensor of shape (batch_size, seq_len). In your case adding this line before the generate is called will solve the issue:

decoder_start_token_id = decoder_start_token_id.unsqueeze(1) # shape (num_target_languages, 1)

@dpernes
Copy link
Author

dpernes commented Feb 6, 2024

Great, thank you @zucchini-nlp! This behavior is not documented, though:

decoder_start_token_id (`int`, *optional*):
            If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token.

You may want to change it to something like:

decoder_start_token_id (`Union[int, torch.LongTensor]`, *optional*):
            If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token. Optionally, use a `torch.LongTensor` of shape `(batch_size, sequence_length)` to specify a prompt for the decoder.

But why isn't this the same as passing decoder_input_ids to generate? I tried passing the same tensor as decoder_input_ids instead of decoder_start_token_id and the results do not match.

@zucchini-nlp
Copy link
Member

Thanks, I added a PR extending the docs.

Regarding your question, there is a subtle difference between them. The decoder_start_token_id is used as the very first token in generation, BOS token in most cases. But decoder_input_ids are used to start/continue the sentence from them. In most cases you do not provide decoder_input_ids yourself when calling generate, so they will be filled with decoder_start_token_id to start generation from BOS.

The general format is [decoder_start_token_id, decoder_input_ids] and the generate automatically fills in decoder_start_token_id from config if you do not provide them.

@tehranixyz
Copy link

Hi,
Is there any way to specify decoder_start_token_id during training as well?
Like

outputs = model(
                    input_ids=batch["input_ids"],
                    attention_mask=batch["attention_mask"],
                    labels=batch["labels"],
                   decoder_start_token_id=decoder_start_token_id,
                )
loss = outputs.loss

Each batch may require a different decoder_start_token_id during training. This is because each batch has a specific input language and output language. Sometimes, the output language is and some other times it is .
Changing model.config.decoder_start_token_id per each batch doesn't seem to be a good approach. Specifically, it seems it causes lots of inconsistency when using Accelerator with DeepSpeed.

@zucchini-nlp
Copy link
Member

Hey @tehranixyz , you do not need to specify decoder_start_token_ids while training. All you need is to prepare the decoder_input_ids and pass it to the forward. We use the start token from model config only when we do not find decoder_input_ids from the user (see code snippet for preparing decoder input ids from labels)

@tehranixyz
Copy link

Gotcha!
I was a bit confused by the warning saying
The decoder_input_ids are now created based on the "labels", no need to pass them yourself anymore. when using EncoderDecoderModel.
So in my case, I guess, as you said, I have to prepare decoder_input_ids myself by shifting labels and adding the appropriate start_token at the beginning.
Many thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Feature request Request for a new feature
Projects
None yet
Development

No branches or pull requests

5 participants