diff --git a/mindnlp/transformers/generation/utils.py b/mindnlp/transformers/generation/utils.py index cc1dccbcc..f22e34f29 100644 --- a/mindnlp/transformers/generation/utils.py +++ b/mindnlp/transformers/generation/utils.py @@ -1297,6 +1297,13 @@ def _prepare_generated_length( and not self.config.is_encoder_decoder ): generation_config.max_length -= inputs_tensor.shape[1] + # by default let's always generate 20 new tokens + elif has_default_max_length: + if generation_config.max_length == GenerationConfig().max_length: + generation_config.max_length = generation_config.max_length + input_ids_length + max_position_embeddings = getattr(self.config, "max_position_embeddings", None) + if max_position_embeddings is not None: + generation_config.max_length = min(generation_config.max_length, max_position_embeddings) # same for min length if generation_config.min_new_tokens is not None: diff --git a/mindnlp/transformers/models/blip_2/modeling_blip_2.py b/mindnlp/transformers/models/blip_2/modeling_blip_2.py index 5c1a90712..b6af1ba78 100644 --- a/mindnlp/transformers/models/blip_2/modeling_blip_2.py +++ b/mindnlp/transformers/models/blip_2/modeling_blip_2.py @@ -2881,31 +2881,45 @@ def generate( *language_model_inputs.shape[:-1], dtype=mindspore.int64 ) if input_ids is None: - input_ids = ( - mindspore.Tensor([[self.config.text_config.bos_token_id]]) - .repeat(batch_size, 1) - ) + start_tokens = [self.config.text_config.bos_token_id] + if getattr(self.config, "image_token_index", None) is not None: + start_tokens = [self.config.image_token_index] * self.config.num_query_tokens + start_tokens + input_ids = ops.tile(mindspore.Tensor([start_tokens]), (batch_size, 1)) + + inputs_embeds = self.get_input_embeddings()(input_ids) if attention_mask is None: attention_mask = ops.ones_like(input_ids) - attention_mask = ops.cat([language_attention_mask, attention_mask], dim=1) - # concatenate query embeddings with prompt embeddings - inputs_embeds = self.get_input_embeddings()(input_ids) - inputs_embeds = ops.cat([language_model_inputs, inputs_embeds], dim=1) + # if the model already has "image_token_index" then the input is expanded to account for image embeds + # otherwise we expand manually by concatenating + if getattr(self.config, "image_token_index", None) is not None: + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) + inputs_embeds[special_image_mask] = language_model_inputs.flatten() + else: + logger.warning_once( + "Expanding inputs for image tokens in BLIP-2 should be done in processing. " + "Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your BLIP-2 model. " + "Using processors without these attributes in the config is deprecated and will throw an error in v4.50." + ) + inputs_embeds = ops.cat([language_model_inputs, inputs_embeds], dim=1) + attention_mask = ops.cat( + [language_attention_mask, attention_mask], dim=1 + ) - # add image_embeds length to max_length, so that the final max_length in counted only on token embeds - # -1 is to account for the prepended BOS after `generate.` - # TODO (joao, raushan): refactor `generate` to avoid these operations with VLMs - if not self.language_model.config.is_encoder_decoder: - generate_kwargs["max_length"] = generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] - 1 - generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1] + # add image_embeds length to max_length, so that the final max_length in counted only on token embeds + # -1 is to account for the prepended BOS after `generate.` + # TODO (joao, raushan): refactor `generate` to avoid these operations with VLMs + if not self.language_model.config.is_encoder_decoder: + generate_kwargs["max_length"] = ( + generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] - 1 + ) + generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1] - outputs = self.language_model.generate( - inputs_embeds=inputs_embeds, - attention_mask=attention_mask, - **generate_kwargs, - ) + inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask} + if not self.language_model.config.is_encoder_decoder: + inputs["input_ids"] = input_ids + outputs = self.language_model.generate(**inputs, **generate_kwargs) return outputs __all__ = [ diff --git a/mindnlp/transformers/models/blip_2/processing_blip_2.py b/mindnlp/transformers/models/blip_2/processing_blip_2.py index 8998ece23..be98736aa 100644 --- a/mindnlp/transformers/models/blip_2/processing_blip_2.py +++ b/mindnlp/transformers/models/blip_2/processing_blip_2.py @@ -20,7 +20,7 @@ from ...image_utils import ImageInput from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy +from ...tokenization_utils_base import AddedToken, BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy from ....utils import TensorType @@ -36,13 +36,16 @@ class Blip2Processor(ProcessorMixin): An instance of [`BlipImageProcessor`]. The image processor is a required input. tokenizer (`AutoTokenizer`): An instance of ['PreTrainedTokenizer`]. The tokenizer is a required input. + num_query_tokens (`int`, *optional*): + Number of tokens used by the Qformer as queries, should be same as in model's config. """ attributes = ["image_processor", "tokenizer"] + valid_kwargs = ["num_query_tokens"] image_processor_class = "BlipImageProcessor" tokenizer_class = "AutoTokenizer" # Copied from transformers.models.blip.processing_blip.BlipProcessor.__init__ - def __init__(self, image_processor, tokenizer): + def __init__(self, image_processor, tokenizer, num_query_tokens=None, **kwargs): """ Initializes a new instance of the Blip2Processor class. @@ -53,7 +56,8 @@ def __init__(self, image_processor, tokenizer): tokenizer: An object representing the tokenizer to be used. t should have the necessary methods and attributes required for tokenization. The 'return_token_type_ids' attribute of the tokenizer will be set to False. - + num_query_tokens (`int`, *optional*): + Number of tokens used by the Qformer as queries, should be same as in model's config. Returns: None. @@ -61,8 +65,15 @@ def __init__(self, image_processor, tokenizer): None. """ tokenizer.return_token_type_ids = False + self.current_processor = image_processor + if not hasattr(tokenizer, "image_token"): + self.image_token = AddedToken("", normalized=False, special=True) + tokenizer.add_tokens([self.image_token], special_tokens=True) + else: + self.image_token = tokenizer.image_token + self.num_query_tokens = num_query_tokens + super().__init__(image_processor, tokenizer) - self.current_processor = self.image_processor # Copied from transformers.models.blip.processing_blip.BlipProcessor.__call__ def __call__(