Skip to content

Commit

Permalink
[models] respect dtype of the model when instantiating it (#12316)
Browse files Browse the repository at this point in the history
* [models] respect dtype of the model when instantiating it

* cleanup

* cleanup

* rework to handle non-float dtype

* fix

* switch to fp32 tiny model

* improve

* use dtype.is_floating_point

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* fix the doc

* recode to use explicit torch_dtype_auto_detect, torch_dtype args

* docs and tweaks

* docs and tweaks

* docs and tweaks

* merge 2 args, add docs

* fix

* fix

* better doc

* better doc

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
stas00 and sgugger authored Jun 29, 2021
1 parent 31c3e7e commit 7682e97
Show file tree
Hide file tree
Showing 8 changed files with 222 additions and 27 deletions.
2 changes: 2 additions & 0 deletions docs/source/main_classes/deepspeed.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1549,6 +1549,8 @@ Note: If the fp16 weights of the model can't fit onto the memory of a single GPU
For full details on this method and other related features please refer to `Constructing Massive Models
<https://deepspeed.readthedocs.io/en/latest/zero3.html#constructing-massive-models>`__.

Also when loading fp16-pretrained models, you will want to tell ``from_pretrained`` to use
``torch_dtype=torch.float16``. For details, please, see :ref:`from_pretrained-torch-dtype`.


Gathering Parameters
Expand Down
33 changes: 32 additions & 1 deletion docs/source/main_classes/model.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
..
..
Copyright 2020 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
Expand Down Expand Up @@ -38,6 +38,37 @@ PreTrainedModel
:members:


.. _from_pretrained-torch-dtype:

Model Instantiation dtype
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Under Pytorch a model normally gets instantiated with ``torch.float32`` format. This can be an issue if one tries to
load a model whose weights are in fp16, since it'd require twice as much memory. To overcome this limitation, you can
either explicitly pass the desired ``dtype`` using ``torch_dtype`` argument:

.. code-block:: python
model = T5ForConditionalGeneration.from_pretrained("t5", torch_dtype=torch.float16)
or, if you want the model to always load in the most optimal memory pattern, you can use the special value ``"auto"``,
and then ``dtype`` will be automatically derived from the model's weights:

.. code-block:: python
model = T5ForConditionalGeneration.from_pretrained("t5", torch_dtype="auto")
Models instantiated from scratch can also be told which ``dtype`` to use with:

.. code-block:: python
config = T5Config.from_pretrained("t5")
model = AutoModel.from_config(config)
Due to Pytorch design, this functionality is only available for floating dtypes.



ModuleUtilsMixin
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
7 changes: 7 additions & 0 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,12 @@ class PretrainedConfig(PushToHubMixin):
- **tie_word_embeddings** (:obj:`bool`, `optional`, defaults to :obj:`True`) -- Whether the model's input and
output word embeddings should be tied. Note that this is only relevant if the model has a output word
embedding layer.
- **torch_dtype** (:obj:`str`, `optional`) -- The :obj:`dtype` of the weights. This attribute can be used to
initialize the model to a non-default ``dtype`` (which is normally ``float32``) and thus allow for optimal
storage allocation. For example, if the saved model is ``float16``, ideally we want to load it back using the
minimal amount of memory needed to load ``float16`` weights. Since the config object is stored in plain text,
this attribute contains just the floating type string without the ``torch.`` prefix. For example, for
``torch.float16`` ``torch_dtype`` is the ``"float16"`` string.
TensorFlow specific parameters
Expand All @@ -207,6 +213,7 @@ def __init__(self, **kwargs):
self.output_hidden_states = kwargs.pop("output_hidden_states", False)
self.output_attentions = kwargs.pop("output_attentions", False)
self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
self.torch_dtype = kwargs.pop("torch_dtype", None) # Only used by PyTorch models
self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
self.pruned_heads = kwargs.pop("pruned_heads", {})
self.tie_word_embeddings = kwargs.pop(
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,13 @@ def __init__(
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> Dict:
raise NotImplementedError(f"init method has to be implemented for {self}")

@classmethod
def _from_config(cls, config, **kwargs):
"""
All context managers that the model should be initialized under go here.
"""
return cls(config, **kwargs)

@property
def config(self) -> PretrainedConfig:
return self._config
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,13 @@ def __init__(self, config, *inputs, **kwargs):
self.config = config
self.name_or_path = config.name_or_path

@classmethod
def _from_config(cls, config, **kwargs):
"""
All context managers that the model should be initialized under go here.
"""
return cls(config, **kwargs)

@tf.function(
input_signature=[
{
Expand Down
120 changes: 107 additions & 13 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union

import torch
from torch import Tensor, device, dtype, nn
from torch import Tensor, device, nn
from torch.nn import CrossEntropyLoss

from .activations import get_activation
Expand Down Expand Up @@ -201,7 +201,7 @@ def device(self) -> device:
return get_parameter_device(self)

@property
def dtype(self) -> dtype:
def dtype(self) -> torch.dtype:
"""
:obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
"""
Expand Down Expand Up @@ -464,6 +464,66 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
self.config = config
self.name_or_path = config.name_or_path

@classmethod
def _from_config(cls, config, **kwargs):
"""
All context managers that the model should be initialized under go here.
Args:
torch_dtype (:obj:`torch.dtype`, `optional`):
Override the default ``torch.dtype`` and load the model under this dtype.
"""
torch_dtype = kwargs.pop("torch_dtype", None)

# override default dtype if needed
dtype_orig = None
if torch_dtype is not None:
dtype_orig = cls._set_default_torch_dtype(torch_dtype)

if is_deepspeed_zero3_enabled():
import deepspeed

logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
# this immediately partitions the model across all gpus, to avoid the overhead in time
# and memory copying it on CPU or each GPU first
with deepspeed.zero.Init(config=deepspeed_config()):
model = cls(config, **kwargs)
else:
model = cls(config, **kwargs)

# restore default dtype if it was modified
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)

return model

@classmethod
def _set_default_torch_dtype(cls, dtype: torch.dtype) -> torch.dtype:
"""
Change the default dtype and return the previous one. This is needed when wanting to instantiate the model
under specific dtype.
Args:
dtype (:obj:`torch.dtype`):
a floating dtype to set to.
Returns:
:obj:`torch.dtype`: the original ``dtype`` that can be used to restore ``torch.set_default_dtype(dtype)``
if it was modified. If it wasn't, returns :obj:`None`.
Note ``set_default_dtype`` currently only works with floating-point types and asserts if for example,
``torch.int64`` is passed. So if a non-float ``dtype`` is passed this functions will throw an exception.
"""
if not dtype.is_floating_point:
raise ValueError(
f"Can't instantiate {cls.__name__} model under dtype={dtype} since it is not a floating point dtype"
)

logger.info(f"Instantiating {cls.__name__} model under default dtype {dtype}.")
dtype_orig = torch.get_default_dtype()
torch.set_default_dtype(dtype)
return dtype_orig

@property
def base_model(self) -> nn.Module:
"""
Expand Down Expand Up @@ -876,6 +936,11 @@ def save_pretrained(
# Only save the model itself if we are using distributed training
model_to_save = unwrap_model(self)

# save the string version of dtype to the config, e.g. convert torch.float32 => "float32"
# we currently don't use this setting automatically, but may start to use with v5
dtype = get_parameter_dtype(model_to_save)
model_to_save.config.torch_dtype = str(dtype).split(".")[1]

# Attach architecture to the config
model_to_save.config.architectures = [model_to_save.__class__.__name__]

Expand Down Expand Up @@ -993,6 +1058,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
Please refer to the mirror site for more information.
_fast_init(:obj:`bool`, `optional`, defaults to `:obj:`True`):
Whether or not to disable fast initialization.
torch_dtype (:obj:`str` or :obj:`torch.dtype`, `optional`):
Override the default ``torch.dtype`` and load the model under this dtype. If ``"auto"`` is passed the
dtype will be automatically derived from the model's weights.
.. warning::
Expand Down Expand Up @@ -1058,6 +1126,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)
_fast_init = kwargs.pop("_fast_init", True)
torch_dtype = kwargs.pop("torch_dtype", None)

from_pt = not (from_tf | from_flax)

user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
if from_pipeline is not None:
Expand Down Expand Up @@ -1162,6 +1233,34 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
else:
resolved_archive_file = None

# load pt weights early so that we know which dtype to init the model under
if from_pt:
if state_dict is None:
try:
state_dict = torch.load(resolved_archive_file, map_location="cpu")
except Exception:
raise OSError(
f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' "
f"at '{resolved_archive_file}'"
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. "
)

# set dtype to instantiate the model under:
# 1. If torch_dtype is not None, we use that dtype
# 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
# weights entry - we assume all weights are of the same dtype
# we also may have config.torch_dtype available, but we won't rely on it till v5
dtype_orig = None
if torch_dtype is not None:
if isinstance(torch_dtype, str):
if torch_dtype == "auto":
torch_dtype = next(iter(state_dict.values())).dtype
else:
raise ValueError(
f"`torch_dtype` can be either a `torch.dtype` or `auto`, but received {torch_dtype}"
)
dtype_orig = cls._set_default_torch_dtype(torch_dtype)

config.name_or_path = pretrained_model_name_or_path

# Instantiate model.
Expand All @@ -1178,6 +1277,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
with no_init_weights(_enable=_fast_init):
model = cls(config, *model_args, **model_kwargs)

if from_pt:
# restore default dtype
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)

if from_tf:
if resolved_archive_file.endswith(".index"):
# Load from a TensorFlow 1.X checkpoint - provided by original authors
Expand Down Expand Up @@ -1205,17 +1309,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
"https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation instructions."
)
raise
else:
if state_dict is None:
try:
state_dict = torch.load(resolved_archive_file, map_location="cpu")
except Exception:
raise OSError(
f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' "
f"at '{resolved_archive_file}'"
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. "
)

elif from_pt:
model, missing_keys, unexpected_keys, error_msgs = cls._load_state_dict_into_model(
model, state_dict, pretrained_model_name_or_path, _fast_init=_fast_init
)
Expand Down
13 changes: 2 additions & 11 deletions src/transformers/models/auto/auto_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import types

from ...configuration_utils import PretrainedConfig
from ...deepspeed import deepspeed_config, is_deepspeed_zero3_enabled
from ...file_utils import copy_func
from ...utils import logging
from .configuration_auto import AutoConfig, replace_list_option_in_docstrings
Expand Down Expand Up @@ -367,16 +366,8 @@ def __init__(self, *args, **kwargs):
def from_config(cls, config, **kwargs):
if type(config) in cls._model_mapping.keys():
model_class = _get_model_class(config, cls._model_mapping)
if is_deepspeed_zero3_enabled():
import deepspeed

logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
# this immediately partitions the model across all gpus, to avoid the overhead in time
# and memory copying it on CPU or each GPU first
with deepspeed.zero.Init(config=deepspeed_config()):
return model_class(config, **kwargs)
else:
return model_class(config, **kwargs)
return model_class._from_config(config, **kwargs)

raise ValueError(
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
Expand Down
Loading

0 comments on commit 7682e97

Please sign in to comment.