Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BOFT bug fix when saving #1994

Merged
merged 9 commits into from
Aug 7, 2024
Merged

BOFT bug fix when saving #1994

merged 9 commits into from
Aug 7, 2024

Conversation

Zeju1997
Copy link
Contributor

@Zeju1997 Zeju1997 commented Aug 7, 2024

fixing non-contiguous tensor when saving the model after merge_and_unload()

@BenjaminBossan
Copy link
Member

Thanks for adding this fix. Do you have an example where lack of contiguity leads to an error?

@Zeju1997
Copy link
Contributor Author

Zeju1997 commented Aug 7, 2024

Thanks for adding this fix. Do you have an example where lack of contiguity leads to an error?

sure, here is where I noticed the error, a small code snippet for training step:

    oft_config = BOFTConfig(
        boft_block_size=args.block_size, # 32
        boft_n_butterfly_factor=args.n_butterfly_factor, # 1
        boft_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
    )

    train_data.start_iteration = 0

    print("Starting main loop")

    training_args = SFTConfig(
        output_dir=args.output_dir,
        dataloader_drop_last=True,
        eval_strategy="steps",
        num_train_epochs=args.num_train_epochs,
        eval_steps=args.eval_freq,
        save_strategy=args.save_strategy,
        logging_steps=args.log_freq,
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size,
        learning_rate=args.learning_rate,
        lr_scheduler_type=args.lr_scheduler_type,
        warmup_steps=args.num_warmup_steps,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        gradient_checkpointing=args.gradient_checkpointing,
        fp16=args.fp16,
        bf16=args.bf16,
        weight_decay=args.weight_decay,
        run_name="llama-7b-finetuned",
        report_to="wandb",
        ddp_find_unused_parameters=False,
        disable_tqdm=False,
        max_seq_length=args.seq_length,
        dataset_text_field = "text",
    )

    model = AutoModelForCausalLM.from_pretrained(
        args.model_path, 
    )

    trainer = SFTTrainer(
        model=model,
        args=training_args,
        train_dataset=train_data,
        peft_config=oft_config,
        packing=False,
    )

    print_trainable_parameters(trainer.model)

    trainer.train()
    
    print("Saving last checkpoint of the model")

    trainer.model = trainer.model.merge_and_unload()
    trainer.save_model(os.path.join(args.output_dir, str(args.num_train_epochs)))
    ```
    error occur at trainer.save_model, it states that some tensors are not  contiguous, I noticed it is because of the torch.transpose() and torch.mm() operation performed during the def merge() function

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, here is where I noticed the error, a small code snippet for training step:

Thanks, I can confirm that this raises the error. Let's add a unit test to ensure that this works correctly. I think the best location would be after this line:

assert torch.allclose(logits_peft, logits_unloaded, atol=atol, rtol=rtol)

Here is the code that I used:

        # serializing works without errors
        with tempfile.TemporaryDirectory() as tmp_dirname:
            # serializing with torch.save works
            torch.save(model_unloaded.state_dict(), os.path.join(tmp_dirname, "model.bin"))

            # serializing with safetensors works
            save_file(model_unloaded.state_dict(), os.path.join(tmp_dirname, "model.safetensors"))

The save_file function must be imported: from safetensors.torch import save_file.

When I ran these extended tests, I noticed that OFT is also failing. However, this can be resolved in the same way in this line:

base_layer.weight.data = new_weights

-                base_layer.weight.data = new_weights
+                base_layer.weight.data = new_weights.contiguous()

Would you be so kind to fix that too?

orig_weight = torch.transpose(orig_weight, 0, 1)
orig_weight = torch.transpose(orig_weight, 0, 1).contiguous()
orig_weight = torch.mm(butterfly_oft_mat, orig_weight).contiguous()
orig_weight = torch.transpose(orig_weight, 0, 1).contiguous()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we can simplify this by only calling .contiguous() at the very last step, i.e. line 532:

- self.base_layer.weight.data = orig_weight
+ self.base_layer.weight.data = orig_weight.contiguous()

Same for line 532. And since there is also Conv2d, should the call be added there as well? So that would be lines 820 and 834.

@Zeju1997
Copy link
Contributor Author

Zeju1997 commented Aug 7, 2024

Hi, thanks for the comment. I updated the unit test and also the self.base_layer.weight.data = orig_weight.contiguous() for BOFT and OFT. Best,

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates. Just two very small change requests, otherwise this looks good.

@@ -763,6 +763,10 @@ def _test_safe_merge(self, model_id, config_cls, config_kwargs):
# check that the logits are the same after unloading
assert torch.allclose(logits_peft, logits_unloaded, atol=atol, rtol=rtol)

# serializing with safetensors works
from safetensors.torch import save_file
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this import to the top of the file.

@@ -763,6 +763,10 @@ def _test_safe_merge(self, model_id, config_cls, config_kwargs):
# check that the logits are the same after unloading
assert torch.allclose(logits_peft, logits_unloaded, atol=atol, rtol=rtol)

# serializing with safetensors works
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# serializing with safetensors works
# Ensure that serializing with safetensors works, there was an error when weights were not contiguous

@Zeju1997
Copy link
Contributor Author

Zeju1997 commented Aug 7, 2024

Can you check again? I noticed in the unit test I did not copy the whole test in the previous commit. Best.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Zeju1997
Copy link
Contributor Author

Zeju1997 commented Aug 7, 2024

@BenjaminBossan What do I need to change in testing_common.py?

ruff check src tests examples docs scripts docker
All checks passed!
ruff format --check src tests examples docs scripts docker
Would reformat: tests/testing_common.py
1 file would be reformatted, 188 files already formatted
make: *** [Makefile:10: quality] Error 1
Error: Process completed with exit code 2.

@BenjaminBossan
Copy link
Member

I think your last commit was the missing piece 🤞

@Zeju1997
Copy link
Contributor Author

Zeju1997 commented Aug 7, 2024

@BenjaminBossan Hi, the tests are still failing, are they because of my added code?

@BenjaminBossan
Copy link
Member

Hi, the tests are still failing, are they because of my added code?

No, I don't think so, this looks more like a HF Hub issue. This can sometimes happen, but not with the whole test matrix. I'll restart later, don't worry about it.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fixes to merging BOFT and OFT.

@BenjaminBossan BenjaminBossan merged commit 9988cb9 into huggingface:main Aug 7, 2024
14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants