Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add nn.Embedding Support to Lora #337

Merged
merged 4 commits into from
May 3, 2023
Merged

Conversation

Splo2t
Copy link
Contributor

@Splo2t Splo2t commented Apr 19, 2023

What does this PR do?

  1. Added nn.Embedding to the Lora Fine-tuning support Layer.
  2. Added a gpt-j Fine-tuning example with nn.Embedding Layer and 8-bit quantization using PEFT in lora.py.

Pull Request Description

This pull request includes the following changes:

  1. Added nn.Embedding to Lora Fine-tuning support Layer: With this update, the Lora Fine-tuning layer now supports nn.Embedding, allowing for better performance and efficiency during the fine-tuning process.

  2. GPT-J Fine-tuning example with nn.Embedding Layer and 8-bit quantization using PEFT in lora.py: To demonstrate the usage of GPT-J with Lora, we've added an example in the lora.py file. This example showcases fine-tuning GPT-J using nn.Embedding Layer and 8-bit quantization with the help of PEFT. This provides users with an efficient method for optimizing their GPT-J models.

@Splo2t Splo2t changed the title Lora embedding Add nn.Embedding Support to Lora Apr 19, 2023
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 20, 2023

The documentation is not available anymore as the PR was closed or merged.

@SOCSChamp
Copy link

This solved my issue in #349 (comment)

@Splo2t
Copy link
Contributor Author

Splo2t commented Apr 26, 2023

@pacman100 please help review.

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @Splo2t for adding the support to have LoRA layers in Embedding modules 🤗, this is really cool 🔥. Left a few suggestions.

Are there evaluation results on how this improves the performance of tasks? It would be great if you can tweet about this feature with some great examples, we can amplify the same from our end.

@SOCSChamp
Copy link

Thank you @Splo2t for adding the support to have LoRA layers in Embedding modules hugs, this is really cool fire. Left a few suggestions.

Are there evaluation results on how this improves the performance of tasks? It would be great if you can tweet about this feature with some great examples, we can amplify the same from our end.

I'm not sure if there's been any significant evaluations done so far, but I've ran into issues with training attention without embeddings. Training a LoRA on a dataset with a strange use of tokens that the model likely didn't see during full training results in a high probability of a random output during inference. This problem isn't present after fine tuning the same model on the same dataset.

My hypothesis is that if your data contains anything other than very standard text, the embedding layer must be trained in order to learn the new context of the input, along with attention. I have a training run going currently using this branch and I'll report back if you're interested.

@Splo2t
Copy link
Contributor Author

Splo2t commented Apr 27, 2023

Thank you @Splo2t for adding the support to have LoRA layers in Embedding modules 🤗, this is really cool 🔥. Left a few suggestions.

Are there evaluation results on how this improves the performance of tasks? It would be great if you can tweet about this feature with some great examples, we can amplify the same from our end.

I have been fine-tuning a GPT model for generating new documents. When using PEFT for fine-tuning, I couldn't achieve the desired generation results. However, when I implemented both Embedding Layer and Linear Layer with Lora based on other codes, I was able to obtain the desired outcome. Using the code from this PR, I achieved the same level of satisfaction.
Currently, I am using a text generator implemented with this PR code to provide a messenger bot service.

I have not yet benchmarked the scores for adding the Embedding Layer to Lora.

@pacman100
Copy link
Contributor

pacman100 commented Apr 27, 2023

Very interesting pointers @Splo2t and @SOCSChamp. And what is the impact on the trainable parameters percentage with and without targeting embedding layer?

@Splo2t
Copy link
Contributor Author

Splo2t commented Apr 27, 2023

Very interesting pointers @Splo2t and @SOCSChamp. And what is the impact on the trainable parameters percentage with and without targeting embedding layer?

Excluding the Embedding Layer: trainable params: 8257536 || all params: 6174759936 || trainable%: 0.13373047836009022

Including the Embedding Layer: trainable params: 8531968 || all params: 6175034368 || trainable%: 0.13816875326579559

The results above compare the trainable parameters of a version with the Embedding Layer and a version without it. The trainable parameters were obtained using the print_trainable_parameters() function.

@lewtun
Copy link
Member

lewtun commented Apr 27, 2023

Training a LoRA on a dataset with a strange use of tokens that the model likely didn't see during full training results in a high probability of a random output during inference. This problem isn't present after fine tuning the same model on the same dataset.

Hi @SOCSChamp are you referring to scenarios where additional special tokens are added to the tokenizer's vocabulary (related to #334)? If yes, I'm curious whether your training run gave the desired outputs at inference time and whether you can share how you're loading the model with the expanded vocabulary?

For context, I'm fine-tuning LLaMa models with special tokens like <|end|> and targeting the embed_tokens module with @Splo2t's branch isn't producing the expected outputs (generations never produce the <|end|> token).

@SOCSChamp
Copy link

Training a LoRA on a dataset with a strange use of tokens that the model likely didn't see during full training results in a high probability of a random output during inference. This problem isn't present after fine tuning the same model on the same dataset.

Hi @SOCSChamp are you referring to scenarios where additional special tokens are added to the tokenizer's vocabulary (related to #334)? If yes, I'm curious whether your training run gave the desired outputs at inference time and whether you can share how you're loading the model with the expanded vocabulary?

For context, I'm fine-tuning LLaMa models with special tokens like <|end|> and targeting the embed_tokens module with @Splo2t's branch isn't producing the expected outputs (generations never produce the <|end|> token).

No, I am not actually expanding the tokenizer vocabulary but using the stock tokenizer with tokens that are used in atypical ways. The tokens themselves (shouldn't be) new, but the context that they're used in and how they relate to each other certainly is.

I mentioned in #334 because the solution that was provided (adding the embedding layer to modules_to_save) wasn't working in my application.

@pacman100
Copy link
Contributor

pacman100 commented Apr 28, 2023

Hello everyone, so adding embedding layers to modules_to_save works fine. I ran into issues when using PEFT+INT8, so I just tried PEFT+gradient checkpointing to make sure that PEFT method works given the sufficient VRAM. Below is an experiment using GPT-J to include new tokens.

GPU used for experimentation: A100 80GB

VRAM consumed during training: 36GB

Colab Notebook: https://colab.research.google.com/drive/1Wjm3iMuS43iu9Q_1XwX0NJsFu2d1fZgJ?usp=sharing

important code snippet:

# loading base model and resizing embedding layers
model = AutoModelForCausalLM.from_pretrained(model_name)
model.resize_token_embeddings(len(tokenizer))

# gradient checkpointing enabling
model.enable_input_require_grads()
model.gradient_checkpointing_enable()

# peft
config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    base_model_name_or_path=model_name,
    modules_to_save=["wte", "lm_head"]

)
model = get_peft_model(model, config)

Output post training:

Screenshot 2023-04-28 at 4 46 41 PM

@Taytay
Copy link

Taytay commented Apr 29, 2023

Hello everyone, so adding embedding layers to modules_to_save works fine. I ran into issues when using PEFT+INT8, so I just tried PEFT+gradient checkpointing to make sure that PEFT method works given the sufficient VRAM. Below is an experiment using GPT-J to include new tokens.

@pacman100 : does your comment imply that this PR is not necessary, because simply configuring the embeddings layer to be trained and saved is sufficient?

@flozi00
Copy link

flozi00 commented Apr 29, 2023

Just giving my 5 cent to this thread
I think enabling 8 bit training by this PR is an important factor
Just adding embd layers to modules to save crashes in 8 bit
With 8 bit training the 6B model would take less ram than 36 GB, I think it would take around 8GB ? So even the 12GB consumer cards could be used for training

@pacman100
Copy link
Contributor

Hello @Taytay and @flozi00,

One would add embedding layers to modules_to_save when they are expanding the tokenizer and embedding layers to add new tokens as explained in the above example I posted. I subsequently got this to work in INT8 training too with VRAM usage going down from 36GB (as mentioned above, PEFT+Gradient Checkpointing) to 19.975 GB (PEFT+Gradient Checkpointing+INT8). Will be raising PR with the fixes next week.

This PR to add lora layers to embedding layers should be used if you aren't expanding the tokenizer with new tokens. As @Splo2t explained that he had a dataset in which existing tokens were used in unusual ways and as such he got expected results by adding LoRA layers to the embedding modules. This PR is very useful in this regard.

@Taytay
Copy link

Taytay commented Apr 30, 2023

Excellent - thank you. Yeah, I would like to modify the embeddings layer in my case too, so this PR is much appreciated. Thank you for the clarification!

@Splo2t
Copy link
Contributor Author

Splo2t commented Apr 30, 2023

@pacman100 I've made the suggested changes and rebased on main.

And I will post about related performance indicators later.

@Splo2t
Copy link
Contributor Author

Splo2t commented May 1, 2023

Hello @pacman100 I have refactored the Lora Embedding layer and fixed some bugs.

For everyone,
I am sharing the experimental results.
I have been working on training a kogpt model based on GPT-J to generate documents using keywords. The dataset I used consists of keywords, titles, and content, with the latter two used as labels during training. kogpt is a model with 64,512 vocabularies.Before training the embedding layer, I couldn't get a strong sense that the dataset was being learned, as the existing GPT model's characteristics were still prominent. Moreover, the handling of EOS tokens was unstable, causing issues with early stopping during generation.
With this PR, I could verify that the model generates content without the mentioned EOS token issue and, most importantly, gives a sense of understanding the dataset's structure while generating.
Additionally, I tested the Lora performance by training kogpt on the NSMC dataset using an A100 40GB GPU. The training settings were Deeepspeed + PEFT + FP16, batch size = 8, and maximum token length = 128. The fine-tuning version includes only the embedding layer in modules_to_save, while the remaining linear layers are trained with Lora. The advantages of using Lora for embedding layer training are as follows:

  1. Memory savings
  • Fine-tuning-with-embedding: average 91% (approximately 35,064MiB ~ 39,130MiB) (trainable params: 536,756,224, 8.693%)
  • Lora-with-embedding: 87.52% (approximately 35,478MiB) (trainable params: 8,548,352, 0.145%)
  • Lora-without-embedding: 87.54% (approximately 35,486MiB) (trainable params: 8,273,920, 0.140%)
    The memory usage of the fine-tuning version fluctuated from 35GB up to 40GB, while the other training methods maintained memory usage around 35.4GB.
  1. Time savings
  • Fine-tuning-with-embedding: 1.705 steps/sec, total 6 hours 15 minutes 17 seconds (37,500 steps)
  • Lora-with-embedding: 2.257 steps/sec, total 4 hours 45 minutes 5 seconds (37,500 steps)
  • Lora-without-embedding: 2.161 steps/sec, total 4 hours 57 minutes 19 seconds (37,500 steps)
    The fine-tuning version took 30% longer than Lora with embedding
  1. Similar evaluation accuracy to fine-tuning the embedding (based on the highest checkpoint scores)
  • The benchmark score for kogpt developer's fine-tuning technique is 0.917.
  • Fine-tuning-with-embedding: 0.9073
  • Lora-with-embedding: 0.9056
  • Lora-without-embedding: 0.9115
    The highest accuracy score was achieved without training the embedding layer. The score difference between models that trained the embedding layer is 0.0017, which is within the range that can be adjusted through hyperparameter tuning.
  1. Memory usage during inference (using PEFT)
  • Fine-tuning-with-embedding: 13,418MiB
  • Lora-with-embedding: 12,918MiB
  • Lora-without-embedding: 12,914MiB
    There is a difference in memory usage during inference, which is likely due to both the original model and the PEFT adapted model being loaded into memory at the same time.

Please see details below
https://api.wandb.ai/links/splo2t/jl2zotwi

@pacman100
Copy link
Contributor

Thank you @Splo2t for the detailed experimental results 🔥. A question to clarify, you used this PR for the cases wherein no new tokens were added to the vocab, right? Other than that, great work and thank you for iterating!

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @Splo2t for iterating, LGTM! 🚀

@Splo2t
Copy link
Contributor Author

Splo2t commented May 2, 2023

Thank you @Splo2t for the detailed experimental results 🔥. A question to clarify, you used this PR for the cases wherein no new tokens were added to the vocab, right? Other than that, great work and thank you for iterating!

In our previous experiments, we did not add new tokens to the vocabulary and proceeded with training. However, we became curious about whether it would be possible to train by adding new tokens to the vocabulary. Initially, we intended to train with quantization to INT8, but an error occurred during the process of resizing the Embedding Layer to match the vocabulary due to incompatible tensor types. As a result, we proceeded with training in FP16 for now.

I added a '<|neweos|>' token to the vocabulary and appended it to the end of every sentence. Then, I resized the Embedding Layer and trained for 1000 steps with a batch size of 8. Ultimately, we were able to confirm that the model generates output with the '<|neweos|>' token attached at the end of the sentence. This demonstrates that applying Lora to the Embedding allows for the expansion of the vocabulary.

Test prompt: '역시 스파이더맨' (As expected, Spider-Man)
Generated result: '역시 스파이더맨도 이제 나이를 먹었구나. <|neweos|>[EOS]' (As expected, Spider-Man has aged now. <|neweos|>[EOS])

Please refer to the source code used below.

import torch
from datasets import Dataset, DatasetDict, load_dataset
from enum import Enum
from transformers import GPTJForCausalLM, AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding
class SpecialTokens(str, Enum):
    new_eos = "<|neweos|>"
    @classmethod
    def list(cls):
        return [c.value for c in cls]

tokenizer = AutoTokenizer.from_pretrained(
            'kakaobrain/kogpt', revision='KoGPT6B-ryan1.5b-float16',  # or float32 version: revision=KoGPT6B-ryan1.5b
                 bos_token='[BOS]', eos_token='[EOS]', unk_token='[UNK]', pad_token='[EOS]', mask_token='[MASK]', additional_special_tokens=SpecialTokens.list(), use_fast=False
                 )

model = GPTJForCausalLM.from_pretrained(
                    'kakaobrain/kogpt', revision='KoGPT6B-ryan1.5b-float16',  # or float32 version: revision=KoGPT6B-ryan1.5b
                    pad_token_id=tokenizer.eos_token_id,
                    use_cache=False,
                    #device_map={'':rank},
                    torch_dtype=torch.float16,
                    #load_in_8bit=True,
                    #num_labels=2
)
model.resize_token_embeddings(len(tokenizer))

train_cs = load_dataset("nsmc", split="train")

dataset = DatasetDict({
    'train': train_cs,
})

def tokenize_function(examples):
    s = [i + ' <|neweos|>' for i in examples['document']]
    a = tokenizer(s, truncation = True, max_length = 128, padding = 'max_length')
    return dict({"input_ids": a['input_ids'], "labels": a['input_ids']})
tokenized_datasets = dataset.map(tokenize_function, remove_columns=["id", "document", "label"],batched=True)

print(tokenized_datasets)
print(tokenized_datasets['train'][0])
print(tokenizer.decode(tokenized_datasets['train'][0]['input_ids']))

"""
DatasetDict({
    train: Dataset({
        features: ['input_ids', 'labels'],
        num_rows: 150000
    })
})
{'input_ids': [508, 37056, 1939, 3328, 11006, 475, 2538, 3399, 327, 63998, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'labels': [508, 37056, 1939, 3328, 11006, 475, 2538, 3399, 327, 63998, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
아 더빙.. 진짜 짜증나네요 목소리 <|neweos|>[EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS]
"""

from peft import LoraConfig, PeftModel, get_peft_model#, prepare_model_for_int8_training
#model = prepare_model_for_int8_training(model)
target_modules = [ 'q_proj', 'k_proj', 'v_proj', 'out_proj', 'fc_in', 'fc_out', 'wte']
config = LoraConfig(
            r=4, lora_alpha=16, target_modules=target_modules,  lora_dropout=0.1, bias="none", task_type="CAUSAL_LM"
)
model = get_peft_model(model, config)
print(type(model))
print(model.model.transformer.wte)
print(model.print_trainable_parameters())
print(model.model.transformer.h[0].attn.q_proj)
"""
<class 'peft.peft_model.PeftModelForCausalLM'>
Embedding(
  63999, 4096
  (lora_dropout): ModuleDict(
    (default): Dropout(p=0.1, inplace=False)
  )
  (lora_A): ModuleDict()
  (lora_B): ModuleDict()
  (lora_embedding_A): ParameterDict(  (default): Parameter containing: [torch.FloatTensor of size 4x63999])
  (lora_embedding_B): ParameterDict(  (default): Parameter containing: [torch.FloatTensor of size 4096x4])
)
Linear(
  in_features=4096, out_features=4096, bias=False
  (lora_dropout): ModuleDict(
    (default): Dropout(p=0.1, inplace=False)
  )
  (lora_A): ModuleDict(
    (default): Linear(in_features=4096, out_features=4, bias=False)
  )
  (lora_B): ModuleDict(
    (default): Linear(in_features=4, out_features=4096, bias=False)
  )
  (lora_embedding_A): ParameterDict()
  (lora_embedding_B): ParameterDict()
)
trainable params: 8529916 || all params: 6170829307 || trainable%: 0.13822965400005352
"""

from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir="test_trainer",
    fp16 = True,
    warmup_ratio = 0.05,
    learning_rate=5e-5,
    max_steps=1000,
    logging_dir='./logs',
    logging_steps=100,
    seed=42,
    dataloader_drop_last=False,
    dataloader_num_workers=2,
    per_device_train_batch_size= 8,
    per_device_eval_batch_size=32,
    deepspeed = 'ds_config.json'
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    tokenizer=tokenizer
)
trainer.train()
prompt= '역시 스파이더맨'
device = "cuda:0"
tokens = tokenizer.encode(prompt, return_tensors='pt').to(device=device, non_blocking=True)
gen_tokens = model.generate(input_ids = tokens,
                          do_sample=True,
                          temperature=0.8,
                          top_p=0.99,
                          max_new_tokens=64,
                          early_stopping=True,
                         )
generated = tokenizer.batch_decode(gen_tokens)[0]
print(generated)
"""
역시 스파이더맨의 매력을 더 느끼게 해준 영화.. <|neweos|>[EOS]
"""

@Splo2t
Copy link
Contributor Author

Splo2t commented May 2, 2023

Thank you for your thorough review and valuable feedback @pacman100 !

@flozi00
Copy link

flozi00 commented May 2, 2023

I think a possible solution would be resizing the model embedding layers, saving it and then loading to 8bit as normally

@Splo2t
Copy link
Contributor Author

Splo2t commented May 2, 2023

I'd like to thank you for your input, @flozi00
I'd like to provide some additional information regarding INT8 training. When I mentioned that the issue occurred while resizing the Embedding Layer, that was actually my mistake. The problem arose while using Deepspeed for multi-GPU training. To resolve this issue when training with multiple GPUs, you can add find_unused_parameters=False to the TrainingArguments, which is based on Distributed Data Parallel (DDP). Below is the code for INT8 training using Distributed Data Parallel as the foundation:

execution code: torchrun --standalone --nnodes=1 --nproc_per_node=4 app.py

import torch
from datasets import Dataset, DatasetDict, load_dataset
from enum import Enum
from transformers import GPTJForCausalLM, AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding
import torch.distributed as dist
dist.init_process_group("nccl")
rank = dist.get_rank()

class SpecialTokens(str, Enum):
    new_eos = "<|neweos|>"
    #end_target = "<|endtarget|>"
    @classmethod
    def list(cls):
        return [c.value for c in cls]

tokenizer = AutoTokenizer.from_pretrained(
            'kakaobrain/kogpt', revision='KoGPT6B-ryan1.5b-float16',  # or float32 version: revision=KoGPT6B-ryan1.5b
                 bos_token='[BOS]', eos_token='[EOS]', unk_token='[UNK]', pad_token='[EOS]', mask_token='[MASK]', additional_special_tokens=SpecialTokens.list(), use_fast=False
                 )

model = GPTJForCausalLM.from_pretrained(
                    'kakaobrain/kogpt', revision='KoGPT6B-ryan1.5b-float16',  # or float32 version: revision=KoGPT6B-ryan1.5b
                    pad_token_id=tokenizer.eos_token_id,
                    use_cache=False,
                    device_map={'':rank},
                    torch_dtype=torch.float16,
                    load_in_8bit=True,
)

model.resize_token_embeddings(len(tokenizer))
train_cs = load_dataset("nsmc", split="train")
dataset = DatasetDict({
        'train': train_cs,
})

def tokenize_function(examples):
    s = [i + ' <|neweos|>' for i in examples['document']]
    a = tokenizer(s, truncation = True, max_length = 128, padding = 'max_length')
    return dict({"input_ids": a['input_ids'], "labels": a['input_ids']})

tokenized_datasets = dataset.map(tokenize_function, remove_columns=["id", "document", "label"],batched=True)

print(tokenized_datasets)
print(tokenized_datasets['train'][0])
print(tokenizer.decode(tokenized_datasets['train'][0]['input_ids']))
"""
DatasetDict({
    train: Dataset({
        features: ['input_ids', 'labels'],
        num_rows: 150000
    })
})
{'input_ids': [508, 37056, 1939, 3328, 11006, 475, 2538, 3399, 327, 63998, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1
, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'labels': [508, 37056, 1939, 3328, 11006, 475, 2538, 3399, 327, 63998, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1
, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
아 더빙.. 진짜 짜증나네요 목소리 <|neweos|>[EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EO
S][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS
][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS]
[EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS]
"""

from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_int8_training
model = prepare_model_for_int8_training(model)
target_modules = [ 'q_proj', 'k_proj', 'v_proj', 'out_proj', 'fc_in', 'fc_out', 'wte']
config = LoraConfig(
            r=4, lora_alpha=16, target_modules=target_modules,  lora_dropout=0.1, bias="none", task_type="CAUSAL_LM"
)
model = get_peft_model(model, config)
model.gradient_checkpointing_enable()
print(type(model))
print(model.model.transformer.wte)
print(model.model.transformer.h[0].attn.q_proj)
print(model.print_trainable_parameters())
"""
<class 'peft.peft_model.PeftModelForCausalLM'>
Embedding(
  63999, 4096
  (lora_dropout): ModuleDict(
    (default): Dropout(p=0.1, inplace=False)
  )
  (lora_A): ModuleDict()
  (lora_B): ModuleDict()
  (lora_embedding_A): ParameterDict(  (default): Parameter containing: [torch.cuda.FloatTensor of size 4x63999 (GPU 2)])
  (lora_embedding_B): ParameterDict(  (default): Parameter containing: [torch.cuda.FloatTensor of size 4096x4 (GPU 2)])
)
Linear8bitLt(
  in_features=4096, out_features=4096, bias=False
  (lora_dropout): ModuleDict(
    (default): Dropout(p=0.1, inplace=False)
  )
  (lora_A): ModuleDict(
    (default): Linear(in_features=4096, out_features=4, bias=False)
  )
  (lora_B): ModuleDict(
    (default): Linear(in_features=4, out_features=4096, bias=False)
  )
  (lora_embedding_A): ParameterDict()
  (lora_embedding_B): ParameterDict()
)
trainable params: 8529916 || all params: 6170829307 || trainable%: 0.13822965400005352
"""

from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir="test_trainer",
    fp16 = True,
    warmup_ratio = 0.05,
    learning_rate=5e-5,
    max_steps=1000,
    logging_dir='./logs',
    logging_steps=100,
    seed=42,
    dataloader_drop_last=False,
    dataloader_num_workers=2,
    ddp_find_unused_parameters = False,
    per_device_train_batch_size= 8,
    per_device_eval_batch_size=32,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
)

trainer.train()
prompt= '역시 스파이더맨'
device = "cuda:0"
tokens = tokenizer.encode(prompt, return_tensors='pt').to(device=device, non_blocking=True)
gen_tokens = model.generate(input_ids = tokens,
                          do_sample=True,
                          temperature=0.8,
                          top_p=0.99,
                          max_new_tokens=64,
                          early_stopping=True,
                         )
generated = tokenizer.batch_decode(gen_tokens)[0]
print(generated)
"""
역시 스파이더맨 너무 좋아. <|neweos|>[EOS]
"""

@pacman100 pacman100 merged commit 6a18585 into huggingface:main May 3, 2023
@Sanster
Copy link

Sanster commented May 4, 2023

@Splo2t Thank you for sharing the experimental results. I would like to ask a question: If Lora is used for embedding training, should the learning rate be adjusted accordingly?

@Splo2t
Copy link
Contributor Author

Splo2t commented May 4, 2023

Thank you for your question @Sanster. When training NSMC using AutoModelForSequenceClassification, we observed that adjusting the learning rate can lead to a difference of up to 3% in accuracy. Although not mentioned in the experimental results above, there were cases where the NSMC accuracy exceeded 91% when the learning rate was adjusted while using Lora for embedding training. For generative models using AutoModelForCausalLM, it is difficult to make quantitative comparisons, but there were no significant issues when using the same learning rate.
Although I have not looked into related papers, in my opinion, the instability of commonly known encoder models seems to be reflected here to some extent. If the task does not require high performance, adjusting the learning rate might not be necessary, but it would be beneficial to do so.

@lastrei
Copy link

lastrei commented Nov 23, 2023

target_modules

sir? if i add new vocab in tokenizer, and i add wte in target_modules, the embeddings layer will also train in peft?

this is my code for phi-1.5

model = AutoModelForCausalLM.from_pretrained(
    "microsoft/phi-1_5",
    device_map={"": 0},
    trust_remote_code=True,
    quantization_config=bnb_config
)
# 更新模型的词表大小
model.resize_token_embeddings(len(tokenizer))

lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["Wqkv", "out_proj", "wte"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)

model.print_trainable_parameters()

@BenjaminBossan
Copy link
Member

i add wte in target_modules, the embeddings layer will also train in peft?

The embeddings themselves will stay frozen, but the LoRA weights for the embeddings will be updated. If you want to train the embeddings fully, add them to modules_to_save instead. But this increases the number of trainable parameters considerably, so just training the LoRA weights should be evaluated first.

@lastrei
Copy link

lastrei commented Nov 24, 2023

i add wte in target_modules, the embeddings layer will also train in peft?

The embeddings themselves will stay frozen, but the LoRA weights for the embeddings will be updated. If you want to train the embeddings fully, add them to modules_to_save instead. But this increases the number of trainable parameters considerably, so just training the LoRA weights should be evaluated first.

thanks sir ,so i have to set

modules_to_save(["wte"])

in the train steps to train the full embeddings?

@ariG23498 ariG23498 mentioned this pull request Aug 18, 2024
2 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.