-
Notifications
You must be signed in to change notification settings - Fork 730
How to load sharded checkpoints? #31
Comments
Could you try using our consolidation script for these smaller models, and try loading the consolidated checkpoint instead? Instructions here: https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/download_opt175b.md#reshard-the-shards |
Hey @suchenzang, Thanks for your answer. We did try out these scripts, but they don't work for a couple of reasons: If you look into:
You can see that the function expects the filenames to have a syntax which is different from the filenames of 125m, which are
Accordingly, if you run the command, you'll soon get the following error:
Also more generally this function cannot work as it never loads the correct module and no metadata is stored anywhere in the 125m checkpoints. Could you try to load the 125M checkpoints and share the script here? It'd be immensely helpful for the community to understand how to load the flat 1D-array weights I think :-) Thanks a lot! |
More generally, the checkpoints of the 125m have no information about:
IMO this is required to load the flat 1d tensor correctly into a fairseq model. Do you know where we could find the shard_metadata information for the 125m and other checkpoints? |
I only interacted with this code when it was in a branch off |
@suchenzang @stephenroller any way that you guys could send us or open-source the I've tried for quite some time now to reproduce the correct parameter mapping without much success. It's not really stated on how many GPUs ( Also, there is one thing I don't fully understand: It would be extremely useful if you guys could provide some kind of script that allows to load the sharded checkpoints on CPU :-) |
Hi, @patrickvonplaten . Thanks for your code. I met the error that Thanks! |
Okay I wrote #60 to help us out here. It outputs the full unflattened and non-sharded checkpoint and should be pretty easy to load into hugging face. See the docstring for usage.
|
Keep in mind that metaseq, just like fairseq, actually prompts the model on the EOS token. |
❓ Questions and Help
After having set-up the libraries as described in: https://github.com/facebookresearch/metaseq/blob/main/docs/setup.md ,
it is possible to load the 350m checkpoint since it's not sharded as follows:
Next we need to comment out one line in the Megatron-LM library which is only relevant for training (initialize different random seeds accross pp ranks):
Comment out this line: https://github.com/ngoyal2707/Megatron-LM/blob/ae0b844c1f6725c3433a95e42cac760b3885170b/megatron/initialize.py#L65 in your local clone of Megatron-LM
Now we write the following Python script to a
run_model.py
file:Problem This only works for the 350m checkpoint!!! For the other checkpoints this doesn't work.
E.g. when replacing:
[os.path.join(path, "reshard.pt")]
by
[os.path.join(path, "reshard-model_part-0.pt"), os.path.join(path, "reshard-model_part-1.pt")]
(part-0 and part-1 of the 125M model),we're getting an error because the weigths are all flattened into 1D-arrays.
Using #29 sadly also doesn't help, since the checkpoints don't seem to be in the
*shard*
format as required here:metaseq/metaseq/distributed/stitch_fsdp_ckpt.py
Line 45 in 48b9b6c
The parameter flattening seems to come from Fairscale and we've found some functionality to unflatten it here: https://github.com/facebookresearch/fairscale/blob/51b53ddb6c3aa77426c7d5cc0b543b79628053c4/fairscale/nn/misc/flatten_params_wrapper.py#L358 , but we don't manage to wrap our head around how to make it work exactly.
@stephenroller @suchenzang @zhiqwang - any pointers on how we could load the 125M model (and the others) into a
model
instance ofmetaseq
?The text was updated successfully, but these errors were encountered: