Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Add] DoRA Embedding #2006

Merged
merged 7 commits into from
Aug 23, 2024
Merged

Conversation

ariG23498
Copy link
Contributor

@ariG23498 ariG23498 commented Aug 14, 2024

This PR adds DoRA for the embedding layer.

TODOs:

  • Check whether this implementation is correct
  • Write tests for the implementation

@BenjaminBossan can you give me an initial review to set the tone of the PR?
Fixes: #1677

@BenjaminBossan
Copy link
Member

Thanks @ariG23498, this looks good so far. Probably there could be some refactoring in dora.py as the code duplication accumulates, but let's keep it as is for now.

Unfortunately, I don't think there is any reference that we can compare this to, as the DoRA paper does not mention application to embedding layers (unless I missed it). Probably a good next step would be to take one of the existing examples and check what happens when the new DoRA embedding layer is used. I would expect results to be very similar to LoRA embeddings, with (hopefully) a slight edge in scores for DoRA.

@sayakpaul
Copy link
Member

We could also request for a review from the DoRA author (sorry I forgot their name) like we did when we addrd Conv DoRA support.

@ariG23498
Copy link
Contributor Author

ariG23498 commented Aug 16, 2024

@sayakpaul I agree. That would be great. I looked into the previous PRs and got hold of the author's GitHub username nbasyl. I cannot @ him in this PR. Maybe @BenjaminBossan can?

On the other hand, I have used the QDoRA finetuning example and added the embedding layer in the LoRA config. Here is the colab notebook. Let me know what you all think.

@BenjaminBossan
Copy link
Member

BenjaminBossan commented Aug 16, 2024

Yes, we could ping nbasyl but let's wait until the PR is ready for review or until we have concrete questions to ask.

Thanks for providing the example notebook. I left it running for a bit longer (also increased batch size to 2) and it appears that the presence of the DoRA embedding layer destabilizes the loss.

Step Train Loss w/ embedding Train Loss w/o embedding Train Loss LoRA w/ embedding
1 1.618500 1.618500 1.820700
2 2.022700 2.022700 1.574100
3 1.945600 1.875700 2.133800
4 1.202400 1.103300 2.512300
5 1.750000 1.550400 2.857700
6 1.878900 1.489500 2.298900
7 2.393900 1.128800 3.082700
8 2.650000 1.485200 2.467200
9 3.748100 1.486900 2.687300
10 2.576500 1.257300 2.307900
11 2.309400 1.323600 3.414000
12 2.405300 1.302100 2.644900
13 3.725500 1.107000 2.551000
14 2.729300 1.712600 2.283900
15 3.939800 1.561200 2.173500
16 2.531300 1.121500 2.497700
17 2.303800 1.077900 2.501000
18 2.544800 1.680800 2.314400
19 2.462000 1.553900 2.776200
20 2.115400 1.218000 2.103700

I'm not exactly sure why that is but I checked the adapter weights on the layers and found that the L2 norm of the embedding layer is larger than that of the other layers:

embed_tokens.lora_embedding_A.default: 29.023
embed_tokens.lora_embedding_B.default: 181.432
embed_tokens.lora_magnitude_vector.default.weight: 213.857
layers.0.self_attn.q_proj.lora_A.default.weight: 1.633
layers.0.self_attn.q_proj.lora_B.default.weight: 0.123
layers.0.self_attn.q_proj.lora_magnitude_vector.default.weight: 72.825
layers.0.self_attn.k_proj.lora_A.default.weight: 1.644
layers.0.self_attn.k_proj.lora_B.default.weight: 0.068
layers.0.self_attn.k_proj.lora_magnitude_vector.default.weight: 52.904
... values for linear layers are all very similar
layers.31.mlp.down_proj.lora_A.default.weight: 1.647
layers.31.mlp.down_proj.lora_B.default.weight: 0.119
layers.31.mlp.down_proj.lora_magnitude_vector.default.weight: 89.978

Maybe this means that the lr for embedding layers should be lowered? I haven't checked further, but lora+ could be an option here. Also, it could just be a fluke on this dataset.

@ariG23498
Copy link
Contributor Author

@BenjaminBossan could the L2 norm of the embedding weights be large due to the size of the embedding weights?

I created a simple base model with torch and printed the L2 norm of the weights as follows:

class BaseModel(torch.nn.Module):
    def __init__(self, vocab_size, emb_dims, out_dims):
        super().__init__()
        self.emb = torch.nn.Embedding(num_embeddings=vocab_size, embedding_dim=emb_dims)
        self.proj = torch.nn.Linear(in_features=emb_dims, out_features=out_dims)
    
    def forward(self, x):
        x = self.emb(x)
        x = self.proj(x)
        return x

model = BaseModel(
    vocab_size=100,
    emb_dims=16,
    out_dims=32,
)

for name, param in model.named_parameters():
    print(name, torch.linalg.norm(param, ord=2).item())
emb.weight 13.42896842956543 # this is significantly larger than the projection layer
proj.weight 1.2026439905166626
proj.bias 0.7463299036026001

Do you want me to train the model using the lora+ optimizer?

@BenjaminBossan
Copy link
Member

Thanks for checking. Note that I didn't check the norm of the embedding layer itself but of the associated LoRA weights. Still, the norm could of course be related to the overall shape of the layer. Maybe this check makes no sense, but I just wanted to get a hint on what could be going wrong.

I did another run with LoRA instead of DoRA applied to the embedding and linear layers and there the loss look similarly bad as for DoRA on embedding. So it's most likely more of an issue with adapting the embedding layer and less with DoRA. Still, it would be nice if we could find an example of using DoRA embedding that indeed improves model performance before merging the PR. LoRA+ was just one idea.

@ariG23498
Copy link
Contributor Author

@BenjaminBossan I noticed that @pacman100 had a working E2E example on training an embedding layer with LoRA here. Unfortunately I do not have the GPU compute to run the example. Would you mind running it with the current DoRA implementation?

@BenjaminBossan
Copy link
Member

Good point. I used this notebook instead (with some minor changes like using bfloat16), which is similar but a bit more up-to-date. Below are some results:

image

As you can see, LoRA, DoRA, and DoRA & LoRA+ are all very similar. I also noticed that re-running the same notebook with a different seed yielded quite different results, so these differences are not significant. I think we can proceed with this PR, as we have no indication so far that DoRA works worse on embeddings than LoRA.

@ariG23498 ariG23498 marked this pull request as ready for review August 19, 2024 18:08
@ariG23498 ariG23498 changed the title [WIP][Add] DoRA Embedding [Add] DoRA Embedding Aug 20, 2024
@ariG23498
Copy link
Contributor Author

@BenjaminBossan I have marked the PR for review.

@BenjaminBossan
Copy link
Member

Would you mind adding some tests for DoRA embeddings? There are already some LoRA embedding test cases in test_custom_models.py that can be copied and modified to use DoRA. There are also some initialization tests that need to be extended.

Comment on lines 804 to 815
if not self.use_dora[active_adapter]:
after_A = self._embed(x, embedding_A)
result = result + (after_A @ embedding_B) * scaling
else:
result = result + self.lora_magnitude_vector[active_adapter](
x,
lora_A=embedding_A,
lora_B=embedding_B,
scaling=scaling,
base_layer=self.get_base_layer(),
embed_fn=self._embed,
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the previous commits, the DoRA embedding layer did not forward propagate the inputs. With this code snippet, we are not forward propagating the inputs through the DoRAEMbedding Layer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BenjaminBossan for visibility.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this code snippet, we are not forward propagating the inputs through the DoRAEMbedding Layer.

I don't understand.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I apologise for not being able to communicate well.

I meant, the code snippet added in this commit (the one that is linked) now lets the inputs to propagate into the dora embedding layer. Previous to this, the dora embedding layer was being created, but not used. The inputs were being forward propagated into the Dora Linear Layer (as it was the parent class).

This also means that the results you have noted in this comment should be re-run in order to make use of the embedding adapters properly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, I think the confusion comes from the "not" in the quoted snippet, which is probably meant to be "now".

I reran the experiments with the new branch and the losses are now much closer:

image

It's still seed dependent, but at least there does not appear to be anything completely off.

@ariG23498
Copy link
Contributor Author

@BenjaminBossan with the current implementation, I have a test case failing. I have isolated the test case and am running it individually like so:

import copy
import torch
from torch import nn
from peft import get_peft_model, LoraConfig
from transformers.pytorch_utils import Conv1D


class ModelEmbConv1D(nn.Module):
    def __init__(self, emb_size=100):
        super().__init__()
        self.emb = nn.Embedding(emb_size, 5)
        self.conv1d = Conv1D(1, 5)
        self.relu = nn.ReLU()
        self.flat = nn.Flatten()
        self.lin0 = nn.Linear(10, 2)
        self.sm = nn.LogSoftmax(dim=-1)

    def forward(self, X):
        # print("getting in embedding")
        X = self.emb(X)
        X = self.conv1d(X)
        X = self.relu(X)
        X = self.flat(X)
        X = self.lin0(X)
        X = self.sm(X)
        return X


# INPUTS
X = torch.arange(90).view(9, 10)
X = {"X":X}

# MODEL
config_cls = LoraConfig
config = LoraConfig(
    target_modules=[
        "emb",
        "conv1d"
    ],
    use_dora=True,
)

model = ModelEmbConv1D()
model = get_peft_model(model, config)
model_before = copy.deepcopy(model)

model.train()
optimizer = torch.optim.SGD(model.parameters(), lr=0.5)

# train at least 3 steps for all parameters to be updated (probably this is required because of symmetry
# breaking of some LoRA layers that are initialized with constants)
for _ in range(3):
    optimizer.zero_grad()
    y_pred = model(**X)
    loss = y_pred.sum()
    loss.backward()
    optimizer.step()

tol = 1e-4
params_before = dict(model_before.named_parameters())
params_after = dict(model.named_parameters())
assert params_before.keys() == params_after.keys()

prefix = "lora_"
for name, param_before in params_before.items():
    param_after = params_after[name]
    if (prefix in name) or ("modules_to_save" in name):
        # target_modules and modules_to_save _are_ updated
        assert not torch.allclose(param_before, param_after, atol=tol, rtol=tol)
    else:
        assert torch.allclose(param_before, param_after, atol=tol, rtol=tol)

Somehow the lora_embedding_B does not get updated. And running the script sometimes results it everything to pass, thus making is a flaky issue. Could you point me to places that I can look into? I have tried printing the gradients of the parameters, and it shows that all the parameters (DoRA) have .grad with it, which should ideally be updated by the optimizer.

What are you thoughts on this?

@BenjaminBossan
Copy link
Member

Thanks for working on the tests and good job isolating this flaky one. I think this is a case of the learning rate being too small, therefore the changes are too small to exceed the tolerance level of the check. The solution is just to increase the learning rate in this specific setting. I tried with 1.0 and the test passed 10 times. You can check other tests in this file and you'll see that the learning rate has to be adjusted a couple of times, it's just not possible to find one that fits all cases.

@ariG23498
Copy link
Contributor Author

Updated the learning rate. The tests pass on my system.

Note: I have kept the lr considerably high 100.0.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates, this looks pretty good already. I have some minor comments regarding the tests, please check.

@@ -874,7 +887,8 @@ def test_only_params_are_updated(self, test_name, model_id, config_cls, config_k
model_before = copy.deepcopy(model)

model.train()
optimizer = torch.optim.SGD(model.parameters(), lr=0.5)
lr = 100.0 if config_kwargs.get("use_dora") else 0.5
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, is 100 really necessary? For me, even 1.0 seems to work (tested 10x on CPU and GPU).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah!

I ran the 1.0 experiment 1000x on CPU and GPU, and it did not go well. I tried increasing the lr in the following manner 1 -> 2 -> 4 -> 10 -> 50 -> 100. 100 was the only lr that passed for 1000x tests.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's wild. Hopefully such a high lr is only needed for this dummy model and not an indicator for real world usage (which does not seem so based on the tests above).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only required for dora + embedding, right? So let's adjust the check, same as below. And let's add a comment that this high learning rate was found through testing to be necessary to avoid flakiness.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated!

tests/test_custom_models.py Outdated Show resolved Hide resolved
tests/test_custom_models.py Outdated Show resolved Hide resolved
@BenjaminBossan
Copy link
Member

@nbasyl This PR adds DoRA to embedding layers. It's pretty similar to how DoRA is implemented for linear layers. If you have time, it would be great if you could give this a look.

@nbasyl
Copy link

nbasyl commented Aug 23, 2024

@BenjaminBossan @ariG23498 Thanks for tagging me and for the implementation; everything seems correct, just a few minor suggestions. From what I understand about adapting the embedding layer, there's no dropout on the input, so this part could potentially be optimized further to improve inference speed as we no longer need to tackle the dropout alignment issue as the adaptation of linear layer:
image

to something like:

        weight_norm = weight_norm.detach()
        mag_norm_scale = magnitude / weight_norm
        result_dora =  mag_norm_scale * (
            embed_fn(x, lora_A) @ lora_B
        ) * scaling
        return mag_norm_scale, result_dora

and the forward pass of embedding layer becomes:

            if not self.use_dora[active_adapter]:
                    after_A = self._embed(x, embedding_A)
                    result = result + (after_A @ embedding_B) * scaling
                else:
                     mag_norm_scale, dora_result = self.lora_magnitude_vector[active_adapter](
                        x,
                        lora_A=embedding_A,
                        lora_B=embedding_B,
                        scaling=scaling,
                        base_layer=self.get_base_layer(),
                        embed_fn=self._embed,
                    )
                    result = mag_norm_scale * result + dora_result 

not quite sure if the speed-up is significant or not, as we are only skipping one forward call of embed_fn(x, weight)

@ariG23498
Copy link
Contributor Author

@nbasyl thank you for reviewing the PR. I had totally overlooked the optimization opportunity in the inference time.

@BenjaminBossan I have made the suggested changes, and all the test pass on my end.

@BenjaminBossan
Copy link
Member

Yes, good catch @nbasyl. I haven't tested it specifically for DoRA embeddings, but in the past, I did test what would happen if we had that optimization on linear layers- There, it would specifically help reduce memory usage, so it is indeed good to have.

I was thinking about adding this optimization to linear layers when:

  1. Users choose a dropout rate of 0
  2. During inference mode

The disadvantage is of course that it makes the code more complicated and it would not help in most training cases, where it matters most.

Btw. I think this part of the DoRA implementation is often wrong when looking at how other packages implement it, but it makes their DoRA layers more efficient than PEFT :-/

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for updating the code with the suggested optimization.

I just found a tiny issue in one of the test, otherwise this is good to be merged.

For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer
output.
"""
lora_weight = (lora_A @ lora_B).T
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, too bad we can't use the same approach as in the other lora layers but instead have to multiply the parameters directly. This may cause trouble in some situations like with FSDP. But still, it's better than not having support for embeddings at all.

Copy link
Contributor Author

@ariG23498 ariG23498 Aug 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree!

I tried with a linear layer to mimic the embedding weights, but the code was getting a little complicated.

@@ -874,7 +887,8 @@ def test_only_params_are_updated(self, test_name, model_id, config_cls, config_k
model_before = copy.deepcopy(model)

model.train()
optimizer = torch.optim.SGD(model.parameters(), lr=0.5)
lr = 100.0 if config_kwargs.get("use_dora") else 0.5
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only required for dora + embedding, right? So let's adjust the check, same as below. And let's add a comment that this high learning rate was found through testing to be necessary to avoid flakiness.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot, LGTM.

@BenjaminBossan BenjaminBossan merged commit 900f96c into huggingface:main Aug 23, 2024
14 checks passed
@nbasyl
Copy link

nbasyl commented Sep 27, 2024

Hi @BenjaminBossan , I've been experimenting with DoRA finetuning on tasks that weren't covered in the paper, like Function Calling and Role Play. Based on my results, it doesn't seem to matter much whether dropout is applied or not. Given our previous discussions, I was wondering if we could integrate this more efficient forward pass into the codebase when the dropout rate is set to zero. Let me know your thoughts and how I can assist!

@nbasyl
Copy link

nbasyl commented Sep 27, 2024

Also tagging @ariG23498 to see if you'd be interested in helping with this. I'm not entirely sure if making these changes to the forward pass will require significant modifications to other parts of the code.

@ariG23498
Copy link
Contributor Author

I would ask @BenjaminBossan to take the lead on this one!

I would be more than happy to contribute if we see that this could potentially lead to a speedup as mentioned here

@BenjaminBossan
Copy link
Member

@nbasyl Thanks for the additional information. Just to clarify, do you mean that dropout in general is not that helpful, or just specifically with regards to the embedding layer (since you posted in this PR)?

I think we can take a look at implementing the more efficient path where we check if the model is in eval mode or if the dropout is 0 (i.e. using nn.Identity). In that case, we can pass the computed result from the base model and re-use that. I can put this on my TODO list.

@nbasyl
Copy link

nbasyl commented Sep 27, 2024

@BenjaminBossan, I'm talking about in general. You're right—we should reuse the computed output from the base model when dropout is set to 0 or when in evaluation mode. Feel free to let me know if there's any way I can help!

@BenjaminBossan
Copy link
Member

Great. I created an issue to track this (#2107).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

DoRA support for Embedding
5 participants