Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PEFT training and explicit kwarg passthrough #3480

Merged
merged 27 commits into from
Jul 15, 2024
Merged

Add PEFT training and explicit kwarg passthrough #3480

merged 27 commits into from
Jul 15, 2024

Conversation

janpf
Copy link
Contributor

@janpf janpf commented Jul 1, 2024

This PR adds the ability to train models using PEFT (LoRA and QLoRA) and some nicer handling for model and config explicit kwargs. For example, passing through kwargs to the model but not the config was not possible before.

If PEFT is not installed and not used, no error should be thrown either.

@alanakbik
Copy link
Collaborator

Hello @janpf could you provide a small test script how to train a model (for instance for NER) using PEFT? That would make it easier to test.

@janpf
Copy link
Contributor Author

janpf commented Jul 2, 2024

will do. hopefully this week :)

@janpf
Copy link
Contributor Author

janpf commented Jul 3, 2024

Ok, I got a minimal example. I adapted this: https://flairnlp.github.io/docs/tutorial-training/how-to-train-text-classifier

requirements.txt:

git+https://github.com/flairNLP/flair.git@refs/pull/3480/merge
bitsandbytes
peft
scipy==1.10.1

The code then looks like this:

from flair.data import Corpus
from flair.datasets import TREC_6
from flair.embeddings import TransformerDocumentEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer

corpus: Corpus = TREC_6()
label_type = "question_class"
label_dict = corpus.make_label_dictionary(label_type=label_type)

# this is new
from peft import LoraConfig, TaskType
import torch
import bitsandbytes as bnb

# set the quantization config (bitsandbytes)
bnb_config = {
    "device_map": "auto",
    "load_in_8bit": True,
}
# set lora config (peft)
peft_config = LoraConfig(
    task_type=TaskType.FEATURE_EXTRACTION,
    inference_mode=False,
)
document_embeddings = TransformerDocumentEmbeddings(
    "uklfr/gottbert-base",
    fine_tune=True,
    # pass both configs using the newly introduced kwargs
    transformers_model_kwargs=bnb_config,
    peft_config=peft_config,
)

classifier = TextClassifier(
    document_embeddings, label_dictionary=label_dict, label_type=label_type
)
trainer = ModelTrainer(classifier, corpus)
trainer.fine_tune(
    "resources/taggers/question-classification-with-transformer",
    learning_rate=5.0e-5,
    mini_batch_size=4,
    # i believe explicitly swapping out the optimizer is recommended
    optimizer=bnb.optim.adamw.AdamW,
    max_epochs=1,

the resulting model is quite bad, but all QLoRA-hyperparameters have been kept at the original values.
the logs then also show that the model has been correctly quantised (lora.Linear8bitLt) and the adapters have been inserted (lora_A & lora_B):

2024-07-03 16:45:06,535 Model: "TextClassifier(
  (embeddings): TransformerDocumentEmbeddings(
    (model): PeftModelForFeatureExtraction(
      (base_model): LoraModel(
        (model): RobertaModel(
          (embeddings): RobertaEmbeddings(
            (word_embeddings): Embedding(52010, 768)
            (position_embeddings): Embedding(514, 768, padding_idx=1)
            (token_type_embeddings): Embedding(1, 768)
            (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (encoder): RobertaEncoder(
            (layer): ModuleList(
              (0-11): 12 x RobertaLayer(
                (attention): RobertaAttention(
                  (self): RobertaSelfAttention(
                    (query): lora.Linear8bitLt(
                      (base_layer): Linear8bitLt(in_features=768, out_features=768, bias=True)
                      (lora_dropout): ModuleDict(
                        (default): Identity()
                      )
                      (lora_A): ModuleDict(
                        (default): Linear(in_features=768, out_features=8, bias=False)
                      )
                      (lora_B): ModuleDict(
                        (default): Linear(in_features=8, out_features=768, bias=False)
                      )
                      (lora_embedding_A): ParameterDict()
                      (lora_embedding_B): ParameterDict()
                    )
                    (key): Linear8bitLt(in_features=768, out_features=768, bias=True)
                    (value): lora.Linear8bitLt(
                      (base_layer): Linear8bitLt(in_features=768, out_features=768, bias=True)
                      (lora_dropout): ModuleDict(
                        (default): Identity()
                      )
                      (lora_A): ModuleDict(
                        (default): Linear(in_features=768, out_features=8, bias=False)
                      )
                      (lora_B): ModuleDict(
                        (default): Linear(in_features=8, out_features=768, bias=False)
                      )
                      (lora_embedding_A): ParameterDict()
                      (lora_embedding_B): ParameterDict()
                    )
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (output): RobertaSelfOutput(
                    (dense): Linear8bitLt(in_features=768, out_features=768, bias=True)
                    (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
                (intermediate): RobertaIntermediate(
                  (dense): Linear8bitLt(in_features=768, out_features=3072, bias=True)
                  (intermediate_act_fn): GELUActivation()
                )
                (output): RobertaOutput(
                  (dense): Linear8bitLt(in_features=3072, out_features=768, bias=True)
                  (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
                  (dropout): Dropout(p=0.1, inplace=False)
                )
              )
            )
          )
          (pooler): RobertaPooler(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (activation): Tanh()
          )
        )
      )
    )
  )
  (decoder): Linear(in_features=768, out_features=6, bias=True)
  (dropout): Dropout(p=0.0, inplace=False)
  (locked_dropout): LockedDropout(p=0.0)
  (word_dropout): WordDropout(p=0.0)
  (loss_function): CrossEntropyLoss()
  (weights): None
  (weight_tensor) None
)"

and also: 2024-07-03 16:45:06,517 trainable params: 294,912 || all params: 126,279,936 || trainable%: 0.2335

@alanakbik
Copy link
Collaborator

Hi @janpf this looks good. I tested for a standard BERT model (for which quantization seems not to be available), and I'm getting competitive results to full fine-tuning when setting a slightly higher learning rate for LoRA:

from peft import LoraConfig, TaskType

document_embeddings = TransformerDocumentEmbeddings(
    "bert-base-uncased",
    fine_tune=True,
    # set LoRA config
    peft_config=LoraConfig(
        task_type=TaskType.FEATURE_EXTRACTION,
        inference_mode=False,
    ),
)

classifier = TextClassifier(document_embeddings, label_dictionary=label_dict, label_type=label_type)
trainer = ModelTrainer(classifier, corpus)
trainer.fine_tune(
    "resources/taggers/question-classification-with-transformer",
    learning_rate=5.0e-4,
    mini_batch_size=4,
    max_epochs=1,
)

Unfortunately, I don't know what is causing the storage error. This is now affecting all PRs.

@alanakbik
Copy link
Collaborator

Thanks again for adding this @janpf! Since the tests are now running through, we can merge!

@alanakbik alanakbik merged commit b7cc211 into flairNLP:master Jul 15, 2024
1 check passed
@janpf janpf deleted the qlora branch July 16, 2024 06:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants