Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load sharded pt model to flax (from the hub ) #18026

69 changes: 62 additions & 7 deletions src/transformers/modeling_flax_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@
#####################


def load_pytorch_checkpoint_in_flax_state_dict(flax_model, pytorch_checkpoint_path, allow_missing_keys=False):
def load_pytorch_checkpoint_in_flax_state_dict(
flax_model, pytorch_checkpoint_path, is_sharded, allow_missing_keys=False
):
"""Load pytorch checkpoints in a flax model"""
try:
import torch # noqa: F401
Expand All @@ -50,14 +52,17 @@ def load_pytorch_checkpoint_in_flax_state_dict(flax_model, pytorch_checkpoint_pa
)
raise

pt_path = os.path.abspath(pytorch_checkpoint_path)
logger.info(f"Loading PyTorch weights from {pt_path}")
if not is_sharded:
pt_path = os.path.abspath(pytorch_checkpoint_path)
logger.info(f"Loading PyTorch weights from {pt_path}")

pt_state_dict = torch.load(pt_path, map_location="cpu")
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")

flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
pt_state_dict = torch.load(pt_path, map_location="cpu")
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")

flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
else:
# model is sharded and pytorch_checkpoint_path already contains the list of shard files
flax_state_dict = convert_pytorch_sharded_state_dict_to_flax(pytorch_checkpoint_path, flax_model)
return flax_state_dict


Expand Down Expand Up @@ -156,6 +161,56 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
return unflatten_dict(flax_state_dict)


def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
import torch

# Load the index
flax_state_dict = {}
for shard_file in shard_filenames:
# load using msgpack utils
pt_state_dict = torch.load(shard_file)
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}

model_prefix = flax_model.base_model_prefix
random_flax_state_dict = flatten_dict(flax_model.params)

load_model_with_head_into_base_model = (model_prefix not in flax_model.params) and (
model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()])
)
load_base_model_into_model_with_head = (model_prefix in flax_model.params) and (
model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()])
)
# Need to change some parameters name to match Flax names
for pt_key, pt_tensor in pt_state_dict.items():

pt_tuple_key = tuple(pt_key.split("."))

# remove base model prefix if necessary
has_base_model_prefix = pt_tuple_key[0] == model_prefix
if load_model_with_head_into_base_model and has_base_model_prefix:
pt_tuple_key = pt_tuple_key[1:]

# Correctly rename weight parameters
flax_key, flax_tensor = rename_key_and_reshape_tensor(
pt_tuple_key, pt_tensor, random_flax_state_dict, model_prefix
)
# add model prefix if necessary
require_base_model_prefix = (model_prefix,) + flax_key in random_flax_state_dict
if load_base_model_into_model_with_head and require_base_model_prefix:
flax_key = (model_prefix,) + flax_key

if flax_key in random_flax_state_dict:
if flax_tensor.shape != random_flax_state_dict[flax_key].shape:
raise ValueError(
f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
)

# also add unexpected weight so that warning is thrown
flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
return unflatten_dict(flax_state_dict)


#####################
# Flax => PyTorch #
#####################
Expand Down
17 changes: 15 additions & 2 deletions src/transformers/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
FLAX_WEIGHTS_INDEX_NAME,
FLAX_WEIGHTS_NAME,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
EntryNotFoundError,
PushToHubMixin,
Expand All @@ -58,7 +59,7 @@
logging,
replace_return_docstrings,
)
from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files, hf_url_exists


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -639,6 +640,10 @@ def from_pretrained(
if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
# Load from a PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)):
# Load from a sharded pytorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
is_sharded = True
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
# Load from a Flax checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
Expand All @@ -660,6 +665,14 @@ def from_pretrained(
)
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path
# check if an index file exists
elif hf_url_exists(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME):
archive_file = hf_bucket_url(
pretrained_model_name_or_path,
filename=WEIGHTS_INDEX_NAME,
revision=revision,
)
is_sharded = True
else:
filename = WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME
archive_file = hf_bucket_url(
Expand Down Expand Up @@ -780,7 +793,7 @@ def from_pretrained(
model = cls(config, *model_args, _do_init=_do_init, **model_kwargs)

if from_pt:
state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file)
state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file, is_sharded)
else:

if is_sharded:
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/utils/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ def is_remote_url(url_or_filename):
return parsed.scheme in ("http", "https")


def hf_url_exists(path, file):
r = requests.head(os.path.join("https://huggingface.co/", path, "raw/main/", file))
return r.status_code == requests.codes.ok


def hf_bucket_url(
model_id: str, filename: str, subfolder: Optional[str] = None, revision: Optional[str] = None, mirror=None
) -> str:
Expand Down
8 changes: 8 additions & 0 deletions tests/test_modeling_flax_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,6 +1099,14 @@ def test_checkpoint_sharding_local(self):
for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(new_model.params).values()):
self.assertTrue(np.allclose(np.array(p1), np.array(p2)))

@is_pt_flax_cross_test
def test_from_sharded_pt(self):
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True)
ref_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-fx-only")
for key, ref_val in flatten_dict(ref_model.params).items():
val = flatten_dict(model.params)[key]
assert np.allclose(np.array(val), np.array(ref_val))

def test_gradient_checkpointing(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

Expand Down