-
Notifications
You must be signed in to change notification settings - Fork 27.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix AutoModel.from_pretrained(..., torch_dtype=...)
#13209
fix AutoModel.from_pretrained(..., torch_dtype=...)
#13209
Conversation
Note, I first tried a simple monkeypatching method, but it doesn't work with C extensions, which
got:
|
if is_torch_available(): | ||
import torch | ||
|
||
self.torch_dtype = getattr(torch, self.torch_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, so torch is not always available when config.torch_dtype is not None
- so now config.torch_dtype
isn't always of torch.dtype
.
I'm thinking perhaps this whole approach needs to be redone and only use the "float32", "float16", etc. strings everywhere, including the torch_dtype
arg in from_pretrained
and from_config
args. And only convert to torch.dtype
at the point it's used when the model is loaded.
That way torch_dtype
doesn't need to have a special handling at config level.
I hope this is recent/experimental enough that it's ok that we break the API.
Actually, if we do that, why even bother with torch_
in torch_dtype
and not just rename it to dtype
- perhaps non-pt frameworks could tap into it as well? After all fp16-saved data by torch isn't any different from flux or tf, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can't think of a scenario where one would want one dtype for one framework and another dtype for another - so changing it to dtype
sounds good to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the main concern is the back-compat in the API arg torch_dtype
- if it's OK if we break it, then I propose both the config and the arg in from_pretrained
and from_config
to be just a dtype
as a string: "auto", "float32", "float16", etc.
And then in the case of torch we convert it to the right torch.dtype
on the fly. perhaps flux/tf could use this too down the road.
Sylvain is not here for another week. Do you both support this breaking API change, @LysandreJik and @patrickvonplaten?
So instead of:
torch_dtype (:obj:`str` or :obj:`torch.dtype`, `optional`):
Override the default ``torch.dtype`` and load the model under this dtype. If ``"auto"`` is passed the
dtype will be automatically derived from the model's weights.
It will be:
dtype (:obj:`str`, `optional`):
Override the default ``dtype`` and load the model under this dtype, (e.g., ``"float16"``).
If ``"auto"`` is passed the
dtype will be automatically derived from the model's weights.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On the other hand we already have the dtype
attribute in modeling_utils, which returns torch.dtype
So my proposal might be confusing.
Should I call it dtype_str
perhaps?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If no version with torch_dtype
has been released yet, I'm fine with changing it to dtype
. However, note that in Flax we already have a dtype
variable that is used to define the dtype the matmul operations are run in instead of the dtype of the actual weights. In Flax we would like to take this design: #13098 as outlined by @patil-suraj
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@stas00 regarding the PR: #13098 - the idea of the PR is exactly to disentangle parameter dtype from matmul/computation dtype. In Flax, it's common practice that the dtype
parameter defines the matmul/computation dtype, see: https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.Dense.html#flax.linen.Dense.dtype instead of the parameter dtype and not the parameter dtype.
So for Flax, I don't really think it would make sense to use a config.dtype
to define weights dtype as it would be quite confusing with Flax's computation dtype parameter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would feel more comfortable to merge this PR without breaking changes and keeping the parameter called
torch_dtype
.
Works for me. Although this API is somewhat bad at the moment due to inconsistent type of values in the config file and the function - the former a string, the latter a torch.dtype
. Perhaps I can change these to support both string "float32" and torch.dtype in the same param of the function.
Think it's worth to have a separate discussion here regarding a framework-agnostic
dtype
parameter for all PyTorch, Tensorflow, and Flax once @sgugger is back.
Agreed!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In Flax, it's common practice that the dtype parameter defines the matmul/computation dtype, see: https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.Dense.html#flax.linen.Dense.dtype instead of the parameter dtype and not the parameter dtype.
So it's somewhat similar to dtype
arg in the new pytorch autocast
feature it seems then, correct? (before it was a hardcoded fp16, but now it has a dtype
arg to support bf16 too.)
p.s. it's currently called fast_dtype
but will renamed shortly to dtype
in pt-nightly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good to me to keep torch_dtype
and have a separate discussion for the framework-agnostic dtype
parameter, to which torch_dtype
could be an alias to prevent breaking changes to the existing API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds like a plan, @LysandreJik
Issue created: #13246
AutoModel.from_pretrained(..., torch_dtype=...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me!
Would actually be nice to have this merged soon to allow GPT-J to be loaded with AutoModelForCausalLM
: #13022
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for working on it @stas00
if is_torch_available(): | ||
import torch | ||
|
||
self.torch_dtype = getattr(torch, self.torch_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can't think of a scenario where one would want one dtype for one framework and another dtype for another - so changing it to dtype
sounds good to me.
This PR fixes one of the 2 issues reported in #13076
Additionally, it corrects the config object to convert the short "float32" string into
torch.float32
at object creation time.Note, I had a to change
from_dict
a bit to preservetorch_dtype
arg inAutoModel.from_pretrained(..., torch_dtype=...), as without this change
from_pretrained` was ignoring this argument.To remind, the issue is that we decided to store
torch_dtype
in the config object, but ignore it for now at load time. Which this PR also documents.Of course, tests added.
Thank you.
Fixes: #13076
(note: 2 separate issues were reported there but it looks like only this is the real issue, so linking to close it with this PR)
@sgugger, @LysandreJik