-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BOFT bug fix when saving #1994
BOFT bug fix when saving #1994
Conversation
Thanks for adding this fix. Do you have an example where lack of contiguity leads to an error? |
sure, here is where I noticed the error, a small code snippet for training step: oft_config = BOFTConfig(
boft_block_size=args.block_size, # 32
boft_n_butterfly_factor=args.n_butterfly_factor, # 1
boft_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
train_data.start_iteration = 0
print("Starting main loop")
training_args = SFTConfig(
output_dir=args.output_dir,
dataloader_drop_last=True,
eval_strategy="steps",
num_train_epochs=args.num_train_epochs,
eval_steps=args.eval_freq,
save_strategy=args.save_strategy,
logging_steps=args.log_freq,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
learning_rate=args.learning_rate,
lr_scheduler_type=args.lr_scheduler_type,
warmup_steps=args.num_warmup_steps,
gradient_accumulation_steps=args.gradient_accumulation_steps,
gradient_checkpointing=args.gradient_checkpointing,
fp16=args.fp16,
bf16=args.bf16,
weight_decay=args.weight_decay,
run_name="llama-7b-finetuned",
report_to="wandb",
ddp_find_unused_parameters=False,
disable_tqdm=False,
max_seq_length=args.seq_length,
dataset_text_field = "text",
)
model = AutoModelForCausalLM.from_pretrained(
args.model_path,
)
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=train_data,
peft_config=oft_config,
packing=False,
)
print_trainable_parameters(trainer.model)
trainer.train()
print("Saving last checkpoint of the model")
trainer.model = trainer.model.merge_and_unload()
trainer.save_model(os.path.join(args.output_dir, str(args.num_train_epochs)))
```
error occur at trainer.save_model, it states that some tensors are not contiguous, I noticed it is because of the torch.transpose() and torch.mm() operation performed during the def merge() function |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, here is where I noticed the error, a small code snippet for training step:
Thanks, I can confirm that this raises the error. Let's add a unit test to ensure that this works correctly. I think the best location would be after this line:
Line 764 in c869664
assert torch.allclose(logits_peft, logits_unloaded, atol=atol, rtol=rtol) |
Here is the code that I used:
# serializing works without errors
with tempfile.TemporaryDirectory() as tmp_dirname:
# serializing with torch.save works
torch.save(model_unloaded.state_dict(), os.path.join(tmp_dirname, "model.bin"))
# serializing with safetensors works
save_file(model_unloaded.state_dict(), os.path.join(tmp_dirname, "model.safetensors"))
The save_file
function must be imported: from safetensors.torch import save_file
.
When I ran these extended tests, I noticed that OFT is also failing. However, this can be resolved in the same way in this line:
peft/src/peft/tuners/oft/layer.py
Line 174 in c869664
base_layer.weight.data = new_weights |
- base_layer.weight.data = new_weights
+ base_layer.weight.data = new_weights.contiguous()
Would you be so kind to fix that too?
src/peft/tuners/boft/layer.py
Outdated
orig_weight = torch.transpose(orig_weight, 0, 1) | ||
orig_weight = torch.transpose(orig_weight, 0, 1).contiguous() | ||
orig_weight = torch.mm(butterfly_oft_mat, orig_weight).contiguous() | ||
orig_weight = torch.transpose(orig_weight, 0, 1).contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we can simplify this by only calling .contiguous()
at the very last step, i.e. line 532:
- self.base_layer.weight.data = orig_weight
+ self.base_layer.weight.data = orig_weight.contiguous()
Same for line 532. And since there is also Conv2d
, should the call be added there as well? So that would be lines 820 and 834.
Hi, thanks for the comment. I updated the unit test and also the self.base_layer.weight.data = orig_weight.contiguous() for BOFT and OFT. Best, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates. Just two very small change requests, otherwise this looks good.
tests/testing_common.py
Outdated
@@ -763,6 +763,10 @@ def _test_safe_merge(self, model_id, config_cls, config_kwargs): | |||
# check that the logits are the same after unloading | |||
assert torch.allclose(logits_peft, logits_unloaded, atol=atol, rtol=rtol) | |||
|
|||
# serializing with safetensors works | |||
from safetensors.torch import save_file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's move this import to the top of the file.
tests/testing_common.py
Outdated
@@ -763,6 +763,10 @@ def _test_safe_merge(self, model_id, config_cls, config_kwargs): | |||
# check that the logits are the same after unloading | |||
assert torch.allclose(logits_peft, logits_unloaded, atol=atol, rtol=rtol) | |||
|
|||
# serializing with safetensors works |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# serializing with safetensors works | |
# Ensure that serializing with safetensors works, there was an error when weights were not contiguous |
Can you check again? I noticed in the unit test I did not copy the whole test in the previous commit. Best. |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@BenjaminBossan What do I need to change in testing_common.py?
|
I think your last commit was the missing piece 🤞 |
@BenjaminBossan Hi, the tests are still failing, are they because of my added code? |
No, I don't think so, this looks more like a HF Hub issue. This can sometimes happen, but not with the whole test matrix. I'll restart later, don't worry about it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fixes to merging BOFT and OFT.
fixing non-contiguous tensor when saving the model after merge_and_unload()