Skip to content

Commit

Permalink
Merge pull request #260 from younesbelkada/add-pix2struct
Browse files Browse the repository at this point in the history
Add BLIP2 Example
  • Loading branch information
pacman100 authored Apr 6, 2023
2 parents a7d5e51 + 7ed9ad0 commit 382b178
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 1 deletion.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,12 @@ An example is provided in `~examples/causal_language_modeling/peft_lora_clm_acce
| ViT || | | |
| Swin || | | |

### Image to text (Multi-modal models)

| Model | LoRA | Prefix Tuning | P-Tuning | Prompt Tuning |
| --------- | ---- | ---- | ---- | ---- |
| Blip-2 || | | |

___Note that we have tested LoRA for [ViT](https://huggingface.co/docs/transformers/model_doc/vit) and [Swin](https://huggingface.co/docs/transformers/model_doc/swin) for fine-tuning on image classification. However, it should be possible to use LoRA for any compatible model [provided](https://huggingface.co/models?pipeline_tag=image-classification&sort=downloads&search=vit) by 🤗 Transformers. Check out the respective
examples to learn more. If you run into problems, please open an issue.___

Expand Down
103 changes: 103 additions & 0 deletions examples/int8_training/fine_tune_blip2_int8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModelForVision2Seq, AutoProcessor

from peft import LoraConfig, get_peft_model


# Let's define the LoraConfig
config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
)

# We load our model and processor using `transformers`
model = AutoModelForVision2Seq.from_pretrained("Salesforce/blip2-opt-2.7b", load_in_8bit=True, device_map={"": 0})
processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")

# Get our peft model and print the number of trainable parameters
model = get_peft_model(model, config)
model.print_trainable_parameters()

# Let's load the dataset here!
dataset = load_dataset("ybelkada/football-dataset", split="train")


class ImageCaptioningDataset(Dataset):
def __init__(self, dataset, processor):
self.dataset = dataset
self.processor = processor

def __len__(self):
return len(self.dataset)

def __getitem__(self, idx):
item = self.dataset[idx]
encoding = self.processor(images=item["image"], padding="max_length", return_tensors="pt")
# remove batch dimension
encoding = {k: v.squeeze() for k, v in encoding.items()}
encoding["text"] = item["text"]
return encoding


def collator(batch):
# pad the input_ids and attention_mask
processed_batch = {}
for key in batch[0].keys():
if key != "text":
processed_batch[key] = torch.stack([example[key] for example in batch])
else:
text_inputs = processor.tokenizer(
[example["text"] for example in batch], padding=True, return_tensors="pt"
)
processed_batch["input_ids"] = text_inputs["input_ids"]
processed_batch["attention_mask"] = text_inputs["attention_mask"]
return processed_batch


train_dataset = ImageCaptioningDataset(dataset, processor)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=2, collate_fn=collator)

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

device = "cuda" if torch.cuda.is_available() else "cpu"

model.train()

for epoch in range(50):
print("Epoch:", epoch)
for idx, batch in enumerate(train_dataloader):
input_ids = batch.pop("input_ids").to(device)
pixel_values = batch.pop("pixel_values").to(device, torch.float16)

outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=input_ids)

loss = outputs.loss

print("Loss:", loss.item())

loss.backward()

optimizer.step()
optimizer.zero_grad()

if idx % 10 == 0:
generated_output = model.generate(pixel_values=pixel_values)
print(processor.batch_decode(generated_output, skip_special_tokens=True))
3 changes: 2 additions & 1 deletion src/peft/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
"bart": ["q_proj", "v_proj"],
"gpt2": ["c_attn"],
"bloom": ["query_key_value"],
"blip-2": ["q", "v", "q_proj", "v_proj"],
"opt": ["q_proj", "v_proj"],
"gptj": ["q_proj", "v_proj"],
"gpt_neox": ["query_key_value"],
Expand Down Expand Up @@ -164,9 +165,9 @@ def get_peft_model(model, peft_config):
model ([`transformers.PreTrainedModel`]): Model to be wrapped.
peft_config ([`PeftConfig`]): Configuration object containing the parameters of the Peft model.
"""

model_config = model.config.to_dict()
peft_config.base_model_name_or_path = model.__dict__.get("name_or_path", None)

if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys():
peft_config = _prepare_lora_config(peft_config, model_config)
return PeftModel(model, peft_config)
Expand Down

0 comments on commit 382b178

Please sign in to comment.