-
Notifications
You must be signed in to change notification settings - Fork 309
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Convert Orbax ckpt to HuggingFace #581
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change LGTMs but without a test I'm skeptical. I don't think this needs to be tested as exhaustively but how could we test it somewhat?
@rwitten I have tested locally but were you thinking running this at a nightly cadence? |
@A9isha Hello, do you by any chance have a script that does the opposite, converting HF to Orbax? |
We have the script llama_or_mistral_ckpt.py to convert the original PyTorch Llama2 checkpoint that Meta provides into MaxText checkpoint. You can see the usage here for Llama2-7b for e.g. |
Hi @A9isha , I found two bugs in your conversion code, and I have fixed it and validated the weights converted from maxtext version of llama3-8b with the HF one. First one is the
Second bug is related to Q and K, I understand it's easy to make mistakes here because both original LLaMA, LLaMA-HF and maxtext stored the tensor differently, the correct way is to do following by reversing first to original LLaMA weight then to HF weight:
|
I think this script is fine, and I have been using it quite a lot. It should be updated for Llama3.1 though (whenever that is merged). And maybe also the 70B models?
Any chance this can be merged @A9isha ? |
No description provided.