Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix AutoModel.from_pretrained(..., torch_dtype=...) #13209

Merged
merged 5 commits into from
Aug 24, 2021

Conversation

stas00
Copy link
Contributor

@stas00 stas00 commented Aug 21, 2021

This PR fixes one of the 2 issues reported in #13076

python -c "import torch; from transformers import AutoModel; AutoModel.from_pretrained('sshleifer/tiny-gpt2', torch_dtype=torch.float16)"
2021-08-20 18:45:07.802651: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/mnt/nvme1/code/huggingface/transformers-master/src/transformers/models/auto/auto_factory.py", line 382, in from_pretrained
    config, kwargs = AutoConfig.from_pretrained(
  File "/mnt/nvme1/code/huggingface/transformers-master/src/transformers/models/auto/configuration_auto.py", line 511, in from_pretrained
    return config_class.from_dict(config_dict, **kwargs)
  File "/mnt/nvme1/code/huggingface/transformers-master/src/transformers/configuration_utils.py", line 581, in from_dict
    logger.info(f"Model config {config}")
  File "/mnt/nvme1/code/huggingface/transformers-master/src/transformers/configuration_utils.py", line 613, in __repr__
    return f"{self.__class__.__name__} {self.to_json_string()}"
  File "/mnt/nvme1/code/huggingface/transformers-master/src/transformers/configuration_utils.py", line 677, in to_json_string
    return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
  File "/home/stas/anaconda3/envs/py38-pt19/lib/python3.8/json/__init__.py", line 234, in dumps
    return cls(
  File "/home/stas/anaconda3/envs/py38-pt19/lib/python3.8/json/encoder.py", line 201, in encode
    chunks = list(chunks)
  File "/home/stas/anaconda3/envs/py38-pt19/lib/python3.8/json/encoder.py", line 431, in _iterencode
    yield from _iterencode_dict(o, _current_indent_level)
  File "/home/stas/anaconda3/envs/py38-pt19/lib/python3.8/json/encoder.py", line 405, in _iterencode_dict
    yield from chunks
  File "/home/stas/anaconda3/envs/py38-pt19/lib/python3.8/json/encoder.py", line 438, in _iterencode
    o = _default(o)
  File "/home/stas/anaconda3/envs/py38-pt19/lib/python3.8/json/encoder.py", line 179, in default
    raise TypeError(f'Object of type {o.__class__.__name__} '
TypeError: Object of type dtype is not JSON serializable

Additionally, it corrects the config object to convert the short "float32" string into torch.float32 at object creation time.

Note, I had a to change from_dict a bit to preserve torch_dtype arg in AutoModel.from_pretrained(..., torch_dtype=...), as without this change from_pretrained` was ignoring this argument.

To remind, the issue is that we decided to store torch_dtype in the config object, but ignore it for now at load time. Which this PR also documents.

Of course, tests added.

Thank you.

Fixes: #13076

(note: 2 separate issues were reported there but it looks like only this is the real issue, so linking to close it with this PR)

@sgugger, @LysandreJik

@stas00 stas00 mentioned this pull request Aug 21, 2021
5 tasks
@stas00
Copy link
Contributor Author

stas00 commented Aug 21, 2021

Note, I first tried a simple monkeypatching method, but it doesn't work with C extensions, which torch.dtype is:

        if config.torch_dtype is not None:
            # in v5 convert str to torch.dtype
            import torch
            if not hasattr(torch.dtype, "to_json_string"):
                import builtins
                #torch.dtype.to_json_string = builtins.str
                setattr(torch.dtype, "to_json_string", builtins.str)

got:

setattr(torch.dtype, "to_json_string", builtins.str)
TypeError: can't set attributes of built-in/extension type 'torch.dtype'

@stas00 stas00 changed the title fix AutoModel.from_pretrained(..., torch_dtype=...) [WIP] fix AutoModel.from_pretrained(..., torch_dtype=...) Aug 21, 2021
Comment on lines +280 to +283
if is_torch_available():
import torch

self.torch_dtype = getattr(torch, self.torch_dtype)
Copy link
Contributor Author

@stas00 stas00 Aug 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, so torch is not always available when config.torch_dtype is not None - so now config.torch_dtype isn't always of torch.dtype.

I'm thinking perhaps this whole approach needs to be redone and only use the "float32", "float16", etc. strings everywhere, including the torch_dtype arg in from_pretrained and from_config args. And only convert to torch.dtype at the point it's used when the model is loaded.

That way torch_dtype doesn't need to have a special handling at config level.

I hope this is recent/experimental enough that it's ok that we break the API.

Actually, if we do that, why even bother with torch_ in torch_dtype and not just rename it to dtype - perhaps non-pt frameworks could tap into it as well? After all fp16-saved data by torch isn't any different from flux or tf, no?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't think of a scenario where one would want one dtype for one framework and another dtype for another - so changing it to dtype sounds good to me.

Copy link
Contributor Author

@stas00 stas00 Aug 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the main concern is the back-compat in the API arg torch_dtype - if it's OK if we break it, then I propose both the config and the arg in from_pretrained and from_config to be just a dtype as a string: "auto", "float32", "float16", etc.

And then in the case of torch we convert it to the right torch.dtype on the fly. perhaps flux/tf could use this too down the road.

Sylvain is not here for another week. Do you both support this breaking API change, @LysandreJik and @patrickvonplaten?

So instead of:

            torch_dtype (:obj:`str` or :obj:`torch.dtype`, `optional`):
                Override the default ``torch.dtype`` and load the model under this dtype. If ``"auto"`` is passed the
                dtype will be automatically derived from the model's weights.

It will be:

            dtype (:obj:`str`, `optional`):
                Override the default ``dtype`` and load the model under this dtype, (e.g., ``"float16"``). 
                If ``"auto"`` is passed the
                dtype will be automatically derived from the model's weights.

Copy link
Contributor Author

@stas00 stas00 Aug 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the other hand we already have the dtype attribute in modeling_utils, which returns torch.dtype

https://github.com/huggingface/transformers/blob/master/src/transformers/modeling_utils.py#L205

So my proposal might be confusing.

Should I call it dtype_str perhaps?

Copy link
Contributor

@patrickvonplaten patrickvonplaten Aug 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If no version with torch_dtype has been released yet, I'm fine with changing it to dtype. However, note that in Flax we already have a dtype variable that is used to define the dtype the matmul operations are run in instead of the dtype of the actual weights. In Flax we would like to take this design: #13098 as outlined by @patil-suraj

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stas00 regarding the PR: #13098 - the idea of the PR is exactly to disentangle parameter dtype from matmul/computation dtype. In Flax, it's common practice that the dtype parameter defines the matmul/computation dtype, see: https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.Dense.html#flax.linen.Dense.dtype instead of the parameter dtype and not the parameter dtype.

So for Flax, I don't really think it would make sense to use a config.dtype to define weights dtype as it would be quite confusing with Flax's computation dtype parameter.

Copy link
Contributor Author

@stas00 stas00 Aug 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would feel more comfortable to merge this PR without breaking changes and keeping the parameter called torch_dtype.

Works for me. Although this API is somewhat bad at the moment due to inconsistent type of values in the config file and the function - the former a string, the latter a torch.dtype. Perhaps I can change these to support both string "float32" and torch.dtype in the same param of the function.

Think it's worth to have a separate discussion here regarding a framework-agnostic dtype parameter for all PyTorch, Tensorflow, and Flax once @sgugger is back.

Agreed!

Copy link
Contributor Author

@stas00 stas00 Aug 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Flax, it's common practice that the dtype parameter defines the matmul/computation dtype, see: https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.Dense.html#flax.linen.Dense.dtype instead of the parameter dtype and not the parameter dtype.

So it's somewhat similar to dtype arg in the new pytorch autocast feature it seems then, correct? (before it was a hardcoded fp16, but now it has a dtype arg to support bf16 too.)

p.s. it's currently called fast_dtype but will renamed shortly to dtype in pt-nightly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me to keep torch_dtype and have a separate discussion for the framework-agnostic dtype parameter, to which torch_dtype could be an alias to prevent breaking changes to the existing API.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds like a plan, @LysandreJik

Issue created: #13246

@stas00 stas00 changed the title [WIP] fix AutoModel.from_pretrained(..., torch_dtype=...) fix AutoModel.from_pretrained(..., torch_dtype=...) Aug 21, 2021
Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me!

Would actually be nice to have this merged soon to allow GPT-J to be loaded with AutoModelForCausalLM: #13022

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for working on it @stas00

Comment on lines +280 to +283
if is_torch_available():
import torch

self.torch_dtype = getattr(torch, self.torch_dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't think of a scenario where one would want one dtype for one framework and another dtype for another - so changing it to dtype sounds good to me.

@patrickvonplaten patrickvonplaten merged commit 5c6eca7 into huggingface:master Aug 24, 2021
@stas00 stas00 deleted the pretrained_torch_dtype branch August 24, 2021 14:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

respect dtype of the the model when instiating not working
3 participants