Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finetuned model does not load trained weights properly, but instead uses random initialization #577

Closed
2 of 4 tasks
ilektram opened this issue Jun 14, 2023 · 14 comments
Closed
2 of 4 tasks

Comments

@ilektram
Copy link

ilektram commented Jun 14, 2023

System Info

torch==2.0.0
peft==0.3.0
accelerate==0.20.3
transformers==4.30.1

Who can help?

@pacman100 @youn

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Finetuned PEFT model based on LlamaForSequenceClassification does not load trained weights of the final layer from local and instead initialises them randomly each time the model is loaded.

import torch
from peft import PeftConfig, PeftModelForSequenceClassification, get_peft_model, get_peft_model_state_dict
from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
from transformers.generation.utils import GreedySearchDecoderOnlyOutput

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

try:
    if torch.backends.mps.is_available():
        device = "mps"
except:  
    pass

id2label = {0: "ham", 1: "spam"}
label2id = {"ham": 0, "spam": 1}

load_8bit = True

peft_model_id = "experiments_openllama"
config = PeftConfig.from_pretrained(peft_model_id,)

base_model = config.base_model_name_or_path
lora_weights = peft_model_id
if device == "cuda":
    model = LlamaForSequenceClassification.from_pretrained(
        base_model,
        load_in_8bit=load_8bit,
        torch_dtype=torch.float16,
        device_map="auto",
        num_labels=2,
        id2label=id2label,
        label2id=label2id,
        problem_type="single_label_classification",
    )
    model = PeftModelForSequenceClassification.from_pretrained(
        model, lora_weights, torch_dtype=torch.float16, config=config
    )
elif device == "mps":
    model = LlamaForSequenceClassification.from_pretrained(
        base_model, device_map={"": device}, torch_dtype=torch.float16,
    )
    model = PeftModelForSequenceClassification.from_pretrained(
        model,
        lora_weights,
        device_map={"": device},
        torch_dtype=torch.float16,
        config=config,
    )
else:
    model = LlamaForSequenceClassification.from_pretrained(
        base_model, device_map={"": device}, low_cpu_mem_usage=True
    )
    model = PeftModelForSequenceClassification.from_pretrained(
        model, lora_weights, device_map={"": device}, config=config
    )

model.load_state_dict(
    torch.load(os.path.join(lora_weights, "adapter_model.bin")), strict=False
)

print(get_peft_model_state_dict(model)['base_model.model.score.weight'])

happy to share the fine-tuned model if there is a way to upload a compressed version.
I have observed a similar behaviour when trying to load via both the torch state_dict() method and from_pretrained() using the respective models. Both seem to not do a proper loading of the model fine-tuned weights.

Expected behavior

When loading a fine-tuned model into peft from the local disk, the model weights should always be the same and the validation set accuracy should be the same as the one logged against the latest epoch in the relevant trainer_state.json file produced during training and evaluation.

@younesbelkada
Copy link
Contributor

Hi @ilektram
Thanks for the issue,
from what I can see you are calling model.load_state_dict with strict=False (therefore ignoring the mistmatched keys), if you set that to True you will probably get a lot of errors.
In PEFT we use the method load_adapter :

model.load_adapter(model_id, adapter_name, **kwargs)
and directly pass the path where the adapter weights are stored. Can you please try with that method instead? Thanks!

@ilektram
Copy link
Author

ilektram commented Jun 16, 2023

Thank you for the reply.

In the meantime I have tried using the following method:

id2label = {0: "ham", 1: "spam"}
label2id = {"ham": 0, "spam": 1}

load_8bit = True
# Load peft config for pre-trained checkpoint etc.
peft_model_id = "experiments_openllama/checkpoint-478"
# peft_model_id = "results/experiments_openllama/checkpoint-478"



CUTOFF_LEN = 512


LORA_R = 8
LORA_ALPHA = 16
LORA_DROPOUT = 0.05
LORA_TARGET_MODULES = [
    "q_proj",
    "v_proj",
]
BASE_MODEL = "openlm-research/open_llama_7b"

config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    target_modules=LORA_TARGET_MODULES,
    lora_dropout=LORA_DROPOUT,
    bias="none",
    task_type="SEQ_CLS",
    inference_mode=True,
)
model = LlamaForSequenceClassification.from_pretrained(
    BASE_MODEL,
    load_in_8bit=True,  # TODO these 3 lines, when commented in result in faster inference but random results that are different each time
    torch_dtype=torch.float16,
    device_map="auto",
    num_labels=2,
    id2label=id2label,
    label2id=label2id,
    problem_type="single_label_classification",
)
model = prepare_model_for_int8_training(model)
model = get_peft_model(model, config)

set_peft_model_state_dict(
    model, torch.load(os.path.join(peft_model_id, "adapter_model.bin"))
)

I have observed that the weight tensors are populated however the inference issue persists as I am still obtaining accuracy that is around 50%.

The odd thing is that each time I load the model and run inference on the validation set the result is slightly different (but always between 45-55% as opposed to validation run after training where accuracy was always 93%). I am therefore wondering if there is a parameter that introduces randomness in the reloaded model? I have used model.eval() but that does not resolve the issue.

I have also tried replacing the line

set_peft_model_state_dict(
    model, torch.load(os.path.join(peft_model_id, "adapter_model.bin"))
)

with

model.load_adapter(peft_model_id, 'adapter_model')

but the same behaviour was observed.

Evaluation metrics:

print(metric.compute(predictions=preds, references=predictions.label_ids))
{'accuracy': 0.5342394145321484, 'f1': 0.316193399846508}

I should add as a note that the trained model was saved using the respective callback method.

@pacman100
Copy link
Contributor

Hello @ilektram, I'm running https://github.com/huggingface/peft/blob/main/examples/sequence_classification/LoRA.ipynb and it works as expected without any performance issues post loading.

Here, is the mode: https://huggingface.co/smangrul/roberta-large-peft-lora-latest

import torch
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer

peft_model_id = "smangrul/roberta-large-peft-lora-latest"
config = PeftConfig.from_pretrained(peft_model_id)
inference_model = AutoModelForSequenceClassification.from_pretrained(config.base_model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)

# Load the Lora model
inference_model = PeftModel.from_pretrained(inference_model, peft_model_id)

inference_model.to(device)
inference_model.eval()
for step, batch in enumerate(tqdm(eval_dataloader)):
    batch.to(device)
    with torch.no_grad():
        outputs = inference_model(**batch)
    predictions = outputs.logits.argmax(dim=-1)
    predictions, references = predictions, batch["labels"]
    metric.add_batch(
        predictions=predictions,
        references=references,
    )

eval_metric = metric.compute()
print(eval_metric)

Output:

Downloading (…)/adapter_config.json: 100%
406/406 [00:00<00:00, 59.7kB/s]
Some weights of the model checkpoint at roberta-large were not used when initializing RobertaForSequenceClassification: ['lm_head.layer_norm.bias', 'lm_head.layer_norm.weight', 'lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.out_proj.weight', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Downloading adapter_model.bin: 100%
7.39M/7.39M [00:00<00:00, 10.0MB/s]
  0%|                                                                        | 0/13 [00:00<?, ?it/s]You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
100%|███████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  7.92it/s]
{'accuracy': 0.8921568627450981, 'f1': 0.9236111111111112}

@ilektram
Copy link
Author

Can I ask, is there a reason why you load the model with AutoModelForSequenceClassification & AutoTokenizer rather than the Roberta specific classes?

@thohag
Copy link

thohag commented Jun 17, 2023

Experienced a similar issue, turned out that dataset shuffling was non-deterministic. Still observing some differences between evaluation and loaded adapter, but <1%.

@hanyin88
Copy link

hanyin88 commented Jul 7, 2023

Hello @ilektram, I am wondering if you find a solution to your problem?

Dear @younesbelkada @pacman100 , I am running the similar issue despite updating Transformer to the latest version (4.30.2; PS thanks for many fix related to saving PEFT in that version!).

In particular, I feel like I can't load the actual fine tuned model but instead the baseline model appear to be loaded.

Here's my code:

loraweight = "/data/experiments/test_run"
config = PeftConfig.from_pretrained(loraweight)
tokenizer = LlamaTokenizer.from_pretrained(config.base_model_name_or_path, model_max_length=512)
model =  LlamaForSequenceClassification.from_pretrained(config.base_model_name_or_path, num_labels=738, 
                                                        torch_dtype=torch.float16,  device_map="auto")
model = PeftModel.from_pretrained(model, loraweight)

trainer = Trainer(
    model=model,
    args=training_args,
    eval_dataset=tokenized_test,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    )

metrics= trainer.evaluate()

My evaluation metric at the end of training is: {'eval_train_loss': 5.794477462768555, 'eval_train_microF1': 0.08219178082191782, 'eval_train_macroF1': 0.017915627236747927, 'eval_train_microAUC': 0.828394538670285, 'eval_train_macroAUC': 0.5873078248178315, 'eval_train_labels': 128, 'eval_train_count': 200, 'eval_train_acc10': 0.21, 'eval_train_acc5': 0.185, 'eval_train_acc': 0.075}

However, with the loaded model the evaluation metric is: {'eval_loss': 7.802577972412109, 'eval_microF1': 0.0, 'eval_macroF1': 0.0, 'eval_microAUC': 0.5221097862957937, 'eval_macroAUC': 0.532642330878301, 'eval_labels': 128, 'eval_count': 200, 'eval_acc10': 0.015, 'eval_acc5': 0.0, 'eval_acc': 0.0}.

I have confirmed that the evaluation dataset is identical with 200 samples. I ran the evaluation of loaded model using both the trainer.evaluate() function and naive torch evaluation function and got the same result.

On a separate note and related to another post, here is my lora config during training phase:

lora_target_modules: List[str] = [
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
    ]

config = LoraConfig(
        r=lora_r,
        lora_alpha=lora_alpha,
        target_modules=lora_target_modules,
        lora_dropout=lora_dropout,
        bias="none",
        task_type=TaskType.SEQ_CLS
    )

Not sure if I need to change anything in modules_to_save as suggested in the other post. My previous understanding is that when you set TaskType to SEQ_CLS, the module_to_save will be set to the last classifier layer ('score' in my case). Not sure if this is the problem.

Deeply appreciate any guidance!

Update on 07.07 @ilektram:

Seems the problem is indeed with the modules_to_saves in LoraConfig as suggested in aforementioned post. Loaded model worked as expected after I adjusted my code as below:

lora_target_modules: List[str] = [
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
    ]

config = LoraConfig(
        r=lora_r,
        lora_alpha=lora_alpha,
        target_modules=lora_target_modules,
        lora_dropout=lora_dropout,
        bias="none",
        task_type=TaskType.SEQ_CLS,
        modules_to_save=lora_target_modules.append("score"),
    )

Of interest, while previously the trainable parameters at score layer looks like this:

base_model.model.score.original_module.weight torch.Size([738, 4096])
base_model.model.score.modules_to_save.default.weight torch.Size([738, 4096])

Now it looks like this:

base_model.model.score.original_module.weight torch.Size([738, 4096])
base_model.model.score.original_module.lora_A.default.weight torch.Size([8, 4096])
base_model.model.score.original_module.lora_B.default.weight torch.Size([738, 8])
base_model.model.score.modules_to_save.default.weight torch.Size([738, 4096])
base_model.model.score.modules_to_save.default.lora_A.default.weight torch.Size([8, 4096])
base_model.model.score.modules_to_save.default.lora_B.default.weight torch.Size([738, 8])

I am quite puzzled as per the official tutorial, I shouldn't need to specify module_to_save again. There appears to be a bug.

Deeply appreciate some clarification.

@ilektram
Copy link
Author

ilektram commented Jul 7, 2023 via email

@vincentmin
Copy link

I am encountering the same issue where loading the model multiple times gives different results for the same input and accuracy hovers around 50%. I described my situation in more detail here:
huggingface/trl#578

@BenjaminBossan
Copy link
Member

@hanyin88 I think this line is not doing what you think it does:

config = LoraConfig(
        r=lora_r,
        lora_alpha=lora_alpha,
        target_modules=lora_target_modules,
        lora_dropout=lora_dropout,
        bias="none",
        task_type=TaskType.SEQ_CLS,
        modules_to_save=lora_target_modules.append("score"),  # <==
    )

lora_target_modules.append("score") returns None, so you end up with modules_to_save=None. At the same time, you appended "score" to lora_target_modules, i.e. to config.target_modules.

Also, in general, the layers targeted for LoRA should not be added to modules_to_save. modules_to_save will do a full parameter training on the layers, so you lose all the advantages of LoRA.

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@TuronLab
Copy link

TuronLab commented Aug 30, 2023

Hi,

I found the same problem training a BLOOM model for sequence classification when using peft v0.4.0.

After some efforts, I encountered that all the adapters weights were correctly loaded after building the PEFT model with PeftModelForSequenceClassification.from_pretrained(model, adapters_path), but it doesn't load the weights of the classification head, as it still maintains the weights that have been randomly set when assembling the model.

As far as I understand, it seems that there is a missmatch with the name of the classification layer of the model and the name that the PEFT model expects for the classification layer, which causes the peft model to keep the weights randomly initialised.

In this way, I was able to load the classification head just loading the weights manually (you may have other layer names depending the model that you're using):

import torch
from torch.nn import Parameter

adapters_weights = torch.load(os.path.join(adapters_path, 'adapter_model.bin'))
model = AutoModelForSequenceClassification.from_pretrained(
            peft_config.base_model_name_or_path,
            config=model_conf,
        )
# Load the weights of the trained classification head (you may need to modify your tensor if your using half precission)
model.score.weight = Parameter(adapters_weights['base_model.model.score.weight'])
model = PeftModelForSequenceClassification.from_pretrained(self.model, adapters_path)

If PEFT didn't save the weights of the classification head in the adapter_model.bin, you may need to specify the task_type when training in the LoraConfig parameters. If it still doesn't save them, you may need to specify which layer it should save using the modules_to_save field in the LoraConfig parameters.

Nevertheless, it looks like this has been fixed with the release of peft v0.5.0 .

@hanyin88
Copy link

hanyin88 commented Aug 30, 2023

Thanks for kind following up on the issue!

@BenjaminBossan Thanks for the kind insight! You are totally right that, what I accidentally or coincidentally did was to set modules_to_save=None and add "score" layer to lora_target_modules. So this effectively adds lora on the final score layer, but skipped using modules_to_save. I double checked that with this approach, all model parameters can be loaded correctly after training.

@TuronLab I think you likely find the fundamental problem here and your solution seems even better. :) As far as I know this bug only applies to AutoModelForSequenceClassification and not other function. Just to make sure everyone is on the same page, could you kindly share the LoraConfig you used? I think by setting task_type =TaskType.SEQ_CLS, modules_to_save will automatically be the last output layer.

@BenjaminBossan
Copy link
Member

modules_to_save will automatically be the last output layer

To expand on this, it will add "classifier" and "score" to the modules_to_save, which are common names for the classification head. If those match the name of the last layer, it will automatically be saved. If not, you should add it explicitly to modules_to_save.

@TuronLab
Copy link

Hi, I'm glad you found this useful.

Yes, as @BenjaminBossan says, when you specify that you're going to solve a sequence classification problem, it will automatically try to save "classifier" and "score" layers; as is depicted in the init of the PeftModelForSequenceClassification class #L717. This was enough to me for saving the classification head which I need, since the classification head of the BLOOM model is named "score" (see BloomForSequenceClassification class, #L995). You will need to check that the model you are training follows the same nomenclature, otherwise you will need to specify it manually.

Here is the LoraConfig that I have used to train

    lora_config = LoraConfig(
        r=8,
        lora_alpha=32,
        target_modules=["query", "value"],
        lora_dropout=0.5,
        bias="none",
        inference_mode=False,
        task_type=TaskType.SEQ_CLS,
    )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

8 participants