Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BLIP2 Example #260

Merged
merged 10 commits into from
Apr 6, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,12 @@ An example is provided in `~examples/causal_language_modeling/peft_lora_clm_acce
| ViT | ✅ | | | |
| Swin | ✅ | | | |

### Image to text (Multi-modal models)

| Model | LoRA | Prefix Tuning | P-Tuning | Prompt Tuning |
| --------- | ---- | ---- | ---- | ---- |
| Blip-2 | ✅ | | | |

___Note that we have tested LoRA for [ViT](https://huggingface.co/docs/transformers/model_doc/vit) and [Swin](https://huggingface.co/docs/transformers/model_doc/swin) for fine-tuning on image classification. However, it should be possible to use LoRA for any compatible model [provided](https://huggingface.co/models?pipeline_tag=image-classification&sort=downloads&search=vit) by 🤗 Transformers. Check out the respective
examples to learn more. If you run into problems, please open an issue.___

Expand Down
104 changes: 104 additions & 0 deletions examples/int8_training/fine_tune_blip2_int8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModelForVision2Seq, AutoProcessor

from peft import LoraConfig, get_peft_model


# Let's define the LoraConfig
config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="VISION_2_SEQ",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keeping this unspecified will automatically use the LoRA model via PeftModel object as task-specific class isn't a requirement for LoRA

)

# We load our model and processor using `transformers`
model = AutoModelForVision2Seq.from_pretrained("Salesforce/blip2-opt-2.7b", load_in_8bit=True, device_map={"": 0})
processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")

# Get our peft model and print the number of trainable parameters
model = get_peft_model(model, config)
model.print_trainable_parameters()

# Let's load the dataset here!
dataset = load_dataset("ybelkada/football-dataset", split="train")


class ImageCaptioningDataset(Dataset):
def __init__(self, dataset, processor):
self.dataset = dataset
self.processor = processor

def __len__(self):
return len(self.dataset)

def __getitem__(self, idx):
item = self.dataset[idx]
encoding = self.processor(images=item["image"], padding="max_length", return_tensors="pt")
# remove batch dimension
encoding = {k: v.squeeze() for k, v in encoding.items()}
encoding["text"] = item["text"]
return encoding


def collator(batch):
# pad the input_ids and attention_mask
processed_batch = {}
for key in batch[0].keys():
if key != "text":
processed_batch[key] = torch.stack([example[key] for example in batch])
else:
text_inputs = processor.tokenizer(
[example["text"] for example in batch], padding=True, return_tensors="pt"
)
processed_batch["input_ids"] = text_inputs["input_ids"]
processed_batch["attention_mask"] = text_inputs["attention_mask"]
return processed_batch


train_dataset = ImageCaptioningDataset(dataset, processor)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=2, collate_fn=collator)

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

device = "cuda" if torch.cuda.is_available() else "cpu"

model.train()

for epoch in range(50):
print("Epoch:", epoch)
for idx, batch in enumerate(train_dataloader):
input_ids = batch.pop("input_ids").to(device)
pixel_values = batch.pop("pixel_values").to(device, torch.float16)

outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=input_ids)

loss = outputs.loss

print("Loss:", loss.item())

loss.backward()

optimizer.step()
optimizer.zero_grad()

if idx % 10 == 0:
generated_output = model.generate(pixel_values=pixel_values)
print(processor.batch_decode(generated_output, skip_special_tokens=True))
8 changes: 7 additions & 1 deletion src/peft/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
PeftModelForSeq2SeqLM,
PeftModelForSequenceClassification,
PeftModelForTokenClassification,
PeftModelForVision2Seq,
)
from .tuners import LoraConfig, PrefixTuningConfig, PromptEncoderConfig, PromptTuningConfig
from .utils import PromptLearningConfig
Expand All @@ -29,6 +30,7 @@
"SEQ_2_SEQ_LM": PeftModelForSeq2SeqLM,
"CAUSAL_LM": PeftModelForCausalLM,
"TOKEN_CLS": PeftModelForTokenClassification,
"VISION_2_SEQ": PeftModelForVision2Seq,
}

PEFT_TYPE_TO_CONFIG_MAPPING = {
Expand All @@ -44,6 +46,7 @@
"bart": ["q_proj", "v_proj"],
"gpt2": ["c_attn"],
"bloom": ["query_key_value"],
"blip-2": ["q", "v", "q_proj", "v_proj"],
"opt": ["q_proj", "v_proj"],
"gptj": ["q_proj", "v_proj"],
"gpt_neox": ["query_key_value"],
Expand Down Expand Up @@ -134,9 +137,12 @@ def get_peft_model(model, peft_config):
model ([`transformers.PreTrainedModel`]): Model to be wrapped.
peft_config ([`PeftConfig`]): Configuration object containing the parameters of the Peft model.
"""

model_config = model.config.to_dict()
peft_config.base_model_name_or_path = model.__dict__.get("name_or_path", None)

if peft_config.task_type == "VISION_2_SEQ" and not isinstance(peft_config, LoraConfig):
raise ValueError("Vision2Seq task type is only supported with LORA")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't required if the task type is left unspecified. For unspecified tasks, line 146-148 already use LoRA via PeftModel as task-specific sub-class isn't required for LoRA method.

if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys():
peft_config = _prepare_lora_config(peft_config, model_config)
return PeftModel(model, peft_config)
Expand Down
69 changes: 69 additions & 0 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,3 +1034,72 @@ def _prefix_tuning_forward(
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


class PeftModelForVision2Seq(PeftModel):
"""
Peft model for vision to text models.

Args:
model ([`~transformers.PreTrainedModel`]): Base transformer model.
peft_config ([`PeftConfig`]): Peft config.


Example:

```py
>>> from transformers import AutoModelForVision2Seq
>>> from peft import PeftModelForVision2Seq, get_peft_config

>>> config = {
... "peft_type": "LORA",
... "task_type": "VISION_2_SEQ",
... "inference_mode": False,
... "r": 8,
... "target_modules": ["q", "v"],
... "lora_alpha": 32,
... "lora_dropout": 0.1,
... "merge_weights": False,
... "fan_in_fan_out": False,
... "enable_lora": None,
... "bias": "none",
... }

>>> peft_config = get_peft_config(config)
>>> model = AutoModelForVision2Seq.from_pretrained("Salesforce/blip2-flan-t5-xl")
>>> peft_model = PeftModelForVision2Seq(model, peft_config)
>>> peft_model.print_trainable_parameters()
trainable params: 1843200 || all params: 775873280 || trainable%: 0.23756456724479544
```
"""

def __init__(self, model, peft_config: PeftConfig):
super().__init__(model, peft_config)
self.base_model_prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation

def forward(
self,
pixel_values=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs,
):
r"""
A simple wrapper around the base model's forward method.
"""
return self.base_model(
pixel_values=pixel_values,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following previous comment, this isn't required if we aren't supporting methods apart from LoRA