-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add BLIP2 Example #260
Add BLIP2 Example #260
Changes from 9 commits
c2ef46f
3d1e87c
c7e22cc
af6794e
f569bc6
46ab596
96cd039
4cbd6cf
8c83386
7ed9ad0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# coding=utf-8 | ||
# Copyright 2023-present the HuggingFace Inc. team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import torch | ||
from datasets import load_dataset | ||
from torch.utils.data import DataLoader, Dataset | ||
from transformers import AutoModelForVision2Seq, AutoProcessor | ||
|
||
from peft import LoraConfig, get_peft_model | ||
|
||
|
||
# Let's define the LoraConfig | ||
config = LoraConfig( | ||
r=16, | ||
lora_alpha=32, | ||
lora_dropout=0.05, | ||
bias="none", | ||
task_type="VISION_2_SEQ", | ||
) | ||
|
||
# We load our model and processor using `transformers` | ||
model = AutoModelForVision2Seq.from_pretrained("Salesforce/blip2-opt-2.7b", load_in_8bit=True, device_map={"": 0}) | ||
processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b") | ||
|
||
# Get our peft model and print the number of trainable parameters | ||
model = get_peft_model(model, config) | ||
model.print_trainable_parameters() | ||
|
||
# Let's load the dataset here! | ||
dataset = load_dataset("ybelkada/football-dataset", split="train") | ||
|
||
|
||
class ImageCaptioningDataset(Dataset): | ||
def __init__(self, dataset, processor): | ||
self.dataset = dataset | ||
self.processor = processor | ||
|
||
def __len__(self): | ||
return len(self.dataset) | ||
|
||
def __getitem__(self, idx): | ||
item = self.dataset[idx] | ||
encoding = self.processor(images=item["image"], padding="max_length", return_tensors="pt") | ||
# remove batch dimension | ||
encoding = {k: v.squeeze() for k, v in encoding.items()} | ||
encoding["text"] = item["text"] | ||
return encoding | ||
|
||
|
||
def collator(batch): | ||
# pad the input_ids and attention_mask | ||
processed_batch = {} | ||
for key in batch[0].keys(): | ||
if key != "text": | ||
processed_batch[key] = torch.stack([example[key] for example in batch]) | ||
else: | ||
text_inputs = processor.tokenizer( | ||
[example["text"] for example in batch], padding=True, return_tensors="pt" | ||
) | ||
processed_batch["input_ids"] = text_inputs["input_ids"] | ||
processed_batch["attention_mask"] = text_inputs["attention_mask"] | ||
return processed_batch | ||
|
||
|
||
train_dataset = ImageCaptioningDataset(dataset, processor) | ||
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=2, collate_fn=collator) | ||
|
||
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5) | ||
|
||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
||
model.train() | ||
|
||
for epoch in range(50): | ||
print("Epoch:", epoch) | ||
for idx, batch in enumerate(train_dataloader): | ||
input_ids = batch.pop("input_ids").to(device) | ||
pixel_values = batch.pop("pixel_values").to(device, torch.float16) | ||
|
||
outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=input_ids) | ||
|
||
loss = outputs.loss | ||
|
||
print("Loss:", loss.item()) | ||
|
||
loss.backward() | ||
|
||
optimizer.step() | ||
optimizer.zero_grad() | ||
|
||
if idx % 10 == 0: | ||
generated_output = model.generate(pixel_values=pixel_values) | ||
print(processor.batch_decode(generated_output, skip_special_tokens=True)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,7 @@ | |
PeftModelForSeq2SeqLM, | ||
PeftModelForSequenceClassification, | ||
PeftModelForTokenClassification, | ||
PeftModelForVision2Seq, | ||
) | ||
from .tuners import LoraConfig, PrefixTuningConfig, PromptEncoderConfig, PromptTuningConfig | ||
from .utils import PromptLearningConfig | ||
|
@@ -29,6 +30,7 @@ | |
"SEQ_2_SEQ_LM": PeftModelForSeq2SeqLM, | ||
"CAUSAL_LM": PeftModelForCausalLM, | ||
"TOKEN_CLS": PeftModelForTokenClassification, | ||
"VISION_2_SEQ": PeftModelForVision2Seq, | ||
} | ||
|
||
PEFT_TYPE_TO_CONFIG_MAPPING = { | ||
|
@@ -44,6 +46,7 @@ | |
"bart": ["q_proj", "v_proj"], | ||
"gpt2": ["c_attn"], | ||
"bloom": ["query_key_value"], | ||
"blip-2": ["q", "v", "q_proj", "v_proj"], | ||
"opt": ["q_proj", "v_proj"], | ||
"gptj": ["q_proj", "v_proj"], | ||
"gpt_neox": ["query_key_value"], | ||
|
@@ -134,9 +137,12 @@ def get_peft_model(model, peft_config): | |
model ([`transformers.PreTrainedModel`]): Model to be wrapped. | ||
peft_config ([`PeftConfig`]): Configuration object containing the parameters of the Peft model. | ||
""" | ||
|
||
model_config = model.config.to_dict() | ||
peft_config.base_model_name_or_path = model.__dict__.get("name_or_path", None) | ||
|
||
if peft_config.task_type == "VISION_2_SEQ" and not isinstance(peft_config, LoraConfig): | ||
raise ValueError("Vision2Seq task type is only supported with LORA") | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This isn't required if the task type is left unspecified. For unspecified tasks, line 146-148 already use LoRA via |
||
if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys(): | ||
peft_config = _prepare_lora_config(peft_config, model_config) | ||
return PeftModel(model, peft_config) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1034,3 +1034,72 @@ def _prefix_tuning_forward( | |
hidden_states=outputs.hidden_states, | ||
attentions=outputs.attentions, | ||
) | ||
|
||
|
||
class PeftModelForVision2Seq(PeftModel): | ||
""" | ||
Peft model for vision to text models. | ||
|
||
Args: | ||
model ([`~transformers.PreTrainedModel`]): Base transformer model. | ||
peft_config ([`PeftConfig`]): Peft config. | ||
|
||
|
||
Example: | ||
|
||
```py | ||
>>> from transformers import AutoModelForVision2Seq | ||
>>> from peft import PeftModelForVision2Seq, get_peft_config | ||
|
||
>>> config = { | ||
... "peft_type": "LORA", | ||
... "task_type": "VISION_2_SEQ", | ||
... "inference_mode": False, | ||
... "r": 8, | ||
... "target_modules": ["q", "v"], | ||
... "lora_alpha": 32, | ||
... "lora_dropout": 0.1, | ||
... "merge_weights": False, | ||
... "fan_in_fan_out": False, | ||
... "enable_lora": None, | ||
... "bias": "none", | ||
... } | ||
|
||
>>> peft_config = get_peft_config(config) | ||
>>> model = AutoModelForVision2Seq.from_pretrained("Salesforce/blip2-flan-t5-xl") | ||
>>> peft_model = PeftModelForVision2Seq(model, peft_config) | ||
>>> peft_model.print_trainable_parameters() | ||
trainable params: 1843200 || all params: 775873280 || trainable%: 0.23756456724479544 | ||
``` | ||
""" | ||
|
||
def __init__(self, model, peft_config: PeftConfig): | ||
super().__init__(model, peft_config) | ||
self.base_model_prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation | ||
|
||
def forward( | ||
self, | ||
pixel_values=None, | ||
attention_mask=None, | ||
decoder_input_ids=None, | ||
decoder_attention_mask=None, | ||
labels=None, | ||
output_attentions=None, | ||
output_hidden_states=None, | ||
return_dict=None, | ||
**kwargs, | ||
): | ||
r""" | ||
A simple wrapper around the base model's forward method. | ||
""" | ||
return self.base_model( | ||
pixel_values=pixel_values, | ||
attention_mask=attention_mask, | ||
decoder_input_ids=decoder_input_ids, | ||
decoder_attention_mask=decoder_attention_mask, | ||
labels=labels, | ||
output_attentions=output_attentions, | ||
output_hidden_states=output_hidden_states, | ||
return_dict=return_dict, | ||
**kwargs, | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Following previous comment, this isn't required if we aren't supporting methods apart from LoRA |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keeping this unspecified will automatically use the LoRA model via PeftModel object as task-specific class isn't a requirement for LoRA