-
Notifications
You must be signed in to change notification settings - Fork 1.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add nn.Embedding Support to Lora #337
Conversation
The documentation is not available anymore as the PR was closed or merged. |
This solved my issue in #349 (comment) |
@pacman100 please help review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @Splo2t for adding the support to have LoRA layers in Embedding modules 🤗, this is really cool 🔥. Left a few suggestions.
Are there evaluation results on how this improves the performance of tasks? It would be great if you can tweet about this feature with some great examples, we can amplify the same from our end.
I'm not sure if there's been any significant evaluations done so far, but I've ran into issues with training attention without embeddings. Training a LoRA on a dataset with a strange use of tokens that the model likely didn't see during full training results in a high probability of a random output during inference. This problem isn't present after fine tuning the same model on the same dataset. My hypothesis is that if your data contains anything other than very standard text, the embedding layer must be trained in order to learn the new context of the input, along with attention. I have a training run going currently using this branch and I'll report back if you're interested. |
I have been fine-tuning a GPT model for generating new documents. When using PEFT for fine-tuning, I couldn't achieve the desired generation results. However, when I implemented both Embedding Layer and Linear Layer with Lora based on other codes, I was able to obtain the desired outcome. Using the code from this PR, I achieved the same level of satisfaction. I have not yet benchmarked the scores for adding the Embedding Layer to Lora. |
Very interesting pointers @Splo2t and @SOCSChamp. And what is the impact on the trainable parameters percentage with and without targeting embedding layer? |
Excluding the Embedding Layer: trainable params: 8257536 || all params: 6174759936 || trainable%: 0.13373047836009022 Including the Embedding Layer: trainable params: 8531968 || all params: 6175034368 || trainable%: 0.13816875326579559 The results above compare the trainable parameters of a version with the Embedding Layer and a version without it. The trainable parameters were obtained using the print_trainable_parameters() function. |
Hi @SOCSChamp are you referring to scenarios where additional special tokens are added to the tokenizer's vocabulary (related to #334)? If yes, I'm curious whether your training run gave the desired outputs at inference time and whether you can share how you're loading the model with the expanded vocabulary? For context, I'm fine-tuning LLaMa models with special tokens like |
No, I am not actually expanding the tokenizer vocabulary but using the stock tokenizer with tokens that are used in atypical ways. The tokens themselves (shouldn't be) new, but the context that they're used in and how they relate to each other certainly is. I mentioned in #334 because the solution that was provided (adding the embedding layer to modules_to_save) wasn't working in my application. |
Hello everyone, so adding embedding layers to GPU used for experimentation: A100 80GB VRAM consumed during training: 36GB Colab Notebook: https://colab.research.google.com/drive/1Wjm3iMuS43iu9Q_1XwX0NJsFu2d1fZgJ?usp=sharing important code snippet:
Output post training: |
@pacman100 : does your comment imply that this PR is not necessary, because simply configuring the embeddings layer to be trained and saved is sufficient? |
Just giving my 5 cent to this thread |
One would add embedding layers to This PR to add lora layers to embedding layers should be used if you aren't expanding the tokenizer with new tokens. As @Splo2t explained that he had a dataset in which existing tokens were used in unusual ways and as such he got expected results by adding LoRA layers to the embedding modules. This PR is very useful in this regard. |
Excellent - thank you. Yeah, I would like to modify the embeddings layer in my case too, so this PR is much appreciated. Thank you for the clarification! |
@pacman100 I've made the suggested changes and rebased on main. And I will post about related performance indicators later. |
…ng's merge and unmerge methods
Hello @pacman100 I have refactored the Lora Embedding layer and fixed some bugs. For everyone,
Please see details below |
Thank you @Splo2t for the detailed experimental results 🔥. A question to clarify, you used this PR for the cases wherein no new tokens were added to the vocab, right? Other than that, great work and thank you for iterating! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @Splo2t for iterating, LGTM! 🚀
In our previous experiments, we did not add new tokens to the vocabulary and proceeded with training. However, we became curious about whether it would be possible to train by adding new tokens to the vocabulary. Initially, we intended to train with quantization to INT8, but an error occurred during the process of resizing the Embedding Layer to match the vocabulary due to incompatible tensor types. As a result, we proceeded with training in FP16 for now. I added a '<|neweos|>' token to the vocabulary and appended it to the end of every sentence. Then, I resized the Embedding Layer and trained for 1000 steps with a batch size of 8. Ultimately, we were able to confirm that the model generates output with the '<|neweos|>' token attached at the end of the sentence. This demonstrates that applying Lora to the Embedding allows for the expansion of the vocabulary. Test prompt: '역시 스파이더맨' (As expected, Spider-Man) Please refer to the source code used below. import torch
from datasets import Dataset, DatasetDict, load_dataset
from enum import Enum
from transformers import GPTJForCausalLM, AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding
class SpecialTokens(str, Enum):
new_eos = "<|neweos|>"
@classmethod
def list(cls):
return [c.value for c in cls]
tokenizer = AutoTokenizer.from_pretrained(
'kakaobrain/kogpt', revision='KoGPT6B-ryan1.5b-float16', # or float32 version: revision=KoGPT6B-ryan1.5b
bos_token='[BOS]', eos_token='[EOS]', unk_token='[UNK]', pad_token='[EOS]', mask_token='[MASK]', additional_special_tokens=SpecialTokens.list(), use_fast=False
)
model = GPTJForCausalLM.from_pretrained(
'kakaobrain/kogpt', revision='KoGPT6B-ryan1.5b-float16', # or float32 version: revision=KoGPT6B-ryan1.5b
pad_token_id=tokenizer.eos_token_id,
use_cache=False,
#device_map={'':rank},
torch_dtype=torch.float16,
#load_in_8bit=True,
#num_labels=2
)
model.resize_token_embeddings(len(tokenizer))
train_cs = load_dataset("nsmc", split="train")
dataset = DatasetDict({
'train': train_cs,
})
def tokenize_function(examples):
s = [i + ' <|neweos|>' for i in examples['document']]
a = tokenizer(s, truncation = True, max_length = 128, padding = 'max_length')
return dict({"input_ids": a['input_ids'], "labels": a['input_ids']})
tokenized_datasets = dataset.map(tokenize_function, remove_columns=["id", "document", "label"],batched=True)
print(tokenized_datasets)
print(tokenized_datasets['train'][0])
print(tokenizer.decode(tokenized_datasets['train'][0]['input_ids']))
"""
DatasetDict({
train: Dataset({
features: ['input_ids', 'labels'],
num_rows: 150000
})
})
{'input_ids': [508, 37056, 1939, 3328, 11006, 475, 2538, 3399, 327, 63998, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'labels': [508, 37056, 1939, 3328, 11006, 475, 2538, 3399, 327, 63998, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
아 더빙.. 진짜 짜증나네요 목소리 <|neweos|>[EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS]
"""
from peft import LoraConfig, PeftModel, get_peft_model#, prepare_model_for_int8_training
#model = prepare_model_for_int8_training(model)
target_modules = [ 'q_proj', 'k_proj', 'v_proj', 'out_proj', 'fc_in', 'fc_out', 'wte']
config = LoraConfig(
r=4, lora_alpha=16, target_modules=target_modules, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM"
)
model = get_peft_model(model, config)
print(type(model))
print(model.model.transformer.wte)
print(model.print_trainable_parameters())
print(model.model.transformer.h[0].attn.q_proj)
"""
<class 'peft.peft_model.PeftModelForCausalLM'>
Embedding(
63999, 4096
(lora_dropout): ModuleDict(
(default): Dropout(p=0.1, inplace=False)
)
(lora_A): ModuleDict()
(lora_B): ModuleDict()
(lora_embedding_A): ParameterDict( (default): Parameter containing: [torch.FloatTensor of size 4x63999])
(lora_embedding_B): ParameterDict( (default): Parameter containing: [torch.FloatTensor of size 4096x4])
)
Linear(
in_features=4096, out_features=4096, bias=False
(lora_dropout): ModuleDict(
(default): Dropout(p=0.1, inplace=False)
)
(lora_A): ModuleDict(
(default): Linear(in_features=4096, out_features=4, bias=False)
)
(lora_B): ModuleDict(
(default): Linear(in_features=4, out_features=4096, bias=False)
)
(lora_embedding_A): ParameterDict()
(lora_embedding_B): ParameterDict()
)
trainable params: 8529916 || all params: 6170829307 || trainable%: 0.13822965400005352
"""
from transformers import TrainingArguments, Trainer
training_args = TrainingArguments(
output_dir="test_trainer",
fp16 = True,
warmup_ratio = 0.05,
learning_rate=5e-5,
max_steps=1000,
logging_dir='./logs',
logging_steps=100,
seed=42,
dataloader_drop_last=False,
dataloader_num_workers=2,
per_device_train_batch_size= 8,
per_device_eval_batch_size=32,
deepspeed = 'ds_config.json'
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets['train'],
tokenizer=tokenizer
)
trainer.train()
prompt= '역시 스파이더맨'
device = "cuda:0"
tokens = tokenizer.encode(prompt, return_tensors='pt').to(device=device, non_blocking=True)
gen_tokens = model.generate(input_ids = tokens,
do_sample=True,
temperature=0.8,
top_p=0.99,
max_new_tokens=64,
early_stopping=True,
)
generated = tokenizer.batch_decode(gen_tokens)[0]
print(generated)
"""
역시 스파이더맨의 매력을 더 느끼게 해준 영화.. <|neweos|>[EOS]
""" |
Thank you for your thorough review and valuable feedback @pacman100 ! |
I think a possible solution would be resizing the model embedding layers, saving it and then loading to 8bit as normally |
I'd like to thank you for your input, @flozi00 execution code: torchrun --standalone --nnodes=1 --nproc_per_node=4 app.py import torch
from datasets import Dataset, DatasetDict, load_dataset
from enum import Enum
from transformers import GPTJForCausalLM, AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding
import torch.distributed as dist
dist.init_process_group("nccl")
rank = dist.get_rank()
class SpecialTokens(str, Enum):
new_eos = "<|neweos|>"
#end_target = "<|endtarget|>"
@classmethod
def list(cls):
return [c.value for c in cls]
tokenizer = AutoTokenizer.from_pretrained(
'kakaobrain/kogpt', revision='KoGPT6B-ryan1.5b-float16', # or float32 version: revision=KoGPT6B-ryan1.5b
bos_token='[BOS]', eos_token='[EOS]', unk_token='[UNK]', pad_token='[EOS]', mask_token='[MASK]', additional_special_tokens=SpecialTokens.list(), use_fast=False
)
model = GPTJForCausalLM.from_pretrained(
'kakaobrain/kogpt', revision='KoGPT6B-ryan1.5b-float16', # or float32 version: revision=KoGPT6B-ryan1.5b
pad_token_id=tokenizer.eos_token_id,
use_cache=False,
device_map={'':rank},
torch_dtype=torch.float16,
load_in_8bit=True,
)
model.resize_token_embeddings(len(tokenizer))
train_cs = load_dataset("nsmc", split="train")
dataset = DatasetDict({
'train': train_cs,
})
def tokenize_function(examples):
s = [i + ' <|neweos|>' for i in examples['document']]
a = tokenizer(s, truncation = True, max_length = 128, padding = 'max_length')
return dict({"input_ids": a['input_ids'], "labels": a['input_ids']})
tokenized_datasets = dataset.map(tokenize_function, remove_columns=["id", "document", "label"],batched=True)
print(tokenized_datasets)
print(tokenized_datasets['train'][0])
print(tokenizer.decode(tokenized_datasets['train'][0]['input_ids']))
"""
DatasetDict({
train: Dataset({
features: ['input_ids', 'labels'],
num_rows: 150000
})
})
{'input_ids': [508, 37056, 1939, 3328, 11006, 475, 2538, 3399, 327, 63998, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1
, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'labels': [508, 37056, 1939, 3328, 11006, 475, 2538, 3399, 327, 63998, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1
, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
아 더빙.. 진짜 짜증나네요 목소리 <|neweos|>[EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EO
S][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS
][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS]
[EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS]
"""
from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_int8_training
model = prepare_model_for_int8_training(model)
target_modules = [ 'q_proj', 'k_proj', 'v_proj', 'out_proj', 'fc_in', 'fc_out', 'wte']
config = LoraConfig(
r=4, lora_alpha=16, target_modules=target_modules, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM"
)
model = get_peft_model(model, config)
model.gradient_checkpointing_enable()
print(type(model))
print(model.model.transformer.wte)
print(model.model.transformer.h[0].attn.q_proj)
print(model.print_trainable_parameters())
"""
<class 'peft.peft_model.PeftModelForCausalLM'>
Embedding(
63999, 4096
(lora_dropout): ModuleDict(
(default): Dropout(p=0.1, inplace=False)
)
(lora_A): ModuleDict()
(lora_B): ModuleDict()
(lora_embedding_A): ParameterDict( (default): Parameter containing: [torch.cuda.FloatTensor of size 4x63999 (GPU 2)])
(lora_embedding_B): ParameterDict( (default): Parameter containing: [torch.cuda.FloatTensor of size 4096x4 (GPU 2)])
)
Linear8bitLt(
in_features=4096, out_features=4096, bias=False
(lora_dropout): ModuleDict(
(default): Dropout(p=0.1, inplace=False)
)
(lora_A): ModuleDict(
(default): Linear(in_features=4096, out_features=4, bias=False)
)
(lora_B): ModuleDict(
(default): Linear(in_features=4, out_features=4096, bias=False)
)
(lora_embedding_A): ParameterDict()
(lora_embedding_B): ParameterDict()
)
trainable params: 8529916 || all params: 6170829307 || trainable%: 0.13822965400005352
"""
from transformers import TrainingArguments, Trainer
training_args = TrainingArguments(
output_dir="test_trainer",
fp16 = True,
warmup_ratio = 0.05,
learning_rate=5e-5,
max_steps=1000,
logging_dir='./logs',
logging_steps=100,
seed=42,
dataloader_drop_last=False,
dataloader_num_workers=2,
ddp_find_unused_parameters = False,
per_device_train_batch_size= 8,
per_device_eval_batch_size=32,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets['train'],
)
trainer.train()
prompt= '역시 스파이더맨'
device = "cuda:0"
tokens = tokenizer.encode(prompt, return_tensors='pt').to(device=device, non_blocking=True)
gen_tokens = model.generate(input_ids = tokens,
do_sample=True,
temperature=0.8,
top_p=0.99,
max_new_tokens=64,
early_stopping=True,
)
generated = tokenizer.batch_decode(gen_tokens)[0]
print(generated)
"""
역시 스파이더맨 너무 좋아. <|neweos|>[EOS]
""" |
@Splo2t Thank you for sharing the experimental results. I would like to ask a question: If Lora is used for embedding training, should the learning rate be adjusted accordingly? |
Thank you for your question @Sanster. When training NSMC using AutoModelForSequenceClassification, we observed that adjusting the learning rate can lead to a difference of up to 3% in accuracy. Although not mentioned in the experimental results above, there were cases where the NSMC accuracy exceeded 91% when the learning rate was adjusted while using Lora for embedding training. For generative models using AutoModelForCausalLM, it is difficult to make quantitative comparisons, but there were no significant issues when using the same learning rate. |
sir? if i add new vocab in tokenizer, and i add wte in target_modules, the embeddings layer will also train in peft? this is my code for phi-1.5
|
The embeddings themselves will stay frozen, but the LoRA weights for the embeddings will be updated. If you want to train the embeddings fully, add them to |
thanks sir ,so i have to set
in the train steps to train the full embeddings? |
What does this PR do?
Pull Request Description
This pull request includes the following changes:
Added nn.Embedding to Lora Fine-tuning support Layer: With this update, the Lora Fine-tuning layer now supports nn.Embedding, allowing for better performance and efficiency during the fine-tuning process.
GPT-J Fine-tuning example with nn.Embedding Layer and 8-bit quantization using PEFT in lora.py: To demonstrate the usage of GPT-J with Lora, we've added an example in the lora.py file. This example showcases fine-tuning GPT-J using nn.Embedding Layer and 8-bit quantization with the help of PEFT. This provides users with an efficient method for optimizing their GPT-J models.