Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: fix generation with inputs_embeds when input_ids=None for llama and gemma #29821

Closed
wants to merge 2 commits into from

Conversation

njhill
Copy link
Contributor

@njhill njhill commented Mar 23, 2024

The changes in #29467 break generation with inputs_embeds when input_ids is None since they expect input_ids to be non-None even for the prefill forward without past_key_values.

@gante

The changes in huggingface#29467 break generation with inputs_embeds when input_ids is None since they expect input_ids to be non-None even for the prefill forward without past_key_values.
@gante
Copy link
Member

gante commented Mar 27, 2024

Hi @njhill 👋

Can you share an example of failure? We have a test for generation with inputs_embeds (which is passing on e.g. Llama), so our test suite is likely incomplete :)

@njhill
Copy link
Contributor Author

njhill commented Mar 27, 2024

Thanks @gante, this is specifically when using inputs_embeds and passing input_ids as None. I think both are passed to generate() in the current test. I'll update the PR title to clarify this.

@njhill njhill changed the title Generate: fix generation with inputs_embeds for llama and gemma Generate: fix generation with inputs_embeds when input_ids=None for llama and gemma Mar 27, 2024
@gante
Copy link
Member

gante commented Mar 28, 2024

@njhill uhmmm the test checks that combination as well 🤔

            # input_ids is not a required input -- if we don't pass it, the newly generated tokens will be the same
            outputs_from_embeds_wo_ids = model.generate(
                inputs_embeds=inputs_embeds, max_new_tokens=20 - inputs_embeds.shape[1]
            )
            self.assertListEqual(
                outputs_from_embeds[:, inputs_embeds.shape[1] :].tolist(),
                outputs_from_embeds_wo_ids.tolist(),
            )

(you can run the test with py.test tests/models/llama/test_modeling_llama.py -k test_generate_from_inputs_embeds_decoder_only)

This means that there is probably something else going on, which could be interesting to pin down :)

@njhill
Copy link
Contributor Author

njhill commented Mar 28, 2024

@gante ah apologies for not looking at those closely enough, and thank you for the tip of how to run. Let me dig in deeper to see what's going on here.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this Apr 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants