Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improving T5 Docs #16614

Closed
alexcoca opened this issue Apr 5, 2022 · 7 comments
Closed

Improving T5 Docs #16614

alexcoca opened this issue Apr 5, 2022 · 7 comments

Comments

@alexcoca
Copy link

alexcoca commented Apr 5, 2022

Who can help

@NielsRogge @patrickvonplaten @sgugger

Documentation: @sgugger

Information

Model I am using : T5ForConditionalGeneration

The problem arises when using:

  • my own modified scripts

The tasks I am working on is:

  • my own task or dataset

To reproduce

#13240 is a really nice PR that adds a lot of clarity to the documentation. However, in the examples provided we read

tokenizer.pad_token = tokenizer.eos_token  # to avoid an error

and I feel more information could be given to explain users why this is necessary. I am currently attempting to do batched decoding with T5 and observing very strange outputs, and therefore I'm keen to understand if this is a problem. Very soon I will test to understand whether the strange behaviour is due to batching or not, but it would be great to enhance the docs to explain the error that would occur and why.

Expected behavior

One or two extra sentences to explain why we left-pad encoder input sequences with when doing batched decoding for T5.

@NielsRogge
Copy link
Contributor

NielsRogge commented Apr 5, 2022

Hi,

Thanks for your question! To be honest it wasn't clear for me neither, I guess it's set as otherwise it might complain that no padding token is set.

I took that snippet from this PR: #7552.

It includes the comment:

# when generating, we will use the logits of right-most token to predict the next token
# so the padding should be on the left

However, that was for a decoder-only model (GPT-2). Not sure whether the same is required for an encoder-decoder one like T5. Maybe @patrickvonplaten can clarify here.

@alexcoca
Copy link
Author

alexcoca commented Apr 6, 2022

Thanks @NielsRogge, this would be very helpful indeed. I did the following checks this morning:

  1. Run my original, non-batched decoder
  2. Run batched decoding, implemented following the docs

The results are vastly different.

For 1 we get:

{
    "1_00000": {
        "0": {
            "utterance": "hi, could you get me a restaurant booking on the 8th please?",
            "Restaurants_2": {
                "predicted_str": " [states] 10:the 8th [intents] i1 [req_slots] <EOS>"
            }
        },
        "1": {
            "utterance": "could you get me a reservation at p.f. chang's in corte madera at afternoon 12?",
            "Restaurants_2": {
                "predicted_str": " [states] 0:corte madera 2:the 8th 9:p.f. chang's 10:afternoon 12 [intents] i1 [req_slots] <EOS>"
            }
}

For 2 we get:

{
    "1_00000": {
        "0": {
            "utterance": "hi, could you get me a restaurant booking on the 8th please?",
            "Restaurants_2": {
                "predicted_str": " [states] 10:the 8th [intents] i1 [req_slots] i1 [req_slots] i1 [req_slots] i1 [req_slots] i1 [req_slots] i1 [req_slots] [intents] [intents] [intents] [intents] [intents] [intents] [intents] [intents] [intents] [intents]  <EOS>"
            }
        },
        "1": {
            "utterance": "could you get me a reservation at p.f. chang's in corte madera at afternoon 12?",
            "Restaurants_2": {
                "predicted_str": " [states] 0:corte madera 2:the 8th 9:p.f. chang's 10:afternoon 12 [intents] i1 [req_slots] i1 [req_slots] i1 [req_slots] i1 [req_slots] i1 [req_slots] [intents] [intents] [intents] [intents] [intents] [intents] [intents] [intents] [intents] [intents] [intents] [intents] [intents] [intents] [intents] [intents] [intents] [intents] [intents] [intents] [intents] [intents] [intents] [intents] [intents] [intents] [intents] [intents] [intents] [intents] [intents] [intents] [intents] [intents] 0 [intents] i1 i1 i1 i1 i1 i1 i1 i1 i1 i1 i1 i1 i1 i1 i1 i1 i1 i1 i1 i1 i1 i1 i1 i <EOS>"
            }
        },
}

The results above are obtained with batch_size=5 but setting batch_size=1 gives identical results. Therefore my "batched" implementation is somehow subtly different to the one giving correct results.

I'll try to figure out why this occurs. I should state the I can make the code/checkpoints to replicate the above issues available to you for debugging. I'll keep you posted.

Note: The results in 2. are after manually post-processing to remove many <EOS> strings. I can post the "raw" version if it is helpful with debugging!

@alexcoca
Copy link
Author

alexcoca commented Apr 6, 2022

Ok, I debugged my code and the preliminary test passed. The change has been to change the generate API call from

output_seqs = model.generate(
    input_ids=input_ids.to(DEVICE),
    attention_mask=attention_mask.to(DEVICE),
    max_length=args.decoder_max_seq_len,
    use_cache=True,
)

to

output_seqs = model.generate(
    input_ids=input_ids.to(DEVICE),
    attention_mask=attention_mask.to(DEVICE),
    max_length=args.decoder_max_seq_len,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.pad_token_id,
    use_cache=True,
)

I'm not sure if this is an omission in the docs or why this is a fix. An explanation would be appreciated! Maybe something note worthy is that the last id in the input_ids is the <EOS> token ID when batch size is 1 (this also holds true of my non-batched implementation). The output tensors then undergo the following postprocessing steps:

output_strings = tokenizer.batch_decode(output_seqs)
output_strings = remove_padding(output_strings)

where remove_padding is

def remove_padding(output_strings: list[str], pad_token: str) -> list(str):
    padding_free = []
    for s in output_strings:
        pad_token_start = s.find(pad_token)
        while pad_token_start != -1:
            s = f"{s[:pad_token_start]}{s[pad_token_start+len(pad_token):].lstrip()}"
            pad_token_start = s.find(pad_token)
        padding_free.append(f"{s} {pad_token}")
    return padding_free

By contrast, the implementation that does not use batching uses the call:

            output_seqs = model.generate(
                input_ids.to(DEVICE),
                max_length=args.decoder_max_seq_len,
                do_sample=False,
                temperature=1.0,
                use_cache=True,
                num_beams=1,
                bos_token_id=tokenizer.bos_token_id,
                eos_token_id=tokenizer.eos_token_id,
                pad_token_id=tokenizer.pad_token_id,
                early_stopping=True,
            )

and postprocessing is simply

output_strings = tokenizer.decode(output_seqs[0]) 

@patrickvonplaten
Copy link
Contributor

Good point @alexcoca ,

@NielsRogge @alexcoca Yeah this looks like it was a bad copy-paste from GPT2.

Should be corrected here: #16646

@alexcoca
Copy link
Author

alexcoca commented Apr 7, 2022

Thanks @patrickvonplaten! So from your PR I understand that it is not necessary to set the tokenizer padding to <EOS> as we do for GPT-2 and my fix worked because I passed the correct pad_token_id to generate in my fix. So I'll revert both and expect this to work as well.

On a different note, I ran a large scale test on batched inference. I get 70.20432% accuracy when I decode with batch size 16 and 70.19564% when I run my original code. The difference is too small to matter but it does show that the inferences in the two cases functions slightly differently!

@patrickvonplaten
Copy link
Contributor

I could be related to #14859 (comment) actually

@github-actions
Copy link

github-actions bot commented May 7, 2022

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants