Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[models] respect dtype of the model when instantiating it #12316

Merged
merged 20 commits into from
Jun 29, 2021
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,12 @@ class PretrainedConfig(PushToHubMixin):
- **tie_word_embeddings** (:obj:`bool`, `optional`, defaults to :obj:`True`) -- Whether the model's input and
output word embeddings should be tied. Note that this is only relevant if the model has a output word
embedding layer.
- **torch_dtype** (:obj:`str`, `optional`) -- The :obj:`dtype` of the weights. This attribute is used to
initialize the model to a non-default ``dtype`` (which is normally ``float32``) and thus allow for optimal
storage allocation. For example, if the saved model is ``float16``, we want to load it back using the minimal
amount of memory needed to load ``float16`` weights. Since the config object is stored in plain text, this
attribute contains just the floating type string without the ``torch\.`` prefix. For example, for
``torch.float16`` ``torch_dtype`` is the ``"float16"`` string.

TensorFlow specific parameters

Expand All @@ -207,6 +213,7 @@ def __init__(self, **kwargs):
self.output_hidden_states = kwargs.pop("output_hidden_states", False)
self.output_attentions = kwargs.pop("output_attentions", False)
self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
self.torch_dtype = kwargs.pop("torch_dtype", None) # Only used by PyTorch models
self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
self.pruned_heads = kwargs.pop("pruned_heads", {})
self.tie_word_embeddings = kwargs.pop(
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,13 @@ def __init__(
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> Dict:
raise NotImplementedError(f"init method has to be implemented for {self}")

@classmethod
def _from_config(cls, config, **kwargs):
"""
All context managers that the model should be initialized under go here.
"""
return cls(config, **kwargs)

@property
def config(self) -> PretrainedConfig:
return self._config
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,13 @@ def __init__(self, config, *inputs, **kwargs):
self.config = config
self.name_or_path = config.name_or_path

@classmethod
def _from_config(cls, config, **kwargs):
"""
All context managers that the model should be initialized under go here.
"""
return cls(config, **kwargs)

@tf.function(
input_signature=[
{
Expand Down
118 changes: 105 additions & 13 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union

import torch
from torch import Tensor, device, dtype, nn
from torch import Tensor, device, nn
from torch.nn import CrossEntropyLoss

from .activations import get_activation
Expand Down Expand Up @@ -202,7 +202,7 @@ def device(self) -> device:
return get_parameter_device(self)

@property
def dtype(self) -> dtype:
def dtype(self) -> torch.dtype:
"""
:obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
"""
Expand Down Expand Up @@ -465,6 +465,66 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
self.config = config
self.name_or_path = config.name_or_path

@classmethod
def _from_config(cls, config, **kwargs):
"""
All context managers that the model should be initialized under go here.
"""

torch_dtype = kwargs.pop("torch_dtype", None)

# override default dtype if needed
dtype_orig = None
if torch_dtype is not None:
dtype_orig = cls._set_default_dtype(torch_dtype)

if is_deepspeed_zero3_enabled():
import deepspeed

logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
# this immediately partitions the model across all gpus, to avoid the overhead in time
# and memory copying it on CPU or each GPU first
with deepspeed.zero.Init(config=deepspeed_config()):
model = cls(config, **kwargs)
else:
model = cls(config, **kwargs)

# restore default dtype if it was modified
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)

return model
stas00 marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def _set_default_dtype(cls, dtype: torch.dtype) -> torch.dtype:
"""
If a floating ``dtype`` is passed via the model config or one of its weights ``torch.set_default_dtype`` is
used to change the global dtype. Which is needed when wanting to instantiate the model under specific dtype.

Args:
config_dtype_str (:obj:`str`):
value of ``config.torch_dtype``
weight_dtype (:obj:`torch.dtype`, `optional`):
``dtype`` of one of the pretrained weights

Returns:
:obj:`torch.dtype`: the original ``dtype`` that can be used to restore ``torch.set_default_dtype(dtype)``
if it was modified. If it wasn't, returns :obj:`None`.

Note ``set_default_dtype`` currently only works with floating-point types and asserts if for example,
``torch.int64`` is passed. So if a non-float ``dtype`` is passed we don't do anything other than logging a
warning that the non-float dtype was ignored.
"""
if not dtype.is_floating_point:
raise ValueError(
f"Can't instantiate {cls.__name__} model under dtype={dtype} since it is not a floating point dtype"
)

logger.info(f"Instantiating {cls.__name__} model under default dtype {dtype}.")
dtype_orig = torch.get_default_dtype()
torch.set_default_dtype(dtype)
return dtype_orig

@property
def base_model(self) -> nn.Module:
"""
Expand Down Expand Up @@ -864,6 +924,11 @@ def save_pretrained(
# Only save the model itself if we are using distributed training
model_to_save = unwrap_model(self)

# save the string version of dtype to the config, e.g. convert torch.float32 => "float32"
# so in from_config and from_pretrained we reverse this with getattr(torch, "float32")
dtype = get_parameter_dtype(model_to_save)
model_to_save.config.torch_dtype = str(dtype).split(".")[1]
stas00 marked this conversation as resolved.
Show resolved Hide resolved

# Attach architecture to the config
model_to_save.config.architectures = [model_to_save.__class__.__name__]

Expand Down Expand Up @@ -1049,6 +1114,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)
_fast_init = kwargs.pop("_fast_init", True)
torch_dtype = kwargs.pop("torch_dtype", None)
torch_dtype_auto_detect = kwargs.pop("torch_dtype_auto_detect", False)
stas00 marked this conversation as resolved.
Show resolved Hide resolved
stas00 marked this conversation as resolved.
Show resolved Hide resolved

from_pt = not (from_tf | from_flax)

user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
if from_pipeline is not None:
Expand Down Expand Up @@ -1153,6 +1222,34 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
else:
resolved_archive_file = None

# load pt weights early so that we know which dtype to init the model under
if from_pt:
if state_dict is None:
try:
state_dict = torch.load(resolved_archive_file, map_location="cpu")
except Exception:
raise OSError(
f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' "
f"at '{resolved_archive_file}'"
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. "
)

# set dtype to instantiate the model under:
# 1. If torch_dtype is not None, we use that dtype
# 2. If torch_dtype_auto_detect is True, we auto-detect it from the loaded state_dict, by checking the first
# entry - we assume all weights are of the same dtype
# we also may have config.torch_dtype but we won't rely on it till v5
dtype_orig = None
if torch_dtype_auto_detect:
if torch_dtype is None:
torch_dtype = next(iter(state_dict.values())).dtype
else:
raise ValueError(
"ambiguous arguments passed non-None ``torch_dtype`` and ``torch_dtype_auto_detect=True`` at the same time"
)
if torch_dtype is not None:
dtype_orig = cls._set_default_dtype(torch_dtype)

config.name_or_path = pretrained_model_name_or_path

# Instantiate model.
Expand All @@ -1169,6 +1266,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
with no_init_weights(_enable=_fast_init):
model = cls(config, *model_args, **model_kwargs)

if from_pt:
# restore default dtype
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)

if from_tf:
if resolved_archive_file.endswith(".index"):
# Load from a TensorFlow 1.X checkpoint - provided by original authors
Expand Down Expand Up @@ -1196,17 +1298,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
"https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation instructions."
)
raise
else:
if state_dict is None:
try:
state_dict = torch.load(resolved_archive_file, map_location="cpu")
except Exception:
raise OSError(
f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' "
f"at '{resolved_archive_file}'"
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. "
)

elif from_pt:
model, missing_keys, unexpected_keys, error_msgs = cls._load_state_dict_into_model(
model, state_dict, pretrained_model_name_or_path, _fast_init=_fast_init
)
Expand Down
13 changes: 2 additions & 11 deletions src/transformers/models/auto/auto_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import types

from ...configuration_utils import PretrainedConfig
from ...deepspeed import deepspeed_config, is_deepspeed_zero3_enabled
from ...file_utils import copy_func
from ...utils import logging
from .configuration_auto import AutoConfig, replace_list_option_in_docstrings
Expand Down Expand Up @@ -367,16 +366,8 @@ def __init__(self, *args, **kwargs):
def from_config(cls, config, **kwargs):
if type(config) in cls._model_mapping.keys():
model_class = _get_model_class(config, cls._model_mapping)
if is_deepspeed_zero3_enabled():
import deepspeed

logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
# this immediately partitions the model across all gpus, to avoid the overhead in time
# and memory copying it on CPU or each GPU first
with deepspeed.zero.Init(config=deepspeed_config()):
return model_class(config, **kwargs)
else:
return model_class(config, **kwargs)
return model_class._from_config(config, **kwargs)

raise ValueError(
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
Expand Down
60 changes: 58 additions & 2 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@

from huggingface_hub import HfApi
from requests.exceptions import HTTPError
from transformers import is_torch_available, logging
from transformers import AutoModel, is_torch_available, logging
from transformers.file_utils import WEIGHTS_NAME, is_torch_fx_available
from transformers.models.auto import get_values
from transformers.testing_utils import (
ENDPOINT_STAGING,
PASS,
USER,
CaptureLogger,
TestCasePlus,
is_staging_test,
require_torch,
require_torch_multi_gpu,
Expand Down Expand Up @@ -63,6 +64,7 @@
BertModel,
PretrainedConfig,
PreTrainedModel,
T5Config,
T5ForConditionalGeneration,
)

Expand Down Expand Up @@ -1553,7 +1555,7 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None):


@require_torch
class ModelUtilsTest(unittest.TestCase):
class ModelUtilsTest(TestCasePlus):
@slow
def test_model_from_pretrained(self):
for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
Expand Down Expand Up @@ -1586,6 +1588,60 @@ def test_model_from_pretrained_with_different_pretrained_model_name(self):
BertModel.from_pretrained(TINY_T5)
self.assertTrue("You are using a model of type t5 to instantiate a model of type bert" in cl.out)

@require_torch
def test_model_from_config_torch_dtype(self):
# test that the model can be instantiated with dtype of user's choice - as long as it's a
# float dtype. To make it happen config.torch_dtype needs to be set before instantiating the
# model from the config object.

config = T5Config.from_pretrained(TINY_T5)
model = AutoModel.from_config(config)
# XXX: isn't supported
# model = T5ForConditionalGeneration.from_config(config)
self.assertEqual(model.dtype, torch.float32)

model = AutoModel.from_config(config, torch_dtype=torch.float16)
self.assertEqual(model.dtype, torch.float16)

# torch.set_default_dtype() supports only float dtypes, so will fail with non-float type
with self.assertRaises(ValueError):
model = AutoModel.from_config(config, torch_dtype=torch.int64)

@require_torch
def test_model_from_pretrained_torch_dtype(self):
# test that the model can be instantiated with dtype of either
# 1. config.torch_dtype setting in the saved model (priority)
# 2. via autodiscovery by looking at model weights
# so if a model.half() was saved, we want it to be instantiated as such.
model_path = self.get_auto_remove_tmp_dir()

# baseline - we know TINY_T5 is fp32 model
model = T5ForConditionalGeneration.from_pretrained(TINY_T5)
self.assertEqual(model.dtype, torch.float32)

# test the default fp32 save_pretrained => from_pretrained cycle
model.save_pretrained(model_path)
model = T5ForConditionalGeneration.from_pretrained(model_path)
self.assertEqual(model.dtype, torch.float32)
# test with auto-detection
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype_auto_detect=True)
self.assertEqual(model.dtype, torch.float32)

# test forced loading in fp16 (even though the weights are in fp32)
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16)
self.assertEqual(model.dtype, torch.float16)

# test fp16 save_pretrained, loaded with auto-detection
model = model.half()
model.save_pretrained(model_path)
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype_auto_detect=True)
self.assertEqual(model.config.torch_dtype, "float16") # also test `config.torch_dtype` saving
self.assertEqual(model.dtype, torch.float16)

# test fp16 save_pretrained, loaded with the explicit fp16
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16)
self.assertEqual(model.dtype, torch.float16)


@require_torch
@is_staging_test
Expand Down