Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load sharded pt to flax #18419

Merged
merged 22 commits into from
Aug 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
0a0d875
initial commit
ArthurZucker Jul 5, 2022
4dcc32e
add small test
ArthurZucker Jul 5, 2022
556fa9a
add cross pt tf flag to test
ArthurZucker Jul 5, 2022
2e8b241
fix quality
ArthurZucker Jul 5, 2022
cfda646
style
ArthurZucker Jul 5, 2022
4e49a48
update test with new repo
ArthurZucker Jul 5, 2022
74a451b
fix failing test
ArthurZucker Aug 2, 2022
708bcb6
update
ArthurZucker Aug 2, 2022
b9a9baf
Merge branch 'main' of https://github.com/huggingface/transformers in…
ArthurZucker Aug 2, 2022
c49bb18
fix wrong param ordering
ArthurZucker Aug 2, 2022
e20f971
style
ArthurZucker Aug 2, 2022
fdc54b6
Merge branch 'main' of https://github.com/huggingface/transformers in…
ArthurZucker Aug 2, 2022
7495554
Merge branch 'huggingface:main' into convert_sharded_pt_to_flax
ArthurZucker Aug 2, 2022
7f83b7d
Merge branch 'convert_sharded_pt_to_flax' of https://github.com/Arthu…
ArthurZucker Aug 11, 2022
0e037a4
update based on review
ArthurZucker Aug 11, 2022
4e5c03c
Merge branch 'main' of https://github.com/huggingface/transformers in…
ArthurZucker Aug 11, 2022
2847af1
update related to recent new caching mechanism
ArthurZucker Aug 11, 2022
55e8f7c
quality
ArthurZucker Aug 11, 2022
f871230
Update based on review
ArthurZucker Aug 11, 2022
b584707
quality and style
ArthurZucker Aug 11, 2022
0957399
Update src/transformers/modeling_flax_utils.py
ArthurZucker Aug 11, 2022
a43af40
Merge branch 'main' of https://github.com/huggingface/transformers in…
ArthurZucker Aug 11, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 67 additions & 7 deletions src/transformers/modeling_flax_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@
#####################


def load_pytorch_checkpoint_in_flax_state_dict(flax_model, pytorch_checkpoint_path, allow_missing_keys=False):
def load_pytorch_checkpoint_in_flax_state_dict(
flax_model, pytorch_checkpoint_path, is_sharded, allow_missing_keys=False
):
"""Load pytorch checkpoints in a flax model"""
try:
import torch # noqa: F401
Expand All @@ -50,14 +52,17 @@ def load_pytorch_checkpoint_in_flax_state_dict(flax_model, pytorch_checkpoint_pa
)
raise

pt_path = os.path.abspath(pytorch_checkpoint_path)
logger.info(f"Loading PyTorch weights from {pt_path}")
if not is_sharded:
pt_path = os.path.abspath(pytorch_checkpoint_path)
logger.info(f"Loading PyTorch weights from {pt_path}")

pt_state_dict = torch.load(pt_path, map_location="cpu")
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")

flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
pt_state_dict = torch.load(pt_path, map_location="cpu")
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")

flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
else:
# model is sharded and pytorch_checkpoint_path already contains the list of .pt shard files
flax_state_dict = convert_pytorch_sharded_state_dict_to_flax(pytorch_checkpoint_path, flax_model)
return flax_state_dict


Expand Down Expand Up @@ -156,6 +161,61 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
return unflatten_dict(flax_state_dict)


############################
# Sharded Pytorch => Flax #
############################


def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
import torch

# Load the index
flax_state_dict = {}
for shard_file in shard_filenames:
# load using msgpack utils
pt_state_dict = torch.load(shard_file)
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}

model_prefix = flax_model.base_model_prefix
random_flax_state_dict = flatten_dict(flax_model.params)

load_model_with_head_into_base_model = (model_prefix not in flax_model.params) and (
model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()])
)
load_base_model_into_model_with_head = (model_prefix in flax_model.params) and (
model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()])
)
# Need to change some parameters name to match Flax names
for pt_key, pt_tensor in pt_state_dict.items():

pt_tuple_key = tuple(pt_key.split("."))

# remove base model prefix if necessary
has_base_model_prefix = pt_tuple_key[0] == model_prefix
if load_model_with_head_into_base_model and has_base_model_prefix:
pt_tuple_key = pt_tuple_key[1:]

# Correctly rename weight parameters
flax_key, flax_tensor = rename_key_and_reshape_tensor(
pt_tuple_key, pt_tensor, random_flax_state_dict, model_prefix
)
# add model prefix if necessary
require_base_model_prefix = (model_prefix,) + flax_key in random_flax_state_dict
if load_base_model_into_model_with_head and require_base_model_prefix:
flax_key = (model_prefix,) + flax_key

if flax_key in random_flax_state_dict:
if flax_tensor.shape != random_flax_state_dict[flax_key].shape:
raise ValueError(
f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
)

# also add unexpected weight so that warning is thrown
flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
return unflatten_dict(flax_state_dict)


#####################
# Flax => PyTorch #
#####################
Expand Down
20 changes: 19 additions & 1 deletion src/transformers/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from .utils import (
FLAX_WEIGHTS_INDEX_NAME,
FLAX_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
PushToHubMixin,
add_code_sample_docstrings,
Expand Down Expand Up @@ -650,6 +651,10 @@ def from_pretrained(
if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
# Load from a PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)):
# Load from a sharded pytorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
is_sharded = True
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
# Load from a Flax checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
Expand Down Expand Up @@ -700,6 +705,13 @@ def from_pretrained(
)
if resolved_archive_file is not None:
is_sharded = True
# Maybe the checkpoint is pytorch sharded, we try to grab the pytorch index name in this case.
elif resolved_archive_file is None and from_pt:
resolved_archive_file = cached_file(
pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs
)
if resolved_archive_file is not None:
is_sharded = True
if resolved_archive_file is None:
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
# message.
Expand All @@ -714,6 +726,12 @@ def from_pretrained(
f" {FLAX_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to"
" load this model from those weights."
)
elif has_file(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **has_file_kwargs):
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named"
f" {FLAX_WEIGHTS_INDEX_NAME} but there is a sharded file for PyTorch weights. Use"
" `from_pt=True` to load this model from those weights."
)
else:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named"
Expand Down Expand Up @@ -761,7 +779,7 @@ def from_pretrained(
model = cls(config, *model_args, _do_init=_do_init, **model_kwargs)

if from_pt:
state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file)
state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file, is_sharded)
else:

if is_sharded:
Expand Down
8 changes: 8 additions & 0 deletions tests/test_modeling_flax_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,6 +1099,14 @@ def test_checkpoint_sharding_local(self):
for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(new_model.params).values()):
self.assertTrue(np.allclose(np.array(p1), np.array(p2)))

@is_pt_flax_cross_test
def test_from_sharded_pt(self):
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True)
ref_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-fx-only")
for key, ref_val in flatten_dict(ref_model.params).items():
val = flatten_dict(model.params)[key]
assert np.allclose(np.array(val), np.array(ref_val))

def test_gradient_checkpointing(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

Expand Down