Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: add DeepMind's Speculative sampling in assisted_generation #27270

Conversation

domgri
Copy link

@domgri domgri commented Nov 3, 2023

What does this PR do?

Implements #27186. Still a draft, work in progress.

Implementation inspired from original paper and these[1][2] existing implementations.

Next steps:

  • Solve raised Todos in code TODO for speculative decoding:
  • Verify implementation
  • Then adhere to possible changes and fix them

Possible changes of PR:

  • modifies implementation of assisted_generation with do_sample=True
  • will possibly affect this blog post
  • will possibly affect assisted_generation documentation

Before submitting

@amyeroberts
Copy link
Collaborator

cc @gante

@gante
Copy link
Member

gante commented Nov 7, 2023

Hey @domgri 👋

Thank you for opening the PR! Let me know when you'd like a review 💪

@domgri
Copy link
Author

domgri commented Nov 12, 2023

Sure, absolutely. Sorry for not responding sooner, got some unexpected workload, hope to comeback to finish implementation in a week🤞

@domgri
Copy link
Author

domgri commented Nov 16, 2023

Hey, so I gave a couple of tries to finish implementation, although with little to no success 😕.

A couple of takeaways that might be useful for anyone continuing or trying to work on this implementation:

  • Initial PR might be useful for overall vision, how this feature could be implemented (TODOs how potential places for modifications).
  • Sampling cases (1, 2) could be improved with something more sophisticated (and possibly already existing functionalities).
  • From second iteration, main model model_inputs.input_ids would not match up with candidate_input_ids (usually would be shorter and containing only several last tokens from candidate_input_ids. I suspect something with cache and/or **candidate_kwargs had an effect on that, though, could not figure out exactly how and what.
    model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs)
  • tmp_result = tmp_max / tmp_max_sum was returning array of nan rather instead of 0's. Possibly related to max_fn implementation that migh be faulty.

I will close this PR since I am out of capacity right now to continue working on it. Feel free to use this PR as an inspiration for actual implementation. Thanks for enthusiastic welcome @amyeroberts @gante, my apologies for not really delivering much value and hope to see someone else step up and contribute more meaningfully 😊.

@domgri domgri closed this Nov 16, 2023
@gante
Copy link
Member

gante commented Nov 17, 2023

@domgri no worries! Thank you for giving it a shot 🤗

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants