Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix AutoModel.from_pretrained(..., torch_dtype=...) #13209

Merged
merged 5 commits into from
Aug 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
hf_bucket_url,
is_offline_mode,
is_remote_url,
is_torch_available,
)
from .utils import logging

Expand Down Expand Up @@ -207,6 +208,9 @@ class PretrainedConfig(PushToHubMixin):
this attribute contains just the floating type string without the ``torch.`` prefix. For example, for
``torch.float16`` ``torch_dtype`` is the ``"float16"`` string.

This attribute is currently not being used during model loading time, but this may change in the future
versions. But we can already start preparing for the future by saving the dtype with save_pretrained.

TensorFlow specific parameters

- **use_bfloat16** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether or not the model should use
Expand Down Expand Up @@ -270,6 +274,14 @@ def __init__(self, **kwargs):
else:
self.num_labels = kwargs.pop("num_labels", 2)

if self.torch_dtype is not None and isinstance(self.torch_dtype, str):
# we will start using self.torch_dtype in v5, but to be consistent with
# from_pretrained's torch_dtype arg convert it to an actual torch.dtype object
if is_torch_available():
import torch

self.torch_dtype = getattr(torch, self.torch_dtype)
Comment on lines +280 to +283
Copy link
Contributor Author

@stas00 stas00 Aug 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, so torch is not always available when config.torch_dtype is not None - so now config.torch_dtype isn't always of torch.dtype.

I'm thinking perhaps this whole approach needs to be redone and only use the "float32", "float16", etc. strings everywhere, including the torch_dtype arg in from_pretrained and from_config args. And only convert to torch.dtype at the point it's used when the model is loaded.

That way torch_dtype doesn't need to have a special handling at config level.

I hope this is recent/experimental enough that it's ok that we break the API.

Actually, if we do that, why even bother with torch_ in torch_dtype and not just rename it to dtype - perhaps non-pt frameworks could tap into it as well? After all fp16-saved data by torch isn't any different from flux or tf, no?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't think of a scenario where one would want one dtype for one framework and another dtype for another - so changing it to dtype sounds good to me.

Copy link
Contributor Author

@stas00 stas00 Aug 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the main concern is the back-compat in the API arg torch_dtype - if it's OK if we break it, then I propose both the config and the arg in from_pretrained and from_config to be just a dtype as a string: "auto", "float32", "float16", etc.

And then in the case of torch we convert it to the right torch.dtype on the fly. perhaps flux/tf could use this too down the road.

Sylvain is not here for another week. Do you both support this breaking API change, @LysandreJik and @patrickvonplaten?

So instead of:

            torch_dtype (:obj:`str` or :obj:`torch.dtype`, `optional`):
                Override the default ``torch.dtype`` and load the model under this dtype. If ``"auto"`` is passed the
                dtype will be automatically derived from the model's weights.

It will be:

            dtype (:obj:`str`, `optional`):
                Override the default ``dtype`` and load the model under this dtype, (e.g., ``"float16"``). 
                If ``"auto"`` is passed the
                dtype will be automatically derived from the model's weights.

Copy link
Contributor Author

@stas00 stas00 Aug 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the other hand we already have the dtype attribute in modeling_utils, which returns torch.dtype

https://github.com/huggingface/transformers/blob/master/src/transformers/modeling_utils.py#L205

So my proposal might be confusing.

Should I call it dtype_str perhaps?

Copy link
Contributor

@patrickvonplaten patrickvonplaten Aug 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If no version with torch_dtype has been released yet, I'm fine with changing it to dtype. However, note that in Flax we already have a dtype variable that is used to define the dtype the matmul operations are run in instead of the dtype of the actual weights. In Flax we would like to take this design: #13098 as outlined by @patil-suraj

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stas00 regarding the PR: #13098 - the idea of the PR is exactly to disentangle parameter dtype from matmul/computation dtype. In Flax, it's common practice that the dtype parameter defines the matmul/computation dtype, see: https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.Dense.html#flax.linen.Dense.dtype instead of the parameter dtype and not the parameter dtype.

So for Flax, I don't really think it would make sense to use a config.dtype to define weights dtype as it would be quite confusing with Flax's computation dtype parameter.

Copy link
Contributor Author

@stas00 stas00 Aug 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would feel more comfortable to merge this PR without breaking changes and keeping the parameter called torch_dtype.

Works for me. Although this API is somewhat bad at the moment due to inconsistent type of values in the config file and the function - the former a string, the latter a torch.dtype. Perhaps I can change these to support both string "float32" and torch.dtype in the same param of the function.

Think it's worth to have a separate discussion here regarding a framework-agnostic dtype parameter for all PyTorch, Tensorflow, and Flax once @sgugger is back.

Agreed!

Copy link
Contributor Author

@stas00 stas00 Aug 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Flax, it's common practice that the dtype parameter defines the matmul/computation dtype, see: https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.Dense.html#flax.linen.Dense.dtype instead of the parameter dtype and not the parameter dtype.

So it's somewhat similar to dtype arg in the new pytorch autocast feature it seems then, correct? (before it was a hardcoded fp16, but now it has a dtype arg to support bf16 too.)

p.s. it's currently called fast_dtype but will renamed shortly to dtype in pt-nightly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me to keep torch_dtype and have a separate discussion for the framework-agnostic dtype parameter, to which torch_dtype could be an alias to prevent breaking changes to the existing API.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds like a plan, @LysandreJik

Issue created: #13246


# Tokenizer arguments TODO: eventually tokenizer and models should share the same config
self.tokenizer_class = kwargs.pop("tokenizer_class", None)
self.prefix = kwargs.pop("prefix", None)
Expand Down Expand Up @@ -574,7 +586,8 @@ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig":
for key, value in kwargs.items():
if hasattr(config, key):
setattr(config, key, value)
to_remove.append(key)
if key != "torch_dtype":
to_remove.append(key)
for key in to_remove:
kwargs.pop(key, None)

Expand Down Expand Up @@ -640,6 +653,8 @@ def to_diff_dict(self) -> Dict[str, Any]:
):
serializable_config_dict[key] = value

self.dict_torch_dtype_to_str(serializable_config_dict)

return serializable_config_dict

def to_dict(self) -> Dict[str, Any]:
Expand All @@ -656,6 +671,8 @@ def to_dict(self) -> Dict[str, Any]:
# Transformers version when serializing the model
output["transformers_version"] = __version__

self.dict_torch_dtype_to_str(output)

return output

def to_json_string(self, use_diff: bool = True) -> str:
Expand Down Expand Up @@ -738,6 +755,15 @@ def update_from_string(self, update_str: str):

setattr(self, k, v)

def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None:
"""
Checks whether the passed dictionary has a `torch_dtype` key and if it's not None, converts torch.dtype to a
string of just the type. For example, :obj:`torch.float32` get converted into `"float32"` string, which can
then be stored in the json format.
"""
if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str):
d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]


PretrainedConfig.push_to_hub = copy_func(PretrainedConfig.push_to_hub)
PretrainedConfig.push_to_hub.__doc__ = PretrainedConfig.push_to_hub.__doc__.format(
Expand Down
22 changes: 19 additions & 3 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import copy
import gc
import inspect
import json
import os.path
import random
import tempfile
Expand Down Expand Up @@ -1663,9 +1664,11 @@ def test_model_from_config_torch_dtype(self):
@require_torch
def test_model_from_pretrained_torch_dtype(self):
# test that the model can be instantiated with dtype of either
# 1. config.torch_dtype setting in the saved model (priority)
# 2. via autodiscovery by looking at model weights
# 1. explicit from_pretrained's torch_dtype argument
# 2. via autodiscovery by looking at model weights (torch_dtype="auto")
# so if a model.half() was saved, we want it to be instantiated as such.
#
# test an explicit model class, but also AutoModel separately as the latter goes through a different code path
model_path = self.get_auto_remove_tmp_dir()

# baseline - we know TINY_T5 is fp32 model
Expand All @@ -1688,13 +1691,26 @@ def test_model_from_pretrained_torch_dtype(self):
model = model.half()
model.save_pretrained(model_path)
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
self.assertEqual(model.config.torch_dtype, "float16") # tests `config.torch_dtype` saving
self.assertEqual(model.config.torch_dtype, torch.float16)
self.assertEqual(model.dtype, torch.float16)

# tests `config.torch_dtype` saving
with open(f"{model_path}/config.json") as f:
config_dict = json.load(f)
self.assertEqual(config_dict["torch_dtype"], "float16")

# test fp16 save_pretrained, loaded with the explicit fp16
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16)
self.assertEqual(model.dtype, torch.float16)

# test AutoModel separately as it goes through a different path
# test auto-detection
model = AutoModel.from_pretrained(TINY_T5, torch_dtype="auto")
self.assertEqual(model.dtype, torch.float32)
# test forcing an explicit dtype
model = AutoModel.from_pretrained(TINY_T5, torch_dtype=torch.float16)
self.assertEqual(model.dtype, torch.float16)


@require_torch
@is_staging_test
Expand Down