Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

respect dtype of the the model when instiating not working #13076

Closed
4 tasks
hwijeen opened this issue Aug 11, 2021 · 16 comments · Fixed by #13209
Closed
4 tasks

respect dtype of the the model when instiating not working #13076

hwijeen opened this issue Aug 11, 2021 · 16 comments · Fixed by #13209
Assignees

Comments

@hwijeen
Copy link
Contributor

hwijeen commented Aug 11, 2021

Environment info

  • transformers version: 4.9.2
  • Platform: Linux-4.18.0-25-generic-x86_64-with-glibc2.10
  • Python version: 3.8.5
  • PyTorch version (GPU?): 1.8.0a0+52ea372 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?: No

Who can help

@stas00 as he is the writer of the #12316

Information

Model I am using (Bert, XLNet ...):

The problem arises when using:

  • the official example scripts: (give details below)
  • my own modified scripts: (give details below)

The tasks I am working on is:

  • an official GLUE/SQUaD task: (give the name)
  • my own task or dataset: (give details below)

To reproduce

First case:

from transformers import AutoModel
AutoModel.from_pretrained("my_path", torch_dtype=torch.float16)

The above code results in

/opt/conda/envs/ml/lib/python3.7/site-packages/transformers/models/auto/auto_factory.py in from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)                                                                                                                                                                                              [40/1573]
    377         if not isinstance(config, PretrainedConfig):
    378             config, kwargs = AutoConfig.from_pretrained(
--> 379                 pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
    380             )
    381

/opt/conda/envs/ml/lib/python3.7/site-packages/transformers/models/auto/configuration_auto.py in from_pretrained(cls, pretrained_model_name_or_path, **kwargs)
    451         if "model_type" in config_dict:
    452             config_class = CONFIG_MAPPING[config_dict["model_type"]]
--> 453             return config_class.from_dict(config_dict, **kwargs)
    454         else:
    455             # Fallback: use pattern matching on the string.

/opt/conda/envs/ml/lib/python3.7/site-packages/transformers/configuration_utils.py in from_dict(cls, config_dict, **kwargs)
    579             kwargs.pop(key, None)
    580
--> 581         logger.info(f"Model config {config}")
    582         if return_unused_kwargs:
    583             return config, kwargs

/opt/conda/envs/ml/lib/python3.7/site-packages/transformers/configuration_utils.py in __repr__(self)
    611
    612     def __repr__(self):
--> 613         return f"{self.__class__.__name__} {self.to_json_string()}"
    614
    615     def to_diff_dict(self) -> Dict[str, Any]:

/opt/conda/envs/ml/lib/python3.7/site-packages/transformers/configuration_utils.py in to_json_string(self, use_diff)
    675         else:
    676             config_dict = self.to_dict()
--> 677         return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
    678
    679     def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True):

/opt/conda/envs/ml/lib/python3.7/json/__init__.py in dumps(obj, skipkeys, ensure_ascii, check_circular, allow_nan, cls, indent, separators, default, sort_keys, **kw)
    236         check_circular=check_circular, allow_nan=allow_nan, indent=indent,
    237         separators=separators, default=default, sort_keys=sort_keys,
--> 238         **kw).encode(obj)
    239
    240

/opt/conda/envs/ml/lib/python3.7/json/encoder.py in encode(self, o)
    199         chunks = self.iterencode(o, _one_shot=True)
    200         if not isinstance(chunks, (list, tuple)):
--> 201             chunks = list(chunks)
    202         return ''.join(chunks)
    203

/opt/conda/envs/ml/lib/python3.7/json/encoder.py in _iterencode(o, _current_indent_level)
    429             yield from _iterencode_list(o, _current_indent_level)
    430         elif isinstance(o, dict):
--> 431             yield from _iterencode_dict(o, _current_indent_level)
    432         else:
    433             if markers is not None:

/opt/conda/envs/ml/lib/python3.7/json/encoder.py in _iterencode_dict(dct, _current_indent_level)
    403                 else:
    404                     chunks = _iterencode(value, _current_indent_level)
--> 405                 yield from chunks
    406         if newline_indent is not None:
    407             _current_indent_level -= 1

/opt/conda/envs/ml/lib/python3.7/json/encoder.py in _iterencode(o, _current_indent_level)
    436                     raise ValueError("Circular reference detected")
    437                 markers[markerid] = o
--> 438             o = _default(o)
    439             yield from _iterencode(o, _current_indent_level)
    440             if markers is not None:

/opt/conda/envs/ml/lib/python3.7/json/encoder.py in default(self, o)
    177
    178         """
--> 179         raise TypeError(f'Object of type {o.__class__.__name__} '
    180                         f'is not JSON serializable')
    181

TypeError: Object of type dtype is not JSON serializable

Second case:

 m = GPT2LMHeadModel.from_pretrained(model_path, torch_dtype_auto_detect=True)

yields the following error.

/opt/conda/envs/ml/lib/python3.7/site-packages/transformers/modeling_utils.py in from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
   1319         else:
   1320             with no_init_weights(_enable=_fast_init):
-> 1321                 model = cls(config, *model_args, **model_kwargs)
   1322
   1323         if from_pt:

TypeError: __init__() got an unexpected keyword argument 'torch_dtype_auto_detect'

Expected behavior

First case
Regarding the first case, setting torch_dtype works with AutoModel as well as specific model classes.
Can this be fixed?
It would be convenient for me if we could sue "torch_dtype" key-value pair in config.json which is not supported in the current version.

Second case
Shouldn't the second case run without any errors?

@hwijeen hwijeen changed the title respect dtype of the the model when instiating with AutoModel respect dtype of the the model when instiating not working Aug 11, 2021
@stas00
Copy link
Contributor

stas00 commented Aug 11, 2021

Thank you for the great report, @hwijeen

I'm able to reproduce both problems:

python -c "from transformers import GPT2LMHeadModel; GPT2LMHeadModel.from_pretrained('sshleifer/tiny-gpt2', torch_dtype_auto_detect=True)"          
python -c "import torch; from transformers import AutoModel; AutoModel.from_pretrained('sshleifer/tiny-gpt2', torch_dtype=torch.float16)"

Once I get a chance I will work on it and we will sort it out.

@stas00 stas00 self-assigned this Aug 11, 2021
@stas00
Copy link
Contributor

stas00 commented Aug 11, 2021

ok, where did you find torch_dtype_auto_detect? The documented syntax is: torch_dtype='auto' for auto detection. Perhaps you were looking at the original proposal discussion before the API was selected?

This works just fine:

python -c "from transformers import AutoModel; AutoModel.from_pretrained('sshleifer/tiny-gpt2', torch_dtype='auto')"

@hwijeen
Copy link
Contributor Author

hwijeen commented Aug 12, 2021

Oh, I see. torch_dtype is the right keyword.

But setting it "auto" does not seem to work:
python -c "from transformers import AutoModel; m=AutoModel.from_pretrained('sshleifer/tiny-gpt2', torch_dtype='auto');print(m.dtype)" # This gives torch.float32.

Just for a sanity check, I tried loading my own model whose weight is float16 and the result was the same.
python -c "from transformers import AutoModel; m=AutoModel.from_pretrained(my_path, torch_dtype='auto');print(m.dtype)" # This gives torch.float32!

It seems that torch_dtype='auto' is not working as expected?

@stas00
Copy link
Contributor

stas00 commented Aug 12, 2021

why do you think it's float16?

the auto-detector checks the first entry:

$ wget https://huggingface.co/sshleifer/tiny-gpt2/resolve/main/pytorch_model.bin
$ python -c "import torch; sd=torch.load('pytorch_model.bin'); print(next(iter(sd.values())).dtype)"
torch.float32

but we can look at all of them:

python -c "import torch; sd=torch.load('pytorch_model.bin'); print([v.dtype for v in sd.values()])"
[torch.float32, torch.float32, torch.float32, torch.float32, torch.uint8, torch.float32, torch.float32, torch.float32, torch.float32, torch.float32, torch.float32, torch.float32, torch.float32, torch.float32, torch.float32, torch.float32, torch.float32, torch.float32, torch.uint8, torch.float32, torch.float32, torch.float32, torch.float32, torch.float32, torch.float32, torch.float32, torch.float32, torch.float32, torch.float32, torch.float32, torch.float32, torch.float32, torch.float32]

Also I think Sam was making many test models half(), perhaps just not this one? Try it on other of his tiny test models?

You can see the test that saves as fp16 and then auto-detects it to be fp16:

# test fp16 save_pretrained, loaded with auto-detection
model = model.half()
model.save_pretrained(model_path)
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
self.assertEqual(model.config.torch_dtype, "float16") # tests `config.torch_dtype` saving
self.assertEqual(model.dtype, torch.float16)

@hwijeen
Copy link
Contributor Author

hwijeen commented Aug 12, 2021

I was not sure whether sshleifer/tiny-gpt2 uses float16 or not, and that's why I tried with my own model (megatronLM) which (mostly) has float16.

python -c "import torch; sd=torch.load('pytorch_model.bin'); print([v.dtype for v in sd.values()])"
[torch.float16, torch.float16, torch.float16, torch.float16, torch.uint8, torch.float32, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.uint8, torch.float32, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.uint8, torch.float32, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.uint8, torch.float32, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.uint8, torch.float32, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.uint8, torch.float32, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.uint8, torch.float32, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.uint8, torch.float32, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.uint8, torch.float32, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.uint8, torch.float32, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.uint8, torch.float32, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.uint8, torch.float32, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16]

I tried to load this model with two ways, and only one yields the correct result:

# load correctly with specific model class
GPT2LMHeadModel.from_pretrained(".", torch_dtype="auto").dtype
torch.float16

# but fails with AutoModelForCausalLM
AutoModelForCausalLM.from_pretrained(".", torch_dtype="auto").dtype
torch.float32

The test cases you linked seem to be using specific model classes, so perhaps this is AutoModel's fault?

@stas00
Copy link
Contributor

stas00 commented Aug 12, 2021

Yes, clearly AutoModel goes through a different path and needs to be better tested and fixed.

I tried with my own model (megatronLM) which (mostly) has float16.

The question is what to do with models that have mixed dtypes - typically a model is either fp16 or fp32. I can see how a custom buffer may be of fp32 while the params are in fp16.

Could you explain your situation and how mixed is your model?

@hwijeen
Copy link
Contributor Author

hwijeen commented Aug 17, 2021

I am using Megatron-LM by Nvidia. As you may know, this code trains a billion scale language model using various parallelism techniques. One thing to note is that this library does not rely on apex amp to achieve mixed precision training, and it has a complicated and self-contained code to deal with fp16 -- so I would say that models with various data types are not a usual case and is not a higher priority.

But the AutoModel problem shown above looks like an urgent issue to me.. Are you planning to work on this in the near future?
(I would also be happy to look into the problem if you could share some hints.)

@stas00
Copy link
Contributor

stas00 commented Aug 17, 2021

But the AutoModel problem shown above looks like an urgent issue to me

Which of the AutoModel problems are you referring to?

If it's the pickle issue, then one needs some kind of to_json workaround for the torch.dtype class. It should be easy to just comment out that code as well, if it gets in the way and it's urgent as you say. Until it's resolved.

By all means if you can solve it, it'd be super helpful.

If it's the auto-detection failing because it checks the first key entry, then before solving it, as suggested we need to discuss what to do if the model has mixed dtypes. I suppose with just fp16/fp32 it obviously should be auto=fp32, but now we are going to have other types like bf16, so hardcoding is going to be an issue.

I'm going to be offline for the next 3 days and can follow up next on Friday.

@stas00
Copy link
Contributor

stas00 commented Aug 17, 2021

I am using Megatron-LM by Nvidia. As you may know, this code trains a billion scale language model using various parallelism techniques. One thing to note is that this library does not rely on apex amp to achieve mixed precision training, and it has a complicated and self-contained code to deal with fp16 -- so I would say that models with various data types are not a usual case and is not a higher priority.

Running on the official checkpoint:

wget --content-disposition https://api.ngc.nvidia.com/v2/models/nvidia/megatron_bert_345m/versions/v0.1_uncased/zip -O checkpoint.zip
python3 /hf/transformers-master/src/transformers/models/megatron_bert/convert_megatron_bert_checkpoint.py checkpoint.zip

python -c "from transformers import MegatronBertForMaskedLM; m = MegatronBertForMaskedLM.from_pretrained('.'); d = {p.dtype():1 for p in m.parameters() }; print(d.keys())"

prints: dict_keys([torch.float32])

so there are only fp32 keys in that official checkpoint. But that's just that checkpoint.

Which keys do you get when you run the quick check from above (last line of code with from_pretrained('.') adjusted to point to your model.

@stas00
Copy link
Contributor

stas00 commented Aug 17, 2021

Ah, of course, the above test is wrong, because it relies on transformers, which by default loads in fp32, need to recode to do it based on the checkpoint. here you go:

python -c "import torch; sd=torch.load('pytorch_model.bin');  d = {p.dtype:1 for p in sd.values() }; print(d.keys())"
dict_keys([torch.float16])

so it's all fp16. not mixed. but again, this is just this checkpoint.

@hwijeen
Copy link
Contributor Author

hwijeen commented Aug 17, 2021

so there are only fp32 keys in that official checkpoint. But that's just that checkpoint.

When I opened the official checkpoint with torch.load, it seems like it mostly has float16.

wget --content-disposition https://api.ngc.nvidia.com/v2/models/nvidia/megatron_bert_345m/versions/v0.1_uncased/zip -O checkpoint.zip
unzip checkpoint.zip
python -c "import torch; sd = torch.load('model_optim_rng.pt', map_location='cpu'); print([v.dtype for v in sd['model']['language_model']['transformer'].values()])`

[torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float16]

In my case, I get a mixture of float32, float16, uint8. Most of the params are float16 with masked_bias being float32 and bias being uint8. I am not 100% sure but I guess this has to do with Megatron version issue..

@hwijeen
Copy link
Contributor Author

hwijeen commented Aug 17, 2021

As you pointed out, dealing with mixed data type is complicated and needs further discussion.

On the other hand, I think AutoModel's pickle issue is orthogonal to this, and I will look into it when I have time (perhaps this weekend) and get back to you with if I find a solution :)

If it's the pickle issue, then one needs some kind of to_json workaround for the torch.dtype class. It should be easy to just comment out that code as well, if it gets in the way and it's urgent as you say. Until it's resolved.

Thanks for the quick workaround!

@stas00
Copy link
Contributor

stas00 commented Aug 17, 2021

Right, so my 2nd attempt was potentially wrong too, since the original checkpoint went through a conversion and I guess it could have ignored the original dtypes and made it fp16 all.

However doing it the right way hopefully inspecting the original and based on your code:

python -c "import torch; sd=torch.load('release/mp_rank_00/model_optim_rng.pt');  d = {p.dtype:1 for p in sd['model']['language_model']['transformer'].values() }; print(d.keys())"
dict_keys([torch.float16])

is still fp16 (for this checkpoint).

Perhaps when the model is mixed, from_pretrained() should assert and tell the user to choose one?

The problem is not transformers by torch which loads the weights under a fixed dtype. Unless we change the dtype context for each key perhaps?

@stas00
Copy link
Contributor

stas00 commented Aug 17, 2021

As you pointed out, dealing with mixed data type is complicated and needs further discussion.

Perhaps let's open a new Issue that focuses just on this separate issue and please tag me, sgugger and LysandreJik on it. Thank you!

You can use the above one liner to show us the mixed keys your model contains and then it'd be easier to understand what's going on.

@hwijeen
Copy link
Contributor Author

hwijeen commented Aug 20, 2021

Right, so my 2nd attempt was potentially wrong too since the original checkpoint went through a conversion and I guess it could have ignored the original dtypes and made it fp16 all.

Oh, I double-checked and confirmed that the weights in Megatron-LM checkpoint are all in fp16. It was the conversion script that made the checkpoint have mixed data type. Specifically, this line produces uint8 and this line float32. I'll open a new issue to address this.

So at least in my case, my model is not a mixed data type -- are there any cases where data types are mixed? If not, I think a new issue is not necessary?

@stas00
Copy link
Contributor

stas00 commented Aug 20, 2021

So at least in my case, my model is not a mixed data type -- are there any cases where data types are mixed? If not, I think a new issue is not necessary?

I asked the same question when working on the original feature and those who followed up, said they didn't think they saw such cases.

I can only think of a registered buffer which can be of whatever dtype and be different from the weights.

That's said perhaps down the road we should check that indeed all the weights have the same dtype, so we don't accidentally set a dtype that is not like the rest. But let's worry about it if it becomes a problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants
@stas00 @hwijeen and others