-
Notifications
You must be signed in to change notification settings - Fork 28.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pass device in Logits Processor's init #29804
Pass device in Logits Processor's init #29804
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall notes before going to details:
- In the processors that take
eos_token_id
as input: see Generate: consistently handle special tokens as tensors #29788. In this PR, the special tokens are treated as tensors by default, solving most of the needed changes. I would rebase this PR onmain
after that PR is merged, as some of the changes here will become redundant :) - On the processors that don't need to use
device
, such asTemperatureLogitsWarper
-- let's not add unused arguments. Clean interfaces are important 🧼 (unless there are significant benefits from standardizing them) - Let's not throw a warning when the device is not passed and tensors are initialized on CPU. A
.to
operation is not that expensive :)
|
Not stale |
This PR now can be reviewed. Rebased main and updated the changes. All the tests from |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you for improving generate
:D
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
@gante Ah I forgot whisper is encoder-decoder. Oke, now it infers device from one of the inputs passed by the user. |
How could the bot come 🤣 anyways on it! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM, not sure input_ids device is the always the best, and we need a small test to see which feature is enable by this potentially!
if device is None: | ||
device = "cpu" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd argue that we can just set it to "cpu" in the arg no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is mostly for users who use/pass LogitsProcessor as a standalone kwarg, because 'generate()' takes care that device is not None.
I think we should raise warning for BC saying users to pass-in the device, but let's ask @gante if he's okay with it. If I am not misunderstanding, we shouldn't raise warnings 🤔
Let's not throw a warning when the device is not passed and tensors are initialized on CPU. A .to operation is not that expensive :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, don't think it's a problem to silently do this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Down to just default to CPU which was already the behaviour by default before this PR no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ahh my bad, didn't read carefully the first comment. Setting in the arg as default is better, right
My concern is that before this PR, we were placing these on scores.device
during "_ [call]_ " , but anyway I still get lost at when to do BC deprecation and when to not do 😄
self.eos_token_id = self.eos_token_id.to(scores.device) |
src/transformers/generation/utils.py
Outdated
@@ -1700,7 +1737,7 @@ def generate( | |||
encoder_input_ids=inputs_tensor, | |||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, | |||
logits_processor=logits_processor, | |||
device=inputs_tensor.device, | |||
device=input_ids.device, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this required ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right! I thought that it was me who changed to inputs_tensor
and was trying to revert 😆 I'll revert it back, no difference whichever tensor we use here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be use self.device? or lm_head.device? (which is not always there but still)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need to make sure dive placement on multi GPU works, might already be tested !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got it. Any how LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you rebase your branch ? (format changes seems unrelated?)
Oke, rebased main and the unnecessary formatting is removed. Will merge as I guess we don't need to add warnings :) |
What does this PR do?
This PR adds the ability to pass in device when initializing
LogitsProcessors
and is one more step towardscompile
compatibility.Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@gante