Skip to content

Commit

Permalink
Load sharded pt to flax (#18419)
Browse files Browse the repository at this point in the history
* initial commit

* add small test

* add cross pt tf flag to test

* fix quality

* style

* update test with new repo

* fix failing test

* update

* fix wrong param ordering

* style

* update based on review

* update related to recent new caching mechanism

* quality

* Update based on review

Co-authored-by: sgugger <sylvain.gugger@gmail.com>

* quality and style

* Update src/transformers/modeling_flax_utils.py
Co-authored-by: sgugger <sylvain.gugger@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
3 people authored Aug 12, 2022
1 parent c8b6ae8 commit bce36ee
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 8 deletions.
74 changes: 67 additions & 7 deletions src/transformers/modeling_flax_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@
#####################


def load_pytorch_checkpoint_in_flax_state_dict(flax_model, pytorch_checkpoint_path, allow_missing_keys=False):
def load_pytorch_checkpoint_in_flax_state_dict(
flax_model, pytorch_checkpoint_path, is_sharded, allow_missing_keys=False
):
"""Load pytorch checkpoints in a flax model"""
try:
import torch # noqa: F401
Expand All @@ -50,14 +52,17 @@ def load_pytorch_checkpoint_in_flax_state_dict(flax_model, pytorch_checkpoint_pa
)
raise

pt_path = os.path.abspath(pytorch_checkpoint_path)
logger.info(f"Loading PyTorch weights from {pt_path}")
if not is_sharded:
pt_path = os.path.abspath(pytorch_checkpoint_path)
logger.info(f"Loading PyTorch weights from {pt_path}")

pt_state_dict = torch.load(pt_path, map_location="cpu")
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")

flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
pt_state_dict = torch.load(pt_path, map_location="cpu")
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")

flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
else:
# model is sharded and pytorch_checkpoint_path already contains the list of .pt shard files
flax_state_dict = convert_pytorch_sharded_state_dict_to_flax(pytorch_checkpoint_path, flax_model)
return flax_state_dict


Expand Down Expand Up @@ -156,6 +161,61 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
return unflatten_dict(flax_state_dict)


############################
# Sharded Pytorch => Flax #
############################


def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
import torch

# Load the index
flax_state_dict = {}
for shard_file in shard_filenames:
# load using msgpack utils
pt_state_dict = torch.load(shard_file)
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}

model_prefix = flax_model.base_model_prefix
random_flax_state_dict = flatten_dict(flax_model.params)

load_model_with_head_into_base_model = (model_prefix not in flax_model.params) and (
model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()])
)
load_base_model_into_model_with_head = (model_prefix in flax_model.params) and (
model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()])
)
# Need to change some parameters name to match Flax names
for pt_key, pt_tensor in pt_state_dict.items():

pt_tuple_key = tuple(pt_key.split("."))

# remove base model prefix if necessary
has_base_model_prefix = pt_tuple_key[0] == model_prefix
if load_model_with_head_into_base_model and has_base_model_prefix:
pt_tuple_key = pt_tuple_key[1:]

# Correctly rename weight parameters
flax_key, flax_tensor = rename_key_and_reshape_tensor(
pt_tuple_key, pt_tensor, random_flax_state_dict, model_prefix
)
# add model prefix if necessary
require_base_model_prefix = (model_prefix,) + flax_key in random_flax_state_dict
if load_base_model_into_model_with_head and require_base_model_prefix:
flax_key = (model_prefix,) + flax_key

if flax_key in random_flax_state_dict:
if flax_tensor.shape != random_flax_state_dict[flax_key].shape:
raise ValueError(
f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
)

# also add unexpected weight so that warning is thrown
flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
return unflatten_dict(flax_state_dict)


#####################
# Flax => PyTorch #
#####################
Expand Down
20 changes: 19 additions & 1 deletion src/transformers/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from .utils import (
FLAX_WEIGHTS_INDEX_NAME,
FLAX_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
PushToHubMixin,
add_code_sample_docstrings,
Expand Down Expand Up @@ -650,6 +651,10 @@ def from_pretrained(
if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
# Load from a PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)):
# Load from a sharded pytorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
is_sharded = True
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
# Load from a Flax checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
Expand Down Expand Up @@ -700,6 +705,13 @@ def from_pretrained(
)
if resolved_archive_file is not None:
is_sharded = True
# Maybe the checkpoint is pytorch sharded, we try to grab the pytorch index name in this case.
elif resolved_archive_file is None and from_pt:
resolved_archive_file = cached_file(
pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs
)
if resolved_archive_file is not None:
is_sharded = True
if resolved_archive_file is None:
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
# message.
Expand All @@ -714,6 +726,12 @@ def from_pretrained(
f" {FLAX_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to"
" load this model from those weights."
)
elif has_file(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **has_file_kwargs):
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named"
f" {FLAX_WEIGHTS_INDEX_NAME} but there is a sharded file for PyTorch weights. Use"
" `from_pt=True` to load this model from those weights."
)
else:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named"
Expand Down Expand Up @@ -761,7 +779,7 @@ def from_pretrained(
model = cls(config, *model_args, _do_init=_do_init, **model_kwargs)

if from_pt:
state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file)
state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file, is_sharded)
else:

if is_sharded:
Expand Down
8 changes: 8 additions & 0 deletions tests/test_modeling_flax_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,6 +1099,14 @@ def test_checkpoint_sharding_local(self):
for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(new_model.params).values()):
self.assertTrue(np.allclose(np.array(p1), np.array(p2)))

@is_pt_flax_cross_test
def test_from_sharded_pt(self):
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True)
ref_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-fx-only")
for key, ref_val in flatten_dict(ref_model.params).items():
val = flatten_dict(model.params)[key]
assert np.allclose(np.array(val), np.array(ref_val))

def test_gradient_checkpointing(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

Expand Down

0 comments on commit bce36ee

Please sign in to comment.