Skip to content

Commit

Permalink
Add Blip2ForImageTextRetrieval
Browse files Browse the repository at this point in the history
  • Loading branch information
jpizarrom committed Sep 22, 2023
1 parent 5936c8c commit 0c20b66
Show file tree
Hide file tree
Showing 9 changed files with 854 additions and 93 deletions.
7 changes: 6 additions & 1 deletion docs/source/en/model_doc/blip-2.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,9 @@ If you're interested in submitting a resource to be included here, please feel f

[[autodoc]] Blip2ForConditionalGeneration
- forward
- generate
- generate

## Blip2ForImageTextRetrieval

[[autodoc]] Blip2ForImageTextRetrieval
- forward
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1334,6 +1334,7 @@
[
"BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST",
"Blip2ForConditionalGeneration",
"Blip2ForImageTextRetrieval",
"Blip2Model",
"Blip2PreTrainedModel",
"Blip2QFormerModel",
Expand Down Expand Up @@ -5346,6 +5347,7 @@
from .models.blip_2 import (
BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST,
Blip2ForConditionalGeneration,
Blip2ForImageTextRetrieval,
Blip2Model,
Blip2PreTrainedModel,
Blip2QFormerModel,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/blip_2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
"Blip2QFormerModel",
"Blip2PreTrainedModel",
"Blip2ForConditionalGeneration",
"Blip2ForImageTextRetrieval",
"Blip2VisionModel",
]

Expand All @@ -59,6 +60,7 @@
from .modeling_blip_2 import (
BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST,
Blip2ForConditionalGeneration,
Blip2ForImageTextRetrieval,
Blip2Model,
Blip2PreTrainedModel,
Blip2QFormerModel,
Expand Down
13 changes: 12 additions & 1 deletion src/transformers/models/blip_2/configuration_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def __init__(
position_embedding_type="absolute",
cross_attention_frequency=2,
encoder_hidden_size=1408,
qformer_text_input=False,
**kwargs,
):
super().__init__(pad_token_id=pad_token_id, **kwargs)
Expand All @@ -227,6 +228,7 @@ def __init__(
self.position_embedding_type = position_embedding_type
self.cross_attention_frequency = cross_attention_frequency
self.encoder_hidden_size = encoder_hidden_size
self.qformer_text_input = qformer_text_input

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
Expand Down Expand Up @@ -302,7 +304,15 @@ class Blip2Config(PretrainedConfig):

model_type = "blip-2"

def __init__(self, vision_config=None, qformer_config=None, text_config=None, num_query_tokens=32, **kwargs):
def __init__(
self,
vision_config=None,
qformer_config=None,
text_config=None,
num_query_tokens=32,
image_text_hidden_size=256,
**kwargs,
):
super().__init__(**kwargs)

if vision_config is None:
Expand All @@ -326,6 +336,7 @@ def __init__(self, vision_config=None, qformer_config=None, text_config=None, nu
self.is_encoder_decoder = self.text_config.is_encoder_decoder

self.num_query_tokens = num_query_tokens
self.image_text_hidden_size = image_text_hidden_size
self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size
self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
self.initializer_factor = 1.0
Expand Down
187 changes: 134 additions & 53 deletions src/transformers/models/blip_2/convert_blip_2_original_to_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,12 @@

from transformers import (
AutoTokenizer,
BertTokenizer,
Blip2Config,
Blip2ForConditionalGeneration,
Blip2ForImageTextRetrieval,
Blip2Processor,
Blip2QFormerConfig,
Blip2VisionConfig,
BlipImageProcessor,
OPTConfig,
Expand All @@ -51,7 +54,7 @@ def load_demo_image():


# here we list all keys to be renamed (original name on the left, our name on the right)
def create_rename_keys(config):
def create_rename_keys(config, model_name):
rename_keys = []
# fmt: off

Expand All @@ -77,8 +80,9 @@ def create_rename_keys(config):
rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc2.bias", f"vision_model.encoder.layers.{i}.mlp.fc2.bias"))

# QFormer
rename_keys.append(("Qformer.bert.embeddings.LayerNorm.weight", "qformer.layernorm.weight"))
rename_keys.append(("Qformer.bert.embeddings.LayerNorm.bias", "qformer.layernorm.bias"))
if "itm" not in model_name:
rename_keys.append(("Qformer.bert.embeddings.LayerNorm.weight", "qformer.layernorm.weight"))
rename_keys.append(("Qformer.bert.embeddings.LayerNorm.bias", "qformer.layernorm.bias"))

# fmt: on
return rename_keys
Expand Down Expand Up @@ -114,8 +118,19 @@ def get_blip2_config(model_name, eos_token_id):
text_config = T5Config.from_pretrained("google/flan-t5-xl", dense_act_fn="gelu", bos_token_id=1).to_dict()
elif "t5-xxl" in model_name:
text_config = T5Config.from_pretrained("google/flan-t5-xxl", dense_act_fn="gelu", bos_token_id=1).to_dict()

config = Blip2Config(vision_config=vision_config, text_config=text_config)
elif "itm" in model_name:
text_config = {}
else:
raise ValueError("Model name not supported")

if "itm" in model_name:
config = Blip2Config(
vision_config=vision_config,
text_config=text_config,
qformer_config=Blip2QFormerConfig(vocab_size=30523, qformer_text_input=True).to_dict(),
)
else:
config = Blip2Config(vision_config=vision_config, text_config=text_config)

return config, image_size

Expand All @@ -125,15 +140,24 @@ def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_
"""
Copy/paste/tweak model's weights to Transformers design.
"""
tokenizer = (
AutoTokenizer.from_pretrained("facebook/opt-2.7b")
if "opt" in model_name
else AutoTokenizer.from_pretrained("google/flan-t5-xl")
)
eos_token_id = tokenizer("\n", add_special_tokens=False).input_ids[0]
if "opt" in model_name:
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-2.7b")
elif "itm" in model_name:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side="right")
tokenizer.add_special_tokens({"bos_token": "[DEC]"})
else:
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xl")

if "itm" in model_name:
eos_token_id = None
else:
eos_token_id = tokenizer("\n", add_special_tokens=False).input_ids[0]
config, image_size = get_blip2_config(model_name, eos_token_id=eos_token_id)

hf_model = Blip2ForConditionalGeneration(config).eval()
if "itm" in model_name:
hf_model = Blip2ForImageTextRetrieval(config).eval()
else:
hf_model = Blip2ForConditionalGeneration(config).eval()

model_name_to_original = {
"blip2-opt-2.7b": ("blip2_opt", "pretrain_opt2.7b"),
Expand All @@ -143,6 +167,9 @@ def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_
"blip2-flan-t5-xl": ("blip2_t5", "pretrain_flant5xl"),
"blip2-flan-t5-xl-coco": ("blip2_t5", "caption_coco_flant5xl"),
"blip2-flan-t5-xxl": ("blip2_t5", "pretrain_flant5xxl"),
"blip2-itm-vit-g": ("blip2_image_text_matching", "pretrain"),
# "blip2-itm-vit-large": ("blip2_image_text_matching", "pretrain_vitL"),
"blip2-itm-vit-g-coco": ("blip2_image_text_matching", "coco"),
}

name, type = model_name_to_original[model_name]
Expand All @@ -163,13 +190,15 @@ def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_

# update state dict keys
state_dict = original_model.state_dict()
rename_keys = create_rename_keys(config)
rename_keys = create_rename_keys(config, model_name)
for src, dest in rename_keys:
rename_key(state_dict, src, dest)

# some keys can be renamed efficiently
for key, val in state_dict.copy().items():
val = state_dict.pop(key)
if key.startswith("Qformer.cls"):
key = key.replace("Qformer.cls", "cls")
if key.startswith("Qformer.bert"):
key = key.replace("Qformer.bert", "qformer")
if "attention.self" in key:
Expand All @@ -193,7 +222,6 @@ def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_

image = load_demo_image()
original_pixel_values = vis_processors["eval"](image).unsqueeze(0).to(lavis_device)
input_ids = tokenizer(["\n"], return_tensors="pt").input_ids.to(hf_model_device)

# create processor
image_processor = BlipImageProcessor(
Expand All @@ -207,50 +235,100 @@ def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_

original_model.to(lavis_device)
hf_model.to(hf_model_device)
with torch.no_grad():
if "opt" in model_name:
original_logits = original_model({"image": original_pixel_values, "text_input": [""]}).logits
logits = hf_model(pixel_values, input_ids).logits
else:
original_logits = original_model(
{"image": original_pixel_values, "text_input": ["\n"], "text_output": ["\n"]}
).logits
labels = input_ids.masked_fill(input_ids == tokenizer.pad_token_id, -100)
logits = hf_model(pixel_values, input_ids, labels=labels).logits

assert original_logits.shape == logits.shape
print("First values of original logits:", original_logits[0, :3, :3])
print("First values of HF logits:", logits[0, :3, :3])
if "itm" in model_name:
caption = "a large fountain spewing water into the air"
input_ids = tokenizer([caption], return_tensors="pt").input_ids.to(hf_model_device)
attention_mask = processor(text=caption, return_tensors="pt").attention_mask.to(hf_model_device)

with torch.no_grad():
original_logits = original_model(
{"image": original_pixel_values, "text_input": [caption]}, match_head="itm"
)
logits = hf_model(pixel_values=original_pixel_values, input_ids=input_ids, attention_mask=attention_mask)

# assert values
assert torch.allclose(original_logits.to(logits.device), logits, atol=1e-4)
print("Looks ok!")
assert original_logits.shape == logits.itm_score.shape
print("First values of original logits:", original_logits[0, :3])
print("First values of HF logits:", logits.itm_score[0, :3])

print("Generating a caption...")
prompt = "Question: what object is in this image? Answer:"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(hf_model_device)
# assert values
# cast to same type
target_dtype = logits.itm_score.dtype
assert torch.allclose(original_logits.to(target_dtype), logits.itm_score, atol=1e-2)

set_seed(42)
original_itm_scores = torch.nn.functional.softmax(original_logits, dim=1)
itm_scores = torch.nn.functional.softmax(logits.itm_score, dim=1)
assert torch.allclose(original_itm_scores.to(target_dtype), itm_scores, atol=1e-2)
print("Looks ok!")

original_outputs = original_model.generate(
{"image": original_pixel_values, "prompt": prompt}, use_nucleus_sampling=True
)
outputs = hf_model.generate(
pixel_values,
input_ids,
do_sample=True,
num_beams=5,
max_length=30,
min_length=1,
top_p=0.9,
repetition_penalty=1.0,
length_penalty=1.0,
temperature=1,
)
output_text = processor.batch_decode(outputs, skip_special_tokens=True)
output_text = [text.strip() for text in output_text]
print("Original generation:", original_outputs)
print("HF generation:", output_text)
with torch.no_grad():
original_logits = original_model(
{"image": original_pixel_values, "text_input": [caption]}, match_head="itc"
)
logits = hf_model(
pixel_values=original_pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
use_itm_head=False,
)

assert original_logits.shape == logits.itm_score.shape
print("First values of original logits:", original_logits[0, :3])
print("First values of HF logits:", logits.itm_score[0, :3])

# assert values
# cast to same type
target_dtype = logits.itm_score.dtype
assert torch.allclose(original_logits.to(target_dtype), logits.itm_score, atol=1e-2)
print("Looks ok!")

else:
input_ids = tokenizer(["\n"], return_tensors="pt").input_ids.to(hf_model_device)

with torch.no_grad():
if "opt" in model_name:
original_logits = original_model({"image": original_pixel_values, "text_input": [""]}).logits
logits = hf_model(pixel_values, input_ids).logits
else:
original_logits = original_model(
{"image": original_pixel_values, "text_input": ["\n"], "text_output": ["\n"]}
).logits
labels = input_ids.masked_fill(input_ids == tokenizer.pad_token_id, -100)
logits = hf_model(pixel_values, input_ids, labels=labels).logits

assert original_logits.shape == logits.shape
print("First values of original logits:", original_logits[0, :3, :3])
print("First values of HF logits:", logits[0, :3, :3])

# assert values
assert torch.allclose(original_logits.to(logits.device), logits, atol=1e-4)
print("Looks ok!")

print("Generating a caption...")
prompt = "Question: what object is in this image? Answer:"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(hf_model_device)

set_seed(42)

original_outputs = original_model.generate(
{"image": original_pixel_values, "prompt": prompt}, use_nucleus_sampling=True
)
outputs = hf_model.generate(
pixel_values,
input_ids,
do_sample=True,
num_beams=5,
max_length=30,
min_length=1,
top_p=0.9,
repetition_penalty=1.0,
length_penalty=1.0,
temperature=1,
)
output_text = processor.batch_decode(outputs, skip_special_tokens=True)
output_text = [text.strip() for text in output_text]
print("Original generation:", original_outputs)
print("HF generation:", output_text)

if pytorch_dump_folder_path is not None:
processor.save_pretrained(pytorch_dump_folder_path)
Expand All @@ -271,6 +349,9 @@ def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_
"blip2-flan-t5-xl",
"blip2-flan-t5-xl-coco",
"blip2-flan-t5-xxl",
"blip2-itm-vit-g",
# "blip2-itm-vit-large",
"blip2-itm-vit-g-coco",
]
parser.add_argument(
"--model_name",
Expand Down
Loading

0 comments on commit 0c20b66

Please sign in to comment.