Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model forgets finetuning after saving / loading #503

Closed
ingo-m opened this issue May 25, 2023 · 18 comments
Closed

Model forgets finetuning after saving / loading #503

ingo-m opened this issue May 25, 2023 · 18 comments

Comments

@ingo-m
Copy link

ingo-m commented May 25, 2023

After finetuning an AutoModelForCausalLM model with PEFT, the model forgets what it learned after saving / loading.

I created a minimal example here: https://colab.research.google.com/drive/1mGpLQk8VMFfh_jcMGPaTfygGOdqlUUTs?usp=sharing
The colab notebook can run on a (free) T4 instance.

In the minimal example, I'm using PEFT on a bigscience/bloom-560m base model. As a toy example, during finetuning, the model learns the alphabet in reverse order.

Before training and after each epoch, I let the model generate a prediction from this prompt: "z y x w v u t s r"

INFERENCE before training
z y x w v u t s r t d d d d d d d d d d d d d d d d d d d d d d d d d d d d d d d d d d d d d d d d d d d d d d d
INFERENCE after epoch 0
z y x w v u t s r q p o n m l k j i h g f e d c b a
INFERENCE after epoch 1
z y x w v u t s r q p o n m l k j i h g f e d c b a

After one epoch, the model has already memorized the target sequence (the alphabet in reverse order).

Then I save the model to gdrive, restart the runtime, and load the model. The model has forgotten what it learned during finetuning:

INFERENCE after loading model from disk
z y x w v u t s r caigut en si r son de se l o in nu p me en bl o id le di ge se di ge si bl o id le di ge se di ge si bl o id le di ge se di ge si bl o id le

I observe this “forgetting” locally and on google colab. I also tried pushing the PEFT model to huggingface instead of saving locally, it also results in forgetting.

Moreover, it doesn’t matter which model dtype I use (torch.bfloat16, torch.float16, torch.float32), or whether load_in_8bit is True or False.

In the minimal example linked above, I’m using a custom torch training loop, but I get the same result with transformers.Trainer.

Am I overlooking something obvious? How is this possible?

@ingo-m
Copy link
Author

ingo-m commented May 26, 2023

I tried one more thing. The "forgetting" also happens if after finetuning I save the model with torch.save(model.state_dict(), model_save_path) and load it like so:

model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m")
_ = model.load_state_dict(torch.load(model_save_path))

It's weird.

@MatthiasEg
Copy link

MatthiasEg commented Jun 7, 2023

Hi @ingo-m, I just had a similar issue while saving and loading a PEFTModel. Here's a short snipped on how you can load a checkpoint if a PEFTModel after training:

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    trust_remote_code=True
)
model = prepare_model_for_kbit_training(model)

config = LoraConfig(
    r=64,
    lora_alpha=16,
    target_modules=[
        "query_key_value",
        "dense",
        "dense_h_to_4h",
        "dense_4h_to_h",
    ],
    lora_dropout=0.1,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, config)
model.resize_token_embeddings(len(tokenizer)) # Optional
model.load_state_dict(torch.load("..../checkpoint-300/pytorch_model.bin"))

@ingo-m
Copy link
Author

ingo-m commented Jun 7, 2023

@MatthiasEg thanks. So checkpoints are not affected by this "forgetting" issue, and can be used as a workaround?

@MatthiasEg
Copy link

Checkpoints or saving the model - there's no difference.

model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m")
_ = model.load_state_dict(torch.load(model_save_path))

From the small snipped you provided, you are not creating a PEFT model (from the base model) before loading all the weights, meaning all PEFT layers (and possible other adjustments to the model) are missing. Hence, it's impossible to reload the previous state the model had before saving.

@ingo-m
Copy link
Author

ingo-m commented Jun 7, 2023

you are not creating a PEFT model

Not sure whether I understand this. After finetuning, I load the model like this:

import torch
from peft import PeftConfig, PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load model for inference

model_save_path = "drive/My Drive/colab/peft_bug_minimal_example"

load_in_8bit = False
torch_dtype = torch.bfloat16

config = PeftConfig.from_pretrained(model_save_path)

base_model = AutoModelForCausalLM.from_pretrained(
    config.base_model_name_or_path,
    load_in_8bit=load_in_8bit,
    torch_dtype=torch_dtype,
    device_map="auto",
)

tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)

finetuned_model = PeftModel.from_pretrained(
    base_model, model_save_path
)

See https://colab.research.google.com/drive/1mGpLQk8VMFfh_jcMGPaTfygGOdqlUUTs?usp=sharing

My understanding is that the base_model is the same as the original base model (i.e. "bigscience/bloom-560m"), and finetuned_model has the PEFT layers? Isn't the PeftModel.from_pretrained() function supposed to add the PEFT layers?

@MatthiasEg
Copy link

MatthiasEg commented Jun 7, 2023

Honestly, I don't know what the correct approach to safe and load a PEFT fine-tuned model is (maybe some could elaborate on that). But I know that with the snippet I provided above, I can successfully load my PEFT fine-tuned model from the disk, without losing the knowledge it gained during training.

Did you try using prepare_model_for_kbit_training and get_peft_model from the peft library, as well as loading the state dict afterwards with load_state_dict? This approach at least worked for me.

@ingo-m
Copy link
Author

ingo-m commented Jun 8, 2023

@MatthiasEg thanks, I can confirm that saving & loading with the torch state_dict works as expected (no "forgetting"):

# After training:
torch.save(model.state_dict(), model_save_path)

# ...

# Load finetuned model from disk:
model = get_peft_model(model, config)  # Previously it didn't work because this line was missing
model.load_state_dict(torch.load(model_save_path))

Notebook with complete example (works as expected): https://colab.research.google.com/drive/1JPevLSsOq6DWr2tKw7reEDcgwmYhbEfp?usp=sharing

I still don't understand why my original example, using the PEFT methods save_pretrained() and from_pretrained(), doesn't work: 🤷

# After training:
model.save_pretrained(model_save_path)

# ...

# Load finetuned model from disk:
finetuned_model = PeftModel.from_pretrained(base_model, model_save_path)

# DOES NOT WORK, MODEL DOES NOT REMEMBER FINETUNING

Notebook with complete example (doesn't work as expected, forgets finetuning): https://colab.research.google.com/drive/1mGpLQk8VMFfh_jcMGPaTfygGOdqlUUTs?usp=sharing

The conclusion for me is to avoid the PEFT methods save_pretrained() and from_pretrained().

@ingo-m ingo-m closed this as completed Jun 8, 2023
@ingo-m
Copy link
Author

ingo-m commented Jun 9, 2023

Actually let me re-open this in the hope that someone can shed some light on it. Using the state_dict is not a solution for large models as described here https://github.com/huggingface/blog/blob/main/accelerate-large-models.md#sharding-state-dicts

@ingo-m ingo-m reopened this Jun 9, 2023
@younesbelkada
Copy link
Contributor

Hi @ingo-m
Sorry for just chiming in, this seems rather an important issue
It is quite strange that the model "forgets" what it had learned even if you save the model with save_pretrained - it should work as we test it in the CI:

def _test_training_layer_indexing(self, model_id, config_cls, config_kwargs):

Would you be able to share a small reproducible snippet that is quick to run? 🙏

@younesbelkada
Copy link
Contributor

cc @pacman100 also just FYI

@ingo-m
Copy link
Author

ingo-m commented Jun 22, 2023

@younesbelkada thanks for looking into this.

This is a colab notebook with a standalone example (where the model forgets the finetuning), it should run as is on a free colab instance: https://colab.research.google.com/drive/1mGpLQk8VMFfh_jcMGPaTfygGOdqlUUTs?usp=sharing

It's quite a few lines of code but that's just boilerplate (required for saving & loading model from disk). When you open the notebook you should be able to see the output from when I ran it.

@ingo-m
Copy link
Author

ingo-m commented Jul 4, 2023

I coincidentally discovered why my original example didn't work as expected (model "forgets" finetuning after saving & loading). In my example I used this LoraConfig:

config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    modules_to_save=[last_module_name],
)

Full example notebook: https://colab.research.google.com/drive/1mGpLQk8VMFfh_jcMGPaTfygGOdqlUUTs?usp=sharing

The problematic part is modules_to_save=[last_module_name], which is the model's output layer. If we replace this with modules_to_save=["q_proj", "v_proj"] (as is common for LoRA on transformer models), the problem disappears - we can save & load the PEFT model as expected.

I'm not sure whether this can be considered a bug. On the one hand, modules_to_save=[last_module_name] is maybe not a sensible configuration. On the other hand, even with that config, the model learns the task, and the saving & loading fails silently, so not ideal. 🧐

@younesbelkada
Copy link
Contributor

Hi @ingo-m
Thanks for the heads up ! Will have a look on what you have shared ASAP

@sjrl
Copy link
Contributor

sjrl commented Jul 24, 2023

I wonder if this issue #602 is related. The workarounds found in the linked issue was disabling load_in_{4,8}_bit or saving the state_dict of the output layer directly using torch (same solution as referenced here).

@ingo-m
Copy link
Author

ingo-m commented Jul 24, 2023

@sjrl Interesting, looks like it's the same problem. In the original example I had load_in_8bit=False, but here #602 (comment) it's mentioned that the bug can also be triggered if device_map is not None (I used device_map="auto" so that makes sense).

@BenjaminBossan
Copy link
Member

@ingo-m Maybe your problem is also solved by the fix in #755, maybe you could check it out if you have time.

@ingo-m
Copy link
Author

ingo-m commented Jul 27, 2023

@BenjaminBossan Yes I can confirm that the issue is fixed when installing PEFT from the current main branch (!pip install git+https://github.com/huggingface/peft.git).

Complete example: https://colab.research.google.com/drive/1XMe3iVaSBPjgo72RU9zPaQaZaANNWHHf?usp=sharing

@ingo-m ingo-m closed this as completed Jul 27, 2023
@LazerJesus
Copy link

+1. This was such an annoying issue for way too long. but i can confirm that this works now.

ADAPTER_DIR = "./peftmodel_save_pretrained"
model = get_peft_model(AutoModelForX.from_pretrained({...}), LoraConfig({...}))
model.train()
model.save_pretrained(ADAPTER_DIR)

... 2 days later...

PeftModel.from_pretrained(AutoModelForX.from_pretrained({...}), ADAPTER_DIR, is_trainable=True)
model.train()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants