Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CLI: convert sharded PT models #17959

Merged
merged 4 commits into from
Jun 30, 2022
Merged

Conversation

gante
Copy link
Member

@gante gante commented Jun 30, 2022

What does this PR do?

This PR adds a major upgrade and a minor change to the pt-to-tf CLI.

Major upgrade: we can now convert sharded PT models. It updates how the from_pt loading works so as to be able to load from shards. It also updates how the pt-to-tf CLI stores the models, so it uses sharding capabilities when needed.

Minor change: adds a flag to control the maximum hidden layer admissible error. It is relatively common to find models where the outputs from the PT and TF models are nearly the same, but the hidden layers have a larger mismatch. This flag allows us to temporarily increase the admissible error if the model seems to be behaving properly (for instance, all RegNet models had a hidden layer difference between 1e-4 and 1e-2, but the outputs were behaving properly).

Example of sharded TF model PR, using the updated tools: https://huggingface.co/facebook/regnet-y-10b-seer-in1k/discussions/1

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 30, 2022

The documentation is not available anymore as the PR was closed or merged.

Comment on lines +126 to +129
for path in pytorch_checkpoint_path:
pt_path = os.path.abspath(path)
logger.info(f"Loading PyTorch weights from {pt_path}")
pt_state_dict.update(torch.load(pt_path, map_location="cpu"))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's super nice 👍🏻

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a nice first step, but ideally, we'd want to convert the shards one by one to avoid using too much RAM and be able to convert LLMs checkpoints without needing a battle station.

Copy link
Member Author

@gante gante Jun 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha yes, I had to spin up a machine with >100GB of RAM to convert the RegNets 😬

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks for working on that!

Copy link
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM to me too!

@ArthurZucker
Copy link
Collaborator

BTW could we add 2 tests, test_load_sharded_tf_to_pt and load_sharded_pt_to_tf

@gante
Copy link
Member Author

gante commented Jun 30, 2022

TF shards -> PT probably won't work, but I will add the test for PT shards -> TF 👍

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice improvements, thanks!

Comment on lines +126 to +129
for path in pytorch_checkpoint_path:
pt_path = os.path.abspath(path)
logger.info(f"Loading PyTorch weights from {pt_path}")
pt_state_dict.update(torch.load(pt_path, map_location="cpu"))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a nice first step, but ideally, we'd want to convert the shards one by one to avoid using too much RAM and be able to convert LLMs checkpoints without needing a battle station.

Comment on lines +2161 to +2164
elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)):
# Load from a sharded PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
is_sharded = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice addition, maybe we should also support loading from a remote sharded checkpoint with from_pt=True? (It should be its own PR if we decide to support this.)

@gante gante merged commit 91e1f24 into huggingface:main Jun 30, 2022
@gante gante deleted the sharded_tf_conversion branch June 30, 2022 15:51
viclzhu pushed a commit to viclzhu/transformers that referenced this pull request Jul 18, 2022
* sharded conversion; add flag to control max hidden error

* better hidden name matching

* Add test: load TF from PT shards

* fix test (PT data must be local)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants