-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
respect dtype of the the model when instiating not working #13076
Comments
Thank you for the great report, @hwijeen I'm able to reproduce both problems:
Once I get a chance I will work on it and we will sort it out. |
ok, where did you find This works just fine:
|
Oh, I see. But setting it "auto" does not seem to work: Just for a sanity check, I tried loading my own model whose weight is float16 and the result was the same. It seems that |
why do you think it's float16? the auto-detector checks the first entry:
but we can look at all of them:
Also I think Sam was making many test models You can see the test that saves as fp16 and then auto-detects it to be fp16: transformers/tests/test_modeling_common.py Lines 1687 to 1692 in c89180a
|
I was not sure whether
I tried to load this model with two ways, and only one yields the correct result: # load correctly with specific model class
GPT2LMHeadModel.from_pretrained(".", torch_dtype="auto").dtype
torch.float16
# but fails with AutoModelForCausalLM
AutoModelForCausalLM.from_pretrained(".", torch_dtype="auto").dtype
torch.float32 The test cases you linked seem to be using specific model classes, so perhaps this is AutoModel's fault? |
Yes, clearly
The question is what to do with models that have mixed dtypes - typically a model is either fp16 or fp32. I can see how a custom buffer may be of fp32 while the params are in fp16. Could you explain your situation and how mixed is your model? |
I am using Megatron-LM by Nvidia. As you may know, this code trains a billion scale language model using various parallelism techniques. One thing to note is that this library does not rely on apex amp to achieve mixed precision training, and it has a complicated and self-contained code to deal with fp16 -- so I would say that models with various data types are not a usual case and is not a higher priority. But the |
Which of the AutoModel problems are you referring to? If it's the pickle issue, then one needs some kind of By all means if you can solve it, it'd be super helpful. If it's the auto-detection failing because it checks the first key entry, then before solving it, as suggested we need to discuss what to do if the model has mixed dtypes. I suppose with just fp16/fp32 it obviously should be auto=fp32, but now we are going to have other types like bf16, so hardcoding is going to be an issue. I'm going to be offline for the next 3 days and can follow up next on Friday. |
Running on the official checkpoint:
prints: so there are only fp32 keys in that official checkpoint. But that's just that checkpoint. Which keys do you get when you run the quick check from above (last line of code with |
Ah, of course, the above test is wrong, because it relies on transformers, which by default loads in fp32, need to recode to do it based on the checkpoint. here you go:
so it's all fp16. not mixed. but again, this is just this checkpoint. |
When I opened the official checkpoint with
In my case, I get a mixture of |
As you pointed out, dealing with mixed data type is complicated and needs further discussion. On the other hand, I think
Thanks for the quick workaround! |
Right, so my 2nd attempt was potentially wrong too, since the original checkpoint went through a conversion and I guess it could have ignored the original dtypes and made it fp16 all. However doing it the right way hopefully inspecting the original and based on your code:
is still fp16 (for this checkpoint). Perhaps when the model is mixed, The problem is not |
Perhaps let's open a new Issue that focuses just on this separate issue and please tag me, sgugger and LysandreJik on it. Thank you! You can use the above one liner to show us the mixed keys your model contains and then it'd be easier to understand what's going on. |
Oh, I double-checked and confirmed that the weights in Megatron-LM checkpoint are all in fp16. It was the conversion script that made the checkpoint have mixed data type. Specifically, this line produces So at least in my case, my model is not a mixed data type -- are there any cases where data types are mixed? If not, I think a new issue is not necessary? |
I asked the same question when working on the original feature and those who followed up, said they didn't think they saw such cases. I can only think of a registered buffer which can be of whatever dtype and be different from the weights. That's said perhaps down the road we should check that indeed all the weights have the same dtype, so we don't accidentally set a dtype that is not like the rest. But let's worry about it if it becomes a problem. |
Environment info
transformers
version: 4.9.2Who can help
@stas00 as he is the writer of the #12316
Information
Model I am using (Bert, XLNet ...):
The problem arises when using:
The tasks I am working on is:
To reproduce
First case:
The above code results in
Second case:
yields the following error.
Expected behavior
First case
Regarding the first case, setting torch_dtype works with AutoModel as well as specific model classes.
Can this be fixed?
It would be convenient for me if we could sue "torch_dtype" key-value pair in config.json which is not supported in the current version.
Second case
Shouldn't the second case run without any errors?
The text was updated successfully, but these errors were encountered: