Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[modeling utils] revamp from_pretrained(..., low_cpu_mem_usage=True) + tests #16657

Merged
merged 17 commits into from
Apr 15, 2022
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ jobs:
keys:
- v0.4-torch-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }}
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev espeak-ng
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev espeak-ng time
- run: pip install --upgrade pip
- run: pip install .[sklearn,torch,testing,sentencepiece,torch-speech,vision,timm]
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.11.0+cpu.html
Expand Down
246 changes: 147 additions & 99 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,95 @@ def load(module: nn.Module, prefix=""):
return error_msgs


def find_submodule_and_param_name(model, long_key, start_prefix):
"""
A helper util to find the last sub-module and the param/buffer name. If `start_prefix` is supplied it'll removed
from the start of the key
"""

if len(start_prefix) > 0 and long_key.startswith(start_prefix):
long_key = ".".join(long_key.split(".")[1:])

split_key = long_key.split(".")
submodule = model
while len(split_key) > 1:
if hasattr(submodule, split_key[0]):
submodule = getattr(submodule, split_key[0])
del split_key[0]
else:
submodule = None
break
if submodule == model:
submodule = None
return submodule, split_key[0]


def _move_model_to_meta(model, loaded_state_dict_keys, start_prefix):
"""
Moves `loaded_state_dict_keys` in model to meta device which frees up the memory taken by those params.

`start_prefix` is used for models which insert their name into model keys, e.g. `bert` in
`bert.pooler.dense.weight`

"""

# meta device was added in pt=1.9
require_version_core("torch>=1.9")

# dematerialize param storage for keys that are going to be replaced by state_dict, by
# putting those on the meta device
for k in loaded_state_dict_keys:
submodule, param_name = find_submodule_and_param_name(model, k, start_prefix)
if submodule is not None:
# selectively switch to the meta device only those params/buffers that will
# be next replaced from state_dict. This a complex way to do p.to_("meta")
# since we have no in-place to_ for tensors.
new_val = getattr(submodule, param_name)
if isinstance(new_val, torch.nn.Parameter):
# isinstance returns False for Params on meta device, so switch after the check
new_val = torch.nn.Parameter(new_val.to("meta"))
else:
new_val = new_val.to("meta")
setattr(submodule, param_name, new_val)


def _load_state_dict_into_meta_model(model, state_dict, start_prefix, loaded_state_dict_keys):
"""
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
params on a `meta` device. It replaces the model params with the data from the `state_dict`, while moving the
params back to the normal device, but only for `loaded_state_dict_keys`.

`start_prefix` is used for models which insert their name into model keys, e.g. `bert` in
`bert.pooler.dense.weight`

"""

# XXX: remaining features to implement to be fully compatible with _load_state_dict_into_model
# - deepspeed zero 3 support
# - need to copy metadata if any - see _load_state_dict_into_model
# - handling error_msgs - mimicking the error handling in module._load_from_state_dict()
# - Is there a situation where some keys aren't in `loaded_state_dict_keys` and in which case
# they won't get loaded.

if is_deepspeed_zero3_enabled():
raise ValueError("low_cpu_mem_usage arg cannot currently be used with DeepSpeed ZeRO-3")

error_msgs = []

# materialize state_dict entries one by one on CPU
for k in loaded_state_dict_keys:
if k in state_dict:
submodule, param_name = find_submodule_and_param_name(model, k, start_prefix)
if submodule is not None:
param_dtype = getattr(submodule, param_name).dtype
new_val = state_dict[k].to(param_dtype)
if isinstance(getattr(submodule, param_name), torch.nn.Parameter):
new_val = torch.nn.Parameter(new_val)
setattr(submodule, param_name, new_val)

return error_msgs


class ModuleUtilsMixin:
"""
A few utilities for `torch.nn.Modules`, to be used as a mixin.
Expand Down Expand Up @@ -1528,7 +1617,24 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
>>> model = BertModel.from_pretrained("./tf_model/my_tf_checkpoint.ckpt.index", from_tf=True, config=config)
>>> # Loading from a Flax checkpoint file instead of a PyTorch model (slower)
>>> model = BertModel.from_pretrained("bert-base-uncased", from_flax=True)
```"""
```

* `low_cpu_mem_usage` algorithm:

This is an experimental function that loads the model using ~1.x model size CPU memory

Before it gets called we do:

1. save which state_dict keys we have
2. drop state_dict before the model is created, since the latter takes 1x model size CPU memory
3. after the model has been instantiated switch to the meta device all params/buffers that
are going to be replaced from the loaded state_dict
4. load state_dict 2nd time
5. replace the params/buffers from the state_dict

Currently, it can't handle deepspeed ZeRO stage 3 and loading errors

"""
config = kwargs.pop("config", None)
state_dict = kwargs.pop("state_dict", None)
cache_dir = kwargs.pop("cache_dir", None)
Expand Down Expand Up @@ -1777,6 +1883,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
if not is_sharded and state_dict is None:
# Time to load the checkpoint
state_dict = load_state_dict(resolved_archive_file)

# set dtype to instantiate the model under:
# 1. If torch_dtype is not None, we use that dtype
# 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
Expand All @@ -1800,13 +1907,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
)
dtype_orig = cls._set_default_torch_dtype(torch_dtype)

if is_sharded:
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
else:
loaded_state_dict_keys = [k for k in state_dict.keys()]
if low_cpu_mem_usage:
# save the keys
if is_sharded:
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
else:
loaded_state_dict_keys = [k for k in state_dict.keys()]
del state_dict # free CPU memory - will reload again later
state_dict = None

config.name_or_path = pretrained_model_name_or_path

Expand All @@ -1824,11 +1930,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
with no_init_weights(_enable=_fast_init):
model = cls(config, *model_args, **model_kwargs)

if from_pt:
# restore default dtype
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)

if from_tf:
if resolved_archive_file.endswith(".index"):
# Load from a TensorFlow 1.X checkpoint - provided by original authors
Expand Down Expand Up @@ -1858,18 +1959,21 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
raise
elif from_pt:

if low_cpu_mem_usage:
cls._load_pretrained_model_low_mem(model, loaded_state_dict_keys, resolved_archive_file)
else:
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
model,
state_dict,
resolved_archive_file,
pretrained_model_name_or_path,
ignore_mismatched_sizes=ignore_mismatched_sizes,
sharded_metadata=sharded_metadata,
_fast_init=_fast_init,
)
# restore default dtype
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)

model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
model,
state_dict,
loaded_state_dict_keys, # XXX: rename?
resolved_archive_file,
pretrained_model_name_or_path,
ignore_mismatched_sizes=ignore_mismatched_sizes,
sharded_metadata=sharded_metadata,
_fast_init=_fast_init,
low_cpu_mem_usage=low_cpu_mem_usage,
)

# make sure token embedding weights are still tied if needed
model.tie_weights()
Expand All @@ -1893,16 +1997,17 @@ def _load_pretrained_model(
cls,
model,
state_dict,
loaded_keys,
resolved_archive_file,
pretrained_model_name_or_path,
ignore_mismatched_sizes=False,
sharded_metadata=None,
_fast_init=True,
low_cpu_mem_usage=False,
):
# Retrieve missing & unexpected_keys
model_state_dict = model.state_dict()
expected_keys = list(model_state_dict.keys())
loaded_keys = list(state_dict.keys()) if state_dict is not None else sharded_metadata["all_checkpoint_keys"]
prefix = model.base_model_prefix

def _fix_key(key):
Expand Down Expand Up @@ -1993,9 +2098,12 @@ def _find_mismatched_keys(
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]

return mismatched_keys

if low_cpu_mem_usage:
model_state_dict = None # free references to model's params to allow memory freeing
_move_model_to_meta(model, loaded_keys, start_prefix)

if state_dict is not None:
# Whole checkpoint
mismatched_keys = _find_mismatched_keys(
Expand All @@ -2008,7 +2116,8 @@ def _find_mismatched_keys(
)
error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
else:
# Sharded checkpoint
# Sharded checkpoint or whole but low_cpu_mem_usage==True

# This should always be a list but, just to be sure.
if not isinstance(resolved_archive_file, list):
resolved_archive_file = [resolved_archive_file]
Expand All @@ -2017,6 +2126,10 @@ def _find_mismatched_keys(
mismatched_keys = []
for shard_file in resolved_archive_file:
state_dict = load_state_dict(shard_file)

if low_cpu_mem_usage:
model_state_dict = model.state_dict()

# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
# matching the weights in the model.
mismatched_keys += _find_mismatched_keys(
Expand All @@ -2027,7 +2140,13 @@ def _find_mismatched_keys(
remove_prefix_from_model,
ignore_mismatched_sizes,
)
error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix)

if low_cpu_mem_usage:
error_msgs += _load_state_dict_into_meta_model(
model_to_load, state_dict, start_prefix, loaded_keys
)
else:
error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix)

if len(error_msgs) > 0:
error_msg = "\n\t".join(error_msgs)
Expand Down Expand Up @@ -2091,77 +2210,6 @@ def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=Fal

return retrieved_modules

@staticmethod
def _load_pretrained_model_low_mem(model, loaded_state_dict_keys, resolved_archive_file):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this method is called in src/transformers/models/regnet/convert_regnet_seer_10b_to_pytorch.py - might be nice to change it the standard one now

Copy link
Contributor Author

@stas00 stas00 Apr 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch, Patrick

It's all modular now, so if you agree we can add a convenience wrapper:

    @staticmethod
    def _load_pretrained_model_low_mem(model, loaded_state_dict_keys, resolved_archive_file, start_prefix=""):
        """
        This is an experimental function that loads the model using ~1.x model size CPU memory

        Before it gets called we do:

        1. save which state_dict keys we have
        2. drop state_dict before model is created, since the latter takes 1x model size memory

        Here then we continue:

        3. switch to the meta device all params/buffers that are going to be replaced from the loaded state_dict
        4. load state_dict 2nd time
        5. replace the params/buffers from the state_dict

        Currently, it doesn't handle missing_keys, unexpected_keys, mismatched_keys. It can't handle deepspeed.
        """

        _move_model_to_meta(model, loaded_state_dict_keys, start_prefix)
        state_dict = load_state_dict(resolved_archive_file)
        error_msgs = _load_state_dict_into_meta_model(model, state_dict, loaded_state_dict_keys, start_prefix)
        return error_msgs

which restores the original function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and if so, how can I test src/transformers/models/regnet/convert_regnet_seer_10b_to_pytorch.py?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went ahead and added it, so just need to test that conversion script once I know how.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, maybe it's a bit overkill to test the script since the model is huge and it's just a conversion script which are not tested anyways 😅 I'd be fine with just changing the function and "trusting" that it works.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't test conversion scripts. (and the conversion script shouldn'tuse a private method from modeling_utils, missed that in the review...)

Copy link
Contributor Author

@stas00 stas00 Apr 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It probably indicates a need for a low memory usage model update from state_dict functionality. Perhaps once it's exercised some more we can make it a public util function.

"""
This is an experimental function that loads the model using ~1.x model size CPU memory

Before it gets called we do:

1. save which state_dict keys we have
2. drop state_dict before model is created, since the latter takes 1x model size memory

Here then we continue:

3. switch to the meta device all params/buffers that are going to be replaced from the loaded state_dict
4. load state_dict 2nd time
5. replace the params/buffers from the state_dict

Currently, it doesn't handle missing_keys, unexpected_keys, mismatched_keys. It can't handle deepspeed.
"""
require_version_core("torch>=1.9")
if is_deepspeed_zero3_enabled():
raise ValueError("low_cpu_mem_usage arg cannot be used with DeepSpeed ZeRO-3")

# a helper util to find the last sub-module and the param/buffer name
def find_submodule_and_param_name(model, long_key):
split_key = long_key.split(".")
submodule = model
while len(split_key) > 1:
if hasattr(submodule, split_key[0]):
submodule = getattr(submodule, split_key[0])
del split_key[0]
else:
submodule = None
break
return submodule, split_key[0]

# dematerialize param storage for keys that are going to be replaced by state_dict, by
# putting those on the meta device
for k in loaded_state_dict_keys:
submodule, param_name = find_submodule_and_param_name(model, k)
if submodule is not None:
# selectively switch to the meta device only those params/buffers that will
# be next replaced from state_dict. This a complex way to do p.to_("meta")
# since we have no in-place to_ for tensors.
new_val = getattr(submodule, param_name)
if isinstance(new_val, torch.nn.Parameter):
# isinstance returns False for Params on meta device, so switch after the check
new_val = torch.nn.Parameter(new_val.to("meta"))
else:
new_val = new_val.to("meta")
setattr(submodule, param_name, new_val)

# only now can load state_dict(s)
if not isinstance(resolved_archive_file, list):
resolved_archive_file = [resolved_archive_file]

for archive_file in resolved_archive_file:
state_dict = torch.load(archive_file, map_location="cpu")

# materialize state_dict entries one by one on CPU
for k in loaded_state_dict_keys:
if k in state_dict:
submodule, param_name = find_submodule_and_param_name(model, k)
if submodule is not None:
param_dtype = getattr(submodule, param_name).dtype
new_val = state_dict[k].to(param_dtype)
if isinstance(getattr(submodule, param_name), torch.nn.Parameter):
new_val = torch.nn.Parameter(new_val)
setattr(submodule, param_name, new_val)

del state_dict

@classmethod
def register_for_auto_class(cls, auto_class="AutoModel"):
"""
Expand Down
Loading