-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CLI: convert sharded PT models #17959
Conversation
The documentation is not available anymore as the PR was closed or merged. |
for path in pytorch_checkpoint_path: | ||
pt_path = os.path.abspath(path) | ||
logger.info(f"Loading PyTorch weights from {pt_path}") | ||
pt_state_dict.update(torch.load(pt_path, map_location="cpu")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's super nice 👍🏻
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is a nice first step, but ideally, we'd want to convert the shards one by one to avoid using too much RAM and be able to convert LLMs checkpoints without needing a battle station.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
haha yes, I had to spin up a machine with >100GB of RAM to convert the RegNets 😬
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thanks for working on that!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM to me too!
BTW could we add 2 tests, |
TF shards -> PT probably won't work, but I will add the test for PT shards -> TF 👍 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice improvements, thanks!
for path in pytorch_checkpoint_path: | ||
pt_path = os.path.abspath(path) | ||
logger.info(f"Loading PyTorch weights from {pt_path}") | ||
pt_state_dict.update(torch.load(pt_path, map_location="cpu")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is a nice first step, but ideally, we'd want to convert the shards one by one to avoid using too much RAM and be able to convert LLMs checkpoints without needing a battle station.
elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)): | ||
# Load from a sharded PyTorch checkpoint | ||
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME) | ||
is_sharded = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice addition, maybe we should also support loading from a remote sharded checkpoint with from_pt=True
? (It should be its own PR if we decide to support this.)
* sharded conversion; add flag to control max hidden error * better hidden name matching * Add test: load TF from PT shards * fix test (PT data must be local)
What does this PR do?
This PR adds a major upgrade and a minor change to the
pt-to-tf
CLI.Major upgrade: we can now convert sharded PT models. It updates how the
from_pt
loading works so as to be able to load from shards. It also updates how thept-to-tf
CLI stores the models, so it uses sharding capabilities when needed.Minor change: adds a flag to control the maximum hidden layer admissible error. It is relatively common to find models where the outputs from the PT and TF models are nearly the same, but the hidden layers have a larger mismatch. This flag allows us to temporarily increase the admissible error if the model seems to be behaving properly (for instance, all RegNet models had a hidden layer difference between 1e-4 and 1e-2, but the outputs were behaving properly).
Example of sharded TF model PR, using the updated tools: https://huggingface.co/facebook/regnet-y-10b-seer-in1k/discussions/1