Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: assisted_decoding now accepts arbitrary candidate generators #27750

Merged
merged 6 commits into from
Dec 12, 2023

Conversation

gante
Copy link
Member

@gante gante commented Nov 28, 2023

What does this PR do?

A common trend is starting to pop up: people are experimenting with new strategies to generate candidate sequences, to then run an assisted-generation-like strategy. A key example is the new technique in #27722, which is equal to assisted_decoding except for the candidate generation part. This technique in particular achieves nice speedups in some settings, and doesn't need an assistant model -- a model-free speedup!

To facilitate experimentation and the addition of new candidate generation techniques, this PR abstracts the candidate generation part in assisted_decoding to a new class with a stable API. This was inspired in classes like LogitsProcessor or StoppingCriteria -- components of generate that can easily be replaced. All these changes are backwards compatible! 🤗

Suggested review order:

  1. utils.py, to see the shape of assisted_decoding under the abstracted API
  2. candidate.py, to see the structure of the new base class (and the specific case of the original assisted generation)

The following tests are passing:

  1. RUN_SLOW=1 py.test tests/models/whisper/ -k speculative
  2. py.test tests/ -k test_assisted (which catches mixin and integration tests associated with assisted generation)

Happy to add more tests if needed :)

self.num_assistant_tokens = max(1.0, self.num_assistant_tokens - 1.0)


def _crop_past_key_values(model, past_key_values, maximum_length):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: these functions were moved here to avoid circular imports

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be good to use the new cache format no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

soon! (we still need to maintain retrocompatibility)

@@ -888,6 +895,29 @@ def _reorder_cache(self, past_key_values, beam_idx):
f" enable beam search for {self.__class__}"
)

def _get_candidate_generator(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function will be expanded as we add more CandidateGenerator :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so will the logic here be something like

  • check params in generation_config (some if else condition)
  • based on params, set candidate_generator

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@apoorvumang exactly!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it, so I'll write a PromptLookupCandidateGenerator that implements CandidateGenerator, and then wire it up in this function

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the plan is to have similar checks to the ones we have for the supported logits processor I guess?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker precisely, we will have flags to control which candidate generation strategies we have in place. I suspect that, because some candidate generation strategies are so cheap (like the one proposed in #27722), assisted generation may become mainstream!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@apoorvumang
Copy link
Contributor

One more place needs to change I think - generation_mode is currently set using _get_generation_mode , where this is the logic:

if assistant_model is not None:
            if generation_mode in ("greedy_search", "sample"):
                generation_mode = GenerationMode.ASSISTED_GENERATION
            else:
                raise ValueError(
                    "You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate "
                    "is only supported with Greedy Search and Sample."
                )

how do u suggest this should change to support prompt lookup decoding? @gante

@gante
Copy link
Member Author

gante commented Nov 30, 2023

@apoorvumang I'd add an or after if assistant_model is not None

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔥 would maybe rename the file with candiate_generators like we have logits_processor but otherwise great!

@@ -888,6 +895,29 @@ def _reorder_cache(self, past_key_values, beam_idx):
f" enable beam search for {self.__class__}"
)

def _get_candidate_generator(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the plan is to have similar checks to the ones we have for the supported logits processor I guess?

src/transformers/generation/utils.py Outdated Show resolved Hide resolved
src/transformers/generation/utils.py Show resolved Hide resolved
src/transformers/generation/utils.py Show resolved Hide resolved
src/transformers/generation/utils.py Show resolved Hide resolved
src/transformers/generation/candidates.py Outdated Show resolved Hide resolved
self.num_assistant_tokens = max(1.0, self.num_assistant_tokens - 1.0)


def _crop_past_key_values(model, past_key_values, maximum_length):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be good to use the new cache format no?

@gante gante force-pushed the arbitrary_candidate_fn branch from 77a8b67 to a57367b Compare December 11, 2023 19:57
@gante gante merged commit 4b759da into huggingface:main Dec 12, 2023
22 checks passed
@gante gante deleted the arbitrary_candidate_fn branch December 12, 2023 09:26
iantbutler01 pushed a commit to BismuthCloud/transformers that referenced this pull request Dec 16, 2023
…ors (huggingface#27750)

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
staghado pushed a commit to staghado/transformers that referenced this pull request Jan 15, 2024
…ors (huggingface#27750)

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants