From 6e3d7db20f810c3fc38504e48037f96552832b7d Mon Sep 17 00:00:00 2001 From: Peng Chen Date: Wed, 11 Oct 2023 17:32:05 -0700 Subject: [PATCH] add blip2 layer under torchmm/models (#484) Summary: Pull Request resolved: https://github.com/facebookresearch/multimodal/pull/484 as title Differential Revision: D50145708 fbshipit-source-id: e6253b044781b2eea81b05ab46041a7ded080937 --- tests/models/blip2/test_blip2.py | 137 +++++++++++++++ torchmultimodal/models/blip2/blip2.py | 157 ++++++++++++++++++ torchmultimodal/models/blip2/qformer_model.py | 2 +- 3 files changed, 295 insertions(+), 1 deletion(-) create mode 100644 tests/models/blip2/test_blip2.py create mode 100644 torchmultimodal/models/blip2/blip2.py diff --git a/tests/models/blip2/test_blip2.py b/tests/models/blip2/test_blip2.py new file mode 100644 index 00000000..e9a294c4 --- /dev/null +++ b/tests/models/blip2/test_blip2.py @@ -0,0 +1,137 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import pytest +import torch +import torch.nn as nn +from tests.test_utils import assert_expected, init_weights_with_constant +from torchmultimodal.models.blip2.blip2 import BLIP2 +from torchmultimodal.models.blip2.qformer_model import QformerForCLM +from torchmultimodal.modules.encoders.vision_transformer import VisionTransformer +from torchmultimodal.modules.layers.patch_embedding import PatchEmbeddings +from torchmultimodal.modules.layers.transformer import TransformerEncoder + + +@pytest.fixture +def dim_q(): + return 4 + + +@pytest.fixture +def dim_kv(): + return 2 + + +@pytest.fixture +def dim_feedforward(): + return 6 + + +@pytest.fixture +def num_hidden_layers(): + return 2 + + +@pytest.fixture +def num_heads(): + return 2 + + +@pytest.fixture +def vocab_size(): + return 20 + + +@pytest.fixture +def qformer_model_for_clm( + dim_q, + dim_kv, + dim_feedforward, + num_hidden_layers, + num_heads, + vocab_size, +): + qformer_for_clm = QformerForCLM( + dim_q=dim_q, + dim_kv=dim_kv, + dim_feedforward=dim_feedforward, + num_heads=num_heads, + attn_dropout=0.0, + dropout=0.0, + num_hidden_layers=num_hidden_layers, + max_position_embeddings=512, + vocab_size=vocab_size, + ) + return qformer_for_clm + + +@pytest.fixture +def vit(): + embedding = PatchEmbeddings(image_size=2, patch_size=1, hidden_size=2) + encoder = TransformerEncoder( + n_layer=1, + d_model=2, + n_head=1, + dim_feedforward=1, + activation=nn.GELU, + norm_first=True, + final_layer_norm_eps=1e-5, + ) + image_encoder = VisionTransformer( + embeddings=embedding, + encoder=encoder, + ) + init_weights_with_constant(image_encoder) + image_encoder.eval() + return image_encoder + + +@pytest.fixture +def blip2(dim_q, dim_kv, qformer_model_for_clm, vit): + blip2 = BLIP2( + dim_q=dim_q, + image_encoder_embedding_dim=dim_kv, + qformer=qformer_model_for_clm, + vision_encoder=vit, + embedding_dim=4, + decoder_bos_token_id=19, + ) + init_weights_with_constant(blip2) + blip2.eval() + return blip2 + + +@pytest.fixture +def attn_mask(): + return torch.Tensor([[1.0, 0.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0]]) + + +class TestBLIP2: + def test_blip2(self, blip2, attn_mask): + image = torch.ones(2, 3, 2, 2) + input_ids = torch.ones(2, 4).long() + output = blip2(image, input_ids, attn_mask) + assert_expected( + output.image_features, torch.ones([2, 32, 4]) * 0.5, rtol=0, atol=1e-4 + ) + assert_expected( + output.text_features, torch.ones([2, 4]) * 0.5, rtol=0, atol=1e-4 + ) + assert_expected( + output.image_embeddings, torch.ones([2, 5, 2]), rtol=0, atol=1e-4 + ) + assert_expected( + output.prediction_scores, torch.ones([2, 4, 20]) * 5, rtol=0, atol=1e-4 + ) + + def test_blip2_scripting(self, blip2, attn_mask): + image = torch.ones(2, 3, 2, 2) + input_ids = torch.ones(2, 4).long() + scripted_model = torch.jit.script(blip2) + actual = scripted_model(image, input_ids, attn_mask) + expected = blip2(image, input_ids, attn_mask) + assert_expected(actual, expected) diff --git a/torchmultimodal/models/blip2/blip2.py b/torchmultimodal/models/blip2/blip2.py new file mode 100644 index 00000000..a3dfde05 --- /dev/null +++ b/torchmultimodal/models/blip2/blip2.py @@ -0,0 +1,157 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import NamedTuple, Optional + +import torch + +from torch import nn, Tensor +from torch.nn import functional as F +from torchmultimodal.modules.layers.transformer import TransformerOutput + + +class Blip2Output(NamedTuple): + """ + BLIP2 model output for loss computation. + + image_embeddings(Tensor): normalized image embeddings returned by the visual encoder + with shape [bsz x seq_len x embed_dim]. + image_features(Tensor): Image features after qformer and projection (for stage 1 training) + with shape [bsz, num_query_tokens, embed_dim] + image_qformer_output(Tensor) : last hidden state for qformer output by given image input + text_features(Optional[Tensor]): Text features after qformer and projection if text input is provided + with shape [bsz, embed_dim] + prediction_scores (Optional[Tensor]): computed for next word prediction + with shape of [bsz, seq_len, vocab_size] + """ + + image_embeddings: Tensor + image_features: Tensor + image_qformer_output: Tensor + text_features: Optional[Tensor] = None + prediction_scores: Optional[Tensor] = None + + +class BLIP2(nn.Module): + """ + BLIP2(https://arxiv.org/pdf/2301.12597.pdf) provides a pre-training strategy to bootstrap vision-language + pre-training from frozen image encoders and frozen large language models(LLM). BLIP-2 bridges the modality gap + and facilitates cross-modal alignment via Querying Transformer (Q-former). Q-former is a lightweight transformer + which has a set of learnable query vectors to extract visual features from the frozen image encoder. + + Args: + qformer(nn.Module): Querying Transformer (Q-former) + visual_encoder(nn.Module): Frozen image encoder + dim_q(int) : Dimension of query tensor, this value should be the same as dim_q in qformer. + image_encoder_embedding_dim(int): Embedding dimension for image encoder, + this value should be the same as dim_kv in qformer. + freeze_visual_encoder(bool): Whether to freeze the visual encoder, default to True + cross_attention_freq(int): Frequency of adding cross-attention block in Qformer, default to 2 + embedding_dim(int): Embedding dimension + num_query_token(int): Number of query tokens in Qformer, default to 32 + init_query_tokens(bool): whether init query token params, default to True + decoder_bos_token_id(Optional[int]): bos_token_id used in decoder, default to None + """ + + def __init__( + self, + qformer: nn.Module, + vision_encoder: nn.Module, + dim_q: int, + image_encoder_embedding_dim: int, + freeze_vision_encoder: bool = True, + cross_attention_freq: int = 2, + embedding_dim: int = 256, + num_query_token: int = 32, + init_query_tokens: bool = True, + decoder_bos_token_id: Optional[int] = None, + ): + super().__init__() + self.vision_encoder = vision_encoder + if freeze_vision_encoder: + for param in self.vision_encoder.parameters(): + param.requires_grad = False + self.vision_encoder = self.vision_encoder.eval() + + self.qformer = qformer + self.decoder_bos_token_id = decoder_bos_token_id + self.dim_q = dim_q + self.query_tokens = nn.Parameter(torch.zeros(1, num_query_token, self.dim_q)) + if init_query_tokens: + self.query_tokens.data.normal_(mean=0.0, std=0.02) + + self.vision_proj = nn.Linear(self.dim_q, embedding_dim) + self.text_proj = nn.Linear(self.dim_q, embedding_dim) + self.ln_vision = nn.LayerNorm(image_encoder_embedding_dim) + + def forward( + self, + image: Tensor, + input_ids: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + ) -> Blip2Output: + """ + Args: + image(Tensor): Image input tensor with shape [B, C, H, W] + input_ids(Optional[Tensor]): Text input tensor with shape [bsz, seq_len] + attention_mask(Optional[Tensor]): Attention mask tensor with shape [bsz, seq_len] + + Returns: + return BLIP2 model output(Blip2Output). + """ + vision_encoder_output = self.vision_encoder(image) + if isinstance(vision_encoder_output, TransformerOutput): + vision_encoder_output = vision_encoder_output.last_hidden_state + assert vision_encoder_output is not None + image_embeds = self.ln_vision(vision_encoder_output) + # query tokens: [batch_size, num_query_token, encoder_hidden_size] + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_output = self.qformer.model( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + use_cache=True, + ) + + # image_feats: [batch_size, num_query_token, embedding_dim] + image_feats = F.normalize(self.vision_proj(query_output[0]), dim=-1) + + text_feats: Optional[Tensor] = None + prediction_scores: Optional[Tensor] = None + if input_ids is not None: + text_output = self.qformer.model( + input_ids, + attention_mask=attention_mask, + use_cache=False, + ) + text_feats = F.normalize(self.text_proj(text_output[0][:, 0, :]), dim=-1) + + decoder_input_ids = input_ids.clone() + if self.decoder_bos_token_id is not None: + # pyre-ignore + decoder_input_ids[:, 0] = self.decoder_bos_token_id + + query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to( + input_ids.device + ) + if attention_mask is not None: + attention_mask = torch.cat([query_atts, attention_mask], dim=1) + + # set use_cache = False since past_key_values should be cached in previous steps. + prediction_scores = self.qformer( + input_ids=decoder_input_ids, + attention_mask=attention_mask, + past_key_values=query_output[1], + use_cache=False, + ) + + return Blip2Output( + image_embeddings=image_embeds, + image_features=image_feats, + image_qformer_output=query_output[0], + text_features=text_feats, + prediction_scores=prediction_scores, + ) diff --git a/torchmultimodal/models/blip2/qformer_model.py b/torchmultimodal/models/blip2/qformer_model.py index 8ce3cf68..b143dd9c 100644 --- a/torchmultimodal/models/blip2/qformer_model.py +++ b/torchmultimodal/models/blip2/qformer_model.py @@ -50,7 +50,7 @@ def __init__( attn_dropout: float = 0.0, dropout: float = 0.0, cross_attention_freq=2, - ): + ) -> None: super().__init__() self.query_length = query_length self.embeddings = QformerEmbedding(