Skip to content

Commit

Permalink
fix Blip2加载和推理bug #1902 #1904 #1905 (#1958)
Browse files Browse the repository at this point in the history
  • Loading branch information
Alemax067 authored Feb 27, 2025
1 parent 3ea3f4a commit 8c26b18
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 23 deletions.
7 changes: 7 additions & 0 deletions mindnlp/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1297,6 +1297,13 @@ def _prepare_generated_length(
and not self.config.is_encoder_decoder
):
generation_config.max_length -= inputs_tensor.shape[1]
# by default let's always generate 20 new tokens
elif has_default_max_length:
if generation_config.max_length == GenerationConfig().max_length:
generation_config.max_length = generation_config.max_length + input_ids_length
max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
if max_position_embeddings is not None:
generation_config.max_length = min(generation_config.max_length, max_position_embeddings)

# same for min length
if generation_config.min_new_tokens is not None:
Expand Down
52 changes: 33 additions & 19 deletions mindnlp/transformers/models/blip_2/modeling_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2881,31 +2881,45 @@ def generate(
*language_model_inputs.shape[:-1], dtype=mindspore.int64
)
if input_ids is None:
input_ids = (
mindspore.Tensor([[self.config.text_config.bos_token_id]])
.repeat(batch_size, 1)
)
start_tokens = [self.config.text_config.bos_token_id]
if getattr(self.config, "image_token_index", None) is not None:
start_tokens = [self.config.image_token_index] * self.config.num_query_tokens + start_tokens
input_ids = ops.tile(mindspore.Tensor([start_tokens]), (batch_size, 1))

inputs_embeds = self.get_input_embeddings()(input_ids)
if attention_mask is None:
attention_mask = ops.ones_like(input_ids)
attention_mask = ops.cat([language_attention_mask, attention_mask], dim=1)

# concatenate query embeddings with prompt embeddings
inputs_embeds = self.get_input_embeddings()(input_ids)
inputs_embeds = ops.cat([language_model_inputs, inputs_embeds], dim=1)
# if the model already has "image_token_index" then the input is expanded to account for image embeds
# otherwise we expand manually by concatenating
if getattr(self.config, "image_token_index", None) is not None:
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
inputs_embeds[special_image_mask] = language_model_inputs.flatten()
else:
logger.warning_once(
"Expanding inputs for image tokens in BLIP-2 should be done in processing. "
"Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your BLIP-2 model. "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.50."
)
inputs_embeds = ops.cat([language_model_inputs, inputs_embeds], dim=1)
attention_mask = ops.cat(
[language_attention_mask, attention_mask], dim=1
)

# add image_embeds length to max_length, so that the final max_length in counted only on token embeds
# -1 is to account for the prepended BOS after `generate.`
# TODO (joao, raushan): refactor `generate` to avoid these operations with VLMs
if not self.language_model.config.is_encoder_decoder:
generate_kwargs["max_length"] = generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] - 1
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]
# add image_embeds length to max_length, so that the final max_length in counted only on token embeds
# -1 is to account for the prepended BOS after `generate.`
# TODO (joao, raushan): refactor `generate` to avoid these operations with VLMs
if not self.language_model.config.is_encoder_decoder:
generate_kwargs["max_length"] = (
generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] - 1
)
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]

outputs = self.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
**generate_kwargs,
)
inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
if not self.language_model.config.is_encoder_decoder:
inputs["input_ids"] = input_ids

outputs = self.language_model.generate(**inputs, **generate_kwargs)
return outputs

__all__ = [
Expand Down
19 changes: 15 additions & 4 deletions mindnlp/transformers/models/blip_2/processing_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from ...image_utils import ImageInput
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...tokenization_utils_base import AddedToken, BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ....utils import TensorType


Expand All @@ -36,13 +36,16 @@ class Blip2Processor(ProcessorMixin):
An instance of [`BlipImageProcessor`]. The image processor is a required input.
tokenizer (`AutoTokenizer`):
An instance of ['PreTrainedTokenizer`]. The tokenizer is a required input.
num_query_tokens (`int`, *optional*):
Number of tokens used by the Qformer as queries, should be same as in model's config.
"""
attributes = ["image_processor", "tokenizer"]
valid_kwargs = ["num_query_tokens"]
image_processor_class = "BlipImageProcessor"
tokenizer_class = "AutoTokenizer"

# Copied from transformers.models.blip.processing_blip.BlipProcessor.__init__
def __init__(self, image_processor, tokenizer):
def __init__(self, image_processor, tokenizer, num_query_tokens=None, **kwargs):
"""
Initializes a new instance of the Blip2Processor class.
Expand All @@ -53,16 +56,24 @@ def __init__(self, image_processor, tokenizer):
tokenizer: An object representing the tokenizer to be used.
t should have the necessary methods and attributes required for tokenization.
The 'return_token_type_ids' attribute of the tokenizer will be set to False.
num_query_tokens (`int`, *optional*):
Number of tokens used by the Qformer as queries, should be same as in model's config.
Returns:
None.
Raises:
None.
"""
tokenizer.return_token_type_ids = False
self.current_processor = image_processor
if not hasattr(tokenizer, "image_token"):
self.image_token = AddedToken("<image>", normalized=False, special=True)
tokenizer.add_tokens([self.image_token], special_tokens=True)
else:
self.image_token = tokenizer.image_token
self.num_query_tokens = num_query_tokens

super().__init__(image_processor, tokenizer)
self.current_processor = self.image_processor

# Copied from transformers.models.blip.processing_blip.BlipProcessor.__call__
def __call__(
Expand Down

0 comments on commit 8c26b18

Please sign in to comment.