Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes required to save_model for certain models (e.g., Phi 3.5 Vision) #34690

Open
jjbuck opened this issue Nov 11, 2024 · 4 comments
Open
Labels

Comments

@jjbuck
Copy link

jjbuck commented Nov 11, 2024

Feature request

This request proposes one of three changes (see Motivation for background, and Your contribution more thoughts on possible solutions) in order to allow saving of a certain class of models, including but not limited to Phi 3.5 Vision.

  1. Accept a state_dict argument in the Trainer class's save_model() method (https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py#L3719-L3768). This state_dict parameter should then be passed down to the call to the private _save() method (https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py#L3842), which does accept a state_dict argument.
  2. Rather thanstate_dict as an argument to save_model(), determine the appropriate heuristic such that we can successfully save Phi 3.5 Vision and other architecturally similar models.
  3. Some change to the way transformers handles shared tensors...?

Motivation

I encountered an issue while trying to fine-tune Phi 3.5 Vision using the Trainer class from transformers. In particular, when trying to call save() or save_pretrained(), transformers throws the following error:

RuntimeError: The weights trying to be saved contained shared tensors [{'model.vision_embed_tokens.wte.weight', 
'model.embed_tokens.weight'}] that are mismatching the transformers base configuration.
Try saving using `safe_serialization=False` or remove this tensor sharing.

Below are two minimal reproducible examples:
Example #1

from transformers import AutoModelForCausalLM
model_id = "microsoft/Phi-3.5-vision-instruct"
model = AutoModelForCausalLM.from_pretrained(
    model_id, device_map="cuda", trust_remote_code=True, torch_dtype="auto"
)
model.save_pretrained("out")

Example #2

from transformers import (
    Trainer,
    TrainingArguments,
)
training_args = TrainingArguments(
        save_only_model=True,
        output_dir='./out/',
        save_strategy='no',
    )
trainer = Trainer(
        model=model,
        args=training_args
    )
trainer.save_model()

It looks like others have also encountered this issue. See the list of reference issues below in "Issues".

A contributor to the Phi 3 Vision cookbook suggested the following solution, stating "You need to remove the wte weight. It's okay because when the model is loaded from the checkpoint, it will automatically copy the weight from the embedding weight."

state_dict = model.state_dict()
state_dict = {k:v for k, v in state_dict.items() if "wte" not in k}
model.save_pretrained(args.save_model_path, state_dict=state_dict, safe_serialization=True)
processor.save_pretrained(args.save_model_path)

This does indeed seem to work. However, it doesn't exactly fit into a use case that relies on the Trainer abstraction. The call to the Trainer class's save_model() method doesn't accommodate a state_dict argument (see https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py#L3719-L3768).

Issues

  1. RuntimeError: The weights trying to be saved contained shared tensors [{'model.vision_embed_tokens.wte.weight', 'model.embed_tokens.weight'}] kazuar/Phi3-Vision-ft#2
  2. https://discuss.huggingface.co/t/runtimeerror-when-saving-phi-3-5-vision-due-to-shared-tensors/116457
  3. Saving Phi 3 vision fails due to tensor sharing #32354
  4. https://discuss.huggingface.co/t/using-trainer-to-save-a-bartforsequenceclassification-model/81606

Your contribution

I'd be glad to submit a PR, but I think some discussion is needed from the appropriate transformers stakeholders.

It's not clear to me whether the most appropriate change here is to modify the function signature.

Alternatively, maybe there's a heuristic by which we could determine whether the architecture is such that one needs to save everything but the wte weights. I don't know the answer to that off-hand. It may require a deep dive from Phi 3/3.5 Vision SMEs.

Or more broadly, perhaps there's some change to the way transformers handles shared tensors in the base configuration that would be most appropriate.

@LysandreJik
Copy link
Member

Thanks a lot for your thorough issue @jjbuck!

Let me ping a few people regarding some people that could eventually help: @muellerzr for the trainer, @SunMarc and @Wauplin for safetensors serialization

@Vectorrent
Copy link

Vectorrent commented Nov 16, 2024

Yeah, I just ran into this one as well. I've been working on a custom architecture, which passes a bunch of layers around, and assigns them to "self", for convenience. Turns-out that doing this throws the error from above.

Fortunately, using safe_serialization=False does address address the problem.

@kylesayrs
Copy link
Contributor

It's been a while since I looked at this, but from when I lasted looked at this issue, I remember the problem potentially being solvable by listing vision_embed_tokens and embed_tokens in the list of tied weights, although I did not verify this.

@SunMarc
Copy link
Member

SunMarc commented Nov 25, 2024

Thanks for the report @jjbuck, this is indeed something that should have been removed automatically when saving with safetensors. As this is running remote code, the issue most probably lies in their code #32354

Just cheked but you need to following in their custom code:

class Phi3VForCausalLM(Phi3VPreTrainedModel):
    _tied_weights_keys = ["lm_head.weight", "model.vision_embed_tokens.wte.weight"]

Feel free to open a PR on the model repo

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

5 participants