-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generate: fix candidate device placement #28493
Conversation
candidate_input_ids, candidate_logits = candidate_generator.get_candidates( | ||
input_ids.to(candidate_generator.assistant_model.device) | ||
) | ||
candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in a nutshell, input_ids
should only be moved on some candidate_generator
classes -- the move was delegated to the relevant classes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for taking care of it
* fix candidate device * this line shouldn't have been in
* fix candidate device * this line shouldn't have been in
* fix candidate device * this line shouldn't have been in
* fix candidate device * this line shouldn't have been in
* fix candidate device * this line shouldn't have been in
What does this PR do?
#27775 was merged, and the branch was not synced with #27995 (already on
main
) -- the two branches together result in CI failures. Fortunately, the fix is simple :)