@@ -594,6 +594,8 @@ class GotOcr2PreTrainedModel(PreTrainedModel):
594594 _supports_cache_class = True
595595 _supports_flash_attn_2 = True
596596 _supports_sdpa = True
597+ _supports_quantized_cache = True
598+ _supports_static_cache = True
597599
598600 def _init_weights (self , module ):
599601 # important: this ported version of GotOcr2 isn't meant for training from scratch - only
@@ -748,89 +750,6 @@ def get_image_features(
748750 image_outputs = self .vision_tower (pixel_values ).last_hidden_state
749751 return self .multi_modal_projector (image_outputs )
750752
751- def _merge_input_ids_with_image_features (self , image_features , inputs_embeds , input_ids , attention_mask , labels ):
752- num_images , num_image_patches , embed_dim = image_features .shape
753- batch_size , sequence_length = input_ids .shape
754- left_padding = not torch .sum (input_ids [:, - 1 ] == torch .tensor (self .pad_token_id ))
755- # 1. Create a mask to know where special image tokens are
756- special_image_token_mask = input_ids == self .config .image_token_index
757- num_special_image_tokens = torch .sum (special_image_token_mask , dim = - 1 )
758- # Compute the maximum embed dimension
759- max_embed_dim = (num_special_image_tokens .max () * (num_image_patches - 1 )) + sequence_length
760- batch_indices , non_image_indices = torch .where (input_ids != self .config .image_token_index )
761-
762- # 2. Compute the positions where text should be written
763- # Calculate new positions for text tokens in merged image-text sequence.
764- # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
765- # `torch.cumsum` computes how each image token shifts subsequent text token positions.
766- # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
767- new_token_positions = torch .cumsum ((special_image_token_mask * (num_image_patches - 1 ) + 1 ), - 1 ) - 1
768- nb_image_pad = max_embed_dim - 1 - new_token_positions [:, - 1 ]
769- if left_padding :
770- new_token_positions += nb_image_pad [:, None ] # offset for left padding
771- text_to_overwrite = new_token_positions [batch_indices , non_image_indices ]
772-
773- # 3. Create the full embedding, already padded to the maximum position
774- final_embedding = torch .zeros (
775- batch_size , max_embed_dim , embed_dim , dtype = inputs_embeds .dtype , device = inputs_embeds .device
776- )
777- final_attention_mask = torch .zeros (
778- batch_size , max_embed_dim , dtype = attention_mask .dtype , device = inputs_embeds .device
779- )
780- if labels is not None :
781- final_labels = torch .full (
782- (batch_size , max_embed_dim ), self .config .ignore_index , dtype = input_ids .dtype , device = input_ids .device
783- )
784- # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
785- # set the corresponding tensors into their correct target device.
786- target_device = inputs_embeds .device
787- batch_indices , non_image_indices , text_to_overwrite = (
788- batch_indices .to (target_device ),
789- non_image_indices .to (target_device ),
790- text_to_overwrite .to (target_device ),
791- )
792- attention_mask = attention_mask .to (target_device )
793-
794- # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
795- # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
796- final_embedding [batch_indices , text_to_overwrite ] = inputs_embeds [batch_indices , non_image_indices ]
797- final_attention_mask [batch_indices , text_to_overwrite ] = attention_mask [batch_indices , non_image_indices ]
798- if labels is not None :
799- final_labels [batch_indices , text_to_overwrite ] = labels [batch_indices , non_image_indices ]
800-
801- # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
802- image_to_overwrite = torch .full (
803- (batch_size , max_embed_dim ), True , dtype = torch .bool , device = inputs_embeds .device
804- )
805- image_to_overwrite [batch_indices , text_to_overwrite ] = False
806- if left_padding :
807- image_to_overwrite &= image_to_overwrite .cumsum (- 1 ) - 1 >= nb_image_pad [:, None ].to (target_device )
808- else :
809- mask = torch .ones_like (image_to_overwrite , dtype = torch .bool ).cumsum (- 1 ) - 1
810- padding_mask = mask <= new_token_positions [:, - 1 :].to (target_device )
811- image_to_overwrite &= padding_mask
812-
813- if image_to_overwrite .sum () != image_features .shape [:- 1 ].numel ():
814- raise ValueError (
815- f"The input provided to the model are wrong. The number of image tokens is { torch .sum (special_image_token_mask )} while"
816- f" the number of image given to the model is { num_images } . This prevents correct indexing and breaks batch generation."
817- )
818-
819- final_embedding [image_to_overwrite ] = image_features .contiguous ().reshape (- 1 , embed_dim ).to (target_device )
820- final_attention_mask |= image_to_overwrite
821- position_ids = (final_attention_mask .cumsum (- 1 ) - 1 ).masked_fill_ ((final_attention_mask == 0 ), 1 )
822-
823- # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
824- batch_indices , pad_indices = torch .where (input_ids == self .pad_token_id )
825- indices_to_mask = new_token_positions [batch_indices , pad_indices ]
826-
827- final_embedding [batch_indices , indices_to_mask ] = 0
828-
829- if labels is None :
830- final_labels = None
831-
832- return final_embedding , final_attention_mask , final_labels , position_ids
833-
834753 @add_start_docstrings_to_model_forward (GOT_OCR2_INPUTS_DOCSTRING )
835754 @replace_return_docstrings (output_type = GotOcr2CausalLMOutputWithPast , config_class = _CONFIG_FOR_DOC )
836755 def forward (
0 commit comments