From a5bb52847139bf6ad7489ac62a5fb6d0fa3d2ec6 Mon Sep 17 00:00:00 2001 From: Pablo Montalvo <39954772+molbap@users.noreply.github.com> Date: Thu, 5 Dec 2024 18:15:48 +0100 Subject: [PATCH] Fix signatures for processing kwargs (#35105) * add conversion script * remove pg2 refs * fixup style * small update * get correct scaling * add back missing bos * fix missing config keys * might revert this pos_embeddings * fixup 9b config * fix 9b * fixup 9b conversion for good + add back num_hidden_layers * add correct query scaling for 2b, 9b, 27b * fixup 27b conversion * Additional variant: 27b-896 * Use CPU for conversion to reduce GPU RAM requirements * fix causal mask generation + formatting * fix in-training causal mask generation edge case * trigger CI * update config * update config * update config * update config * update config * update config * update config * update config * update config * move conversion file to main model dir * handle multi-images + bos token * address comments for input ids * revert ci fixes * [run-slow] paligemma * fix * [run-slow] paligemma * skip end 2 end * [run-slow] paligemma --------- Co-authored-by: Pedro Cuenca Co-authored-by: ydshieh --- .../convert_paligemma2_weights_to_hf.py | 415 ++++++++++++++++++ .../models/paligemma/modeling_paligemma.py | 37 +- .../models/paligemma/processing_paligemma.py | 15 +- .../paligemma/test_modeling_paligemma.py | 5 + .../paligemma/test_processor_paligemma.py | 6 +- 5 files changed, 459 insertions(+), 19 deletions(-) create mode 100644 src/transformers/models/paligemma/convert_paligemma2_weights_to_hf.py diff --git a/src/transformers/models/paligemma/convert_paligemma2_weights_to_hf.py b/src/transformers/models/paligemma/convert_paligemma2_weights_to_hf.py new file mode 100644 index 00000000000000..df869fcefb2ba4 --- /dev/null +++ b/src/transformers/models/paligemma/convert_paligemma2_weights_to_hf.py @@ -0,0 +1,415 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert PaliGemma2 checkpoints from the original repository.""" + +import argparse +import collections + +import jax.numpy as jnp +import ml_dtypes +import numpy as np +import torch + +from transformers import ( + AutoTokenizer, + Gemma2Config, + PaliGemmaConfig, + PaliGemmaForConditionalGeneration, + PaliGemmaProcessor, + SiglipImageProcessor, +) +from transformers.tokenization_utils_base import AddedToken +from transformers.utils import logging + + +device = "cpu" + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + +# TODO add sequence length variations here + +PALIGEMMA2_VARIANTS = ["2b-224", "2b-448", "2b-896", "9b-224", "9b-448", "9b-896", "27b-224", "27b-448", "27b-896"] +VARIANT_CONFIGS = { + "2b": { + "num_positions": 256, + "hidden_size": 2304, + "num_hidden_layers": 26, + "intermediate_size": 9216, + "num_key_value_heads": 4, + "num_attention_heads": 8, + "head_dim": 256, + "query_pre_attn_scalar": 256, + }, + "9b": { + "num_positions": 1024, + "hidden_size": 3584, + "num_hidden_layers": 42, + "intermediate_size": 14336, + "num_key_value_heads": 8, + "num_attention_heads": 16, + "head_dim": 256, + "query_pre_attn_scalar": 256, + }, + "27b": { + "num_positions": 4096, + "hidden_size": 4608, + "num_hidden_layers": 46, + "intermediate_size": 36864, + "num_key_value_heads": 16, + "num_attention_heads": 32, + "head_dim": 128, + "query_pre_attn_scalar": 4608 // 32, # scaling is different for the 28b + }, +} + +DTYPES = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16} + + +def get_paligemma2_config(variant: str, precision: str): + config = { + "image_token_index": None, + "pad_token_id": 0, + "bos_token_id": 2, + "eos_token_id": 1, + } + base_variant = variant.split("-")[0] + + if variant in PALIGEMMA2_VARIANTS: + image_size = int(variant.split("-")[1]) + variant_config = VARIANT_CONFIGS[base_variant] + patch_size = 14 + num_image_tokens = (image_size**2) // (patch_size**2) + config["projection_dim"] = variant_config["hidden_size"] + config["image_token_index"] = 257152 + config["num_hidden_layers"] = variant_config["num_hidden_layers"] # For generate + text_config = Gemma2Config.from_pretrained("google/gemma-2-2b-it").to_dict() + sup_text_config = { + "model_type": "gemma2", + "vocab_size": 257152, + "num_hidden_layers": variant_config["num_hidden_layers"], + "num_key_value_heads": variant_config["num_key_value_heads"], + "head_dim": variant_config["head_dim"], + "torch_dtype": precision, + "hidden_size": variant_config["hidden_size"], + "hidden_activation": "gelu_pytorch_tanh", + "num_attention_heads": variant_config["num_attention_heads"], + "intermediate_size": variant_config["intermediate_size"], + "is_encoder_decoder": False, + "query_pre_attn_scalar": variant_config["query_pre_attn_scalar"], + } + text_config.update(sup_text_config) + + vision_config = { + "num_positions": variant_config["num_positions"], # not useful, to remove + "torch_dtype": precision, + "image_size": image_size, + "patch_size": patch_size, + "num_image_tokens": num_image_tokens, + "hidden_size": 1152, + "intermediate_size": 4304, + "num_hidden_layers": 27, + "num_attention_heads": 16, + "projection_dim": variant_config["hidden_size"], + "hidden_act": "gelu_pytorch_tanh", + "vision_use_head": False, + } + final_config = PaliGemmaConfig(text_config=text_config, vision_config=vision_config, **config) + else: + raise ValueError(f"Identifier {variant} not supported. Available: {PALIGEMMA2_VARIANTS}") + return final_config + + +def slice_state_dict(state_dict, config): + # fmt: off + # patch embeddings + state_dict["vision_tower.vision_model.embeddings.patch_embedding.weight"] = state_dict.pop("img/embedding/kernel").transpose( + 3, 2, 0, 1 + ) + state_dict["vision_tower.vision_model.embeddings.patch_embedding.bias"] = state_dict.pop("img/embedding/bias") + # positional embeddings + state_dict["vision_tower.vision_model.embeddings.position_embedding.weight"] = state_dict.pop("img/pos_embedding").reshape( + -1, config.vision_config.hidden_size + ) + + # extract vision layers to be sliced at index 0. There are 27 layers in the base model. + encoderblock_layernorm0_scale = state_dict.pop("img/Transformer/encoderblock/LayerNorm_0/scale") + encoderblock_layernorm0_bias = state_dict.pop("img/Transformer/encoderblock/LayerNorm_0/bias") + encoderblock_layernorm1_scale = state_dict.pop("img/Transformer/encoderblock/LayerNorm_1/scale") + encoderblock_layernorm1_bias = state_dict.pop("img/Transformer/encoderblock/LayerNorm_1/bias") + + encoderblock_mlp_dense0_kernel= state_dict.pop("img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel") + encoderblock_mlp_dense0_bias= state_dict.pop("img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias") + encoderblock_mlp_dense1_kernel= state_dict.pop("img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel") + encoderblock_mlp_dense1_bias= state_dict.pop("img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias") + + encoderblock_attention_0_key_kernel = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel") + encoderblock_attention_0_key_bias = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias") + encoderblock_attention_0_value_kernel = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel") + encoderblock_attention_0_value_bias = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias") + encoderblock_attention_0_query_kernel = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel") + encoderblock_attention_0_query_bias = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias") + encoderblock_attention_0_out_kernel = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel") + encoderblock_attention_0_out_bias = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias") + + for i in range(config.vision_config.num_hidden_layers): + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight"] = encoderblock_layernorm0_scale[i].transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias"] = encoderblock_layernorm0_bias[i] + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight"] = encoderblock_layernorm1_scale[i].transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias"] = encoderblock_layernorm1_bias[i] + + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight"] = encoderblock_mlp_dense0_kernel[i].transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias"] = encoderblock_mlp_dense0_bias[i] + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight"] = encoderblock_mlp_dense1_kernel[i].transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias"] = encoderblock_mlp_dense1_bias[i] + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight"] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias"] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight"] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias"] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight"] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias"] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight"] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias"] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) + + state_dict["vision_tower.vision_model.post_layernorm.weight"] = state_dict.pop("img/Transformer/encoder_norm/scale").transpose() + state_dict["vision_tower.vision_model.post_layernorm.bias"] = state_dict.pop("img/Transformer/encoder_norm/bias") + + # multimodal projector + + state_dict['multi_modal_projector.linear.weight'] = state_dict.pop("img/head/kernel").transpose() + state_dict['multi_modal_projector.linear.bias'] = state_dict.pop("img/head/bias") + + # text decoder (gemma) + + embedding_vector = state_dict.pop("llm/embedder/input_embedding") + state_dict["language_model.model.embed_tokens.weight"] = embedding_vector + + # pop the einsum attention + mlp representations. There are 26 layers in gemma2-2b. + + llm_attention_attn_vec_einsum = state_dict.pop("llm/layers/attn/attn_vec_einsum/w") + # (26, 2, 4, 2304, 256) for 2b-224, 4 kv heads and 26 layers + llm_attention_kv_einsum = state_dict.pop("llm/layers/attn/kv_einsum/w") + llm_attention_q_einsum = state_dict.pop("llm/layers/attn/q_einsum/w") + llm_mlp_gating_einsum = state_dict.pop("llm/layers/mlp/gating_einsum") + llm_mlp_linear = state_dict.pop("llm/layers/mlp/linear") + # TODO verify correctness of layer norm loading + llm_input_layernorm = state_dict.pop("llm/layers/pre_attention_norm/scale") + llm_pre_feedforward_layernorm = state_dict.pop("llm/layers/pre_ffw_norm/scale") + + llm_post_attention_layernorm = state_dict.pop("llm/layers/post_attention_norm/scale") + llm_post_feedforward_layernorm = state_dict.pop("llm/layers/post_ffw_norm/scale") + + for i in range(config.text_config.num_hidden_layers): + # llm_attention_q_einsum[i].shape = (8, 2048, 256) + # q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size) + + """ + q shape (8, 2304, 256) + k shape (4, 2304, 256) + v shape (4, 2304, 256) + o shape (8, 256, 2304) + + """ + q_transpose = (0, 2, 1) + k_transpose = (0, 2, 1) + v_transpose = (0, 2, 1) + o_transpose = (2, 0, 1) + + q_weight_matrices = llm_attention_q_einsum[i].transpose(*q_transpose) + q_proj_weight_reshaped = q_weight_matrices + q_proj_weight_reshaped = q_proj_weight_reshaped.reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size) + state_dict[f"language_model.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped + # Shape: (4, 2304, 256) + k_weight_matrices = llm_attention_kv_einsum[i, 0].transpose(*k_transpose) + k_proj_weight_reshaped = k_weight_matrices.reshape( + config.text_config.num_key_value_heads * config.text_config.head_dim, + config.text_config.hidden_size + ) + state_dict[f"language_model.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped + # llm_attention_kv_einsum[i, 1].shape = (num_key_value_heads, hidden_size, head_dim) + v_weight_matrices = llm_attention_kv_einsum[i, 1].transpose(*v_transpose) # Shape: (4, 2304, 256) + v_proj_weight_reshaped = v_weight_matrices.reshape( + config.text_config.num_key_value_heads * config.text_config.head_dim, + config.text_config.hidden_size + ) + state_dict[f"language_model.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped + + # output projection. + + # llm_attention_attn_vec_einsum[i].shape = (8, 256, 2304) + o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].transpose(*o_transpose).reshape(config.text_config.hidden_size, config.text_config.num_attention_heads * config.text_config.head_dim) + state_dict[f"language_model.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped + # mlp layers + gate_proj_weight = llm_mlp_gating_einsum[i, 0] + state_dict[f"language_model.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose() + up_proj_weight = llm_mlp_gating_einsum[i, 1] + state_dict[f"language_model.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose() + state_dict[f"language_model.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose() + state_dict[f"language_model.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i] + state_dict[f"language_model.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i] + state_dict[f"language_model.model.layers.{i}.pre_feedforward_layernorm.weight"] = llm_pre_feedforward_layernorm[i] + state_dict[f"language_model.model.layers.{i}.post_feedforward_layernorm.weight"] = llm_post_feedforward_layernorm[i] + state_dict["language_model.model.norm.weight"] = state_dict.pop("llm/final_norm/scale") + state_dict["language_model.lm_head.weight"] = embedding_vector # weights are tied. + [k for k in state_dict.keys() if not k.startswith('vision') and not k.startswith('language')] + # fmt: on + for key, value in state_dict.items(): + if not isinstance(value, torch.Tensor): + try: + if value.dtype == jnp.bfloat16: + value = jnp.array(value).astype(jnp.float32) + value = np.array(value) + state_dict[key] = torch.from_numpy(value).to(torch.bfloat16) + else: + state_dict[key] = torch.from_numpy(value) + except Exception as initial_exception: + raise ValueError(f"Conversion failed from jax weights with {initial_exception}. Check your inputs.") + return state_dict + + +def flatten_nested_dict(params, parent_key="", sep="/", precision: int = "float32"): + items = [] + + for k, v in params.items(): + k = k.removeprefix("params/") + new_key = parent_key + sep + k if parent_key else k + + if isinstance(v, collections.abc.MutableMapping): + items.extend(flatten_nested_dict(v, parent_key=new_key, sep=sep, precision=precision).items()) + else: + if precision == "bfloat16": + try: + v = v.view(ml_dtypes.bfloat16) + except Exception as initial_exception: + raise ValueError(f"Conversion failed from bfloat16 with {initial_exception}, check your inputs.") + items.append((new_key, v)) + return dict(items) + + +@torch.no_grad() +def convert_paligemma2_checkpoint( + checkpoint_path, + pytorch_dump_folder_path, + variant: str, + precision: str, + do_convert_weights=False, +): + """ + Read checkpoints from flax npz files, rename/reshape, send result to state dict and verify logits if needed. + """ + config = get_paligemma2_config(variant, precision=precision) + if do_convert_weights: + tokenizer_id = "google/paligemma-3b-pt-224" # same tokenizer as paligemma 1 + tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) + image_token = AddedToken("", normalized=False, special=True) + tokens_to_add = {"additional_special_tokens": [image_token]} + tokenizer.add_special_tokens(tokens_to_add) + + # tokenizer.padding_side = 'right' # uncomment for testing purposes only. + + image_processor = SiglipImageProcessor.from_pretrained("google/paligemma-3b-pt-224") + image_processor.size = {"width": config.vision_config.image_size, "height": config.vision_config.image_size} + image_processor.image_seq_length = config.vision_config.num_image_tokens + + processor = PaliGemmaProcessor(image_processor=image_processor, tokenizer=tokenizer) + data = jnp.load(checkpoint_path) + state_dict = flatten_nested_dict(data, precision=precision) + del data + state_dict_transformers = slice_state_dict(state_dict, config) + del state_dict + del config.hidden_size # this key is unused + model = PaliGemmaForConditionalGeneration(config).to(device).eval() + model.load_state_dict(state_dict_transformers) + del state_dict_transformers + model.config.text_config._attn_implementation = "sdpa" + + # model expansion to get random embeds of image tokens + pad_shape = 64 # for performance reasons + pre_expansion_embeddings = model.language_model.model.embed_tokens.weight.data + mu = torch.mean(pre_expansion_embeddings, dim=0).float() + n = pre_expansion_embeddings.size()[0] + sigma = ((pre_expansion_embeddings - mu).T @ (pre_expansion_embeddings - mu)) / n + dist = torch.distributions.multivariate_normal.MultivariateNormal(mu, covariance_matrix=1e-5 * sigma) + + # We add an image token so we resize the model + model.resize_token_embeddings(config.text_config.vocab_size + 2, pad_shape) + model.language_model.model.embed_tokens.weight.data[257152:] = torch.stack( + tuple( + (dist.sample() for _ in range(model.language_model.model.embed_tokens.weight.data[257152:].shape[0])) + ), + dim=0, + ) + model.language_model.lm_head.weight.data[257152:] = torch.stack( + tuple((dist.sample() for _ in range(model.language_model.lm_head.weight.data[257152:].shape[0]))), + dim=0, + ) + # convert to needed precision + + model.to(DTYPES[precision]) + model.save_pretrained(pytorch_dump_folder_path, safe_serialization=True) + processor.save_pretrained(pytorch_dump_folder_path) + + else: + processor = PaliGemmaProcessor.from_pretrained(pytorch_dump_folder_path, do_rescale=False) + model = ( + PaliGemmaForConditionalGeneration.from_pretrained(pytorch_dump_folder_path, attn_implementation="sdpa") + .to(device) + .eval() + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--checkpoint_path", + required=True, + type=str, + help="Path to the .npz checkpoint", + ) + + parser.add_argument( + "--pytorch_dump_folder_path", + required=True, + type=str, + help="Path to the output directory where model and processor will be saved.", + ) + + parser.add_argument( + "--precision", + choices=["float32", "bfloat16", "float16"], + type=str, + help="Precision identifier for model conversion - should match the base checkpoint precision.", + ) + + parser.add_argument( + "--variant", + default="2b-224", + choices=PALIGEMMA2_VARIANTS, + type=str, + help="String identifier of the paligemma2 variant to convert.", + ) + + parser.add_argument( + "--do_convert_weights", action="store_true", help="Whether or not to reload and convert the weights." + ) + + args = parser.parse_args() + convert_paligemma2_checkpoint( + checkpoint_path=args.checkpoint_path, + pytorch_dump_folder_path=args.pytorch_dump_folder_path, + variant=args.variant, + precision=args.precision, + do_convert_weights=args.do_convert_weights, + ) diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index 63cfbb6e6a5ff0..b4a231561ba791 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -21,7 +21,7 @@ import torch.utils.checkpoint from torch import nn -from ...cache_utils import Cache, StaticCache +from ...cache_utils import Cache, HybridCache, StaticCache from ...generation import GenerationMixin from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -341,7 +341,14 @@ def tie_weights(self): return self.language_model.tie_weights() def _update_causal_mask( - self, attention_mask, token_type_ids, inputs_embeds, past_key_values, cache_position, is_training: bool = False + self, + attention_mask, + token_type_ids, + past_key_values, + cache_position, + input_ids=None, + inputs_embeds=None, + is_training: bool = False, ): if self.config.text_config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: @@ -349,11 +356,13 @@ def _update_causal_mask( return None using_static_cache = isinstance(past_key_values, StaticCache) - dtype = inputs_embeds.dtype - min_dtype = torch.finfo(dtype).min - sequence_length = inputs_embeds.shape[1] + min_dtype = torch.finfo(self.dtype).min + inputs_lead_dim = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] + sequence_length = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() + elif isinstance(past_key_values, HybridCache): + target_length = past_key_values.get_max_cache_shape() else: target_length = ( attention_mask.shape[-1] @@ -366,7 +375,7 @@ def _update_causal_mask( return attention_mask causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + (sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device ) # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below if sequence_length != 1: @@ -376,7 +385,7 @@ def _update_causal_mask( causal_mask[:, :sequence_length] = 0.0 causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(inputs_embeds.shape[0], 1, -1, -1) + causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] @@ -405,7 +414,7 @@ def get_image_features(self, pixel_values: torch.FloatTensor): image_outputs = self.vision_tower(pixel_values) selected_image_feature = image_outputs.last_hidden_state image_features = self.multi_modal_projector(selected_image_feature) - image_features = image_features / (self.config.hidden_size**0.5) + image_features = image_features / (self.config.text_config.hidden_size**0.5) return image_features @add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING) @@ -516,9 +525,8 @@ def forward( labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels) causal_mask = self._update_causal_mask( - attention_mask, token_type_ids, inputs_embeds, past_key_values, cache_position, is_training + attention_mask, token_type_ids, past_key_values, cache_position, input_ids, inputs_embeds, is_training ) - outputs = self.language_model( attention_mask=causal_mask, position_ids=position_ids, @@ -579,6 +587,7 @@ def prepare_inputs_for_generation( token_type_ids=None, use_cache=True, num_logits_to_keep=None, + labels=None, **kwargs, ): # Overwritten -- custom `position_ids` and `pixel_values` handling @@ -598,10 +607,14 @@ def prepare_inputs_for_generation( # position_ids in Paligemma are 1-indexed if model_inputs.get("position_ids") is not None: model_inputs["position_ids"] += 1 - # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always if cache_position[0] == 0: model_inputs["pixel_values"] = pixel_values - + is_training = token_type_ids is not None and labels is not None + if cache_position[0] == 0 and isinstance(past_key_values, HybridCache): + causal_mask = self._update_causal_mask( + attention_mask, token_type_ids, past_key_values, cache_position, input_ids, inputs_embeds, is_training + ) + model_inputs["attention_mask"] = causal_mask return model_inputs diff --git a/src/transformers/models/paligemma/processing_paligemma.py b/src/transformers/models/paligemma/processing_paligemma.py index dab0d60ad56d15..cb35aab66cba49 100644 --- a/src/transformers/models/paligemma/processing_paligemma.py +++ b/src/transformers/models/paligemma/processing_paligemma.py @@ -269,7 +269,7 @@ def __call__( logger.warning( "You are passing both `text` and `images` to `PaliGemmaProcessor`. The processor expects special " "image tokens in the text, as many tokens as there are images per each text. It is recommended to " - "add `` tokens in the very beginning of your text and `` token after that. For this call, we will infer how many images " + "add `` tokens in the very beginning of your text. For this call, we will infer how many images " "each text has and add special tokens." ) @@ -304,9 +304,16 @@ def __call__( ] images = make_batched_images(images) else: - text = [sample.replace(IMAGE_TOKEN, IMAGE_TOKEN * self.image_seq_length) for sample in text] - input_strings = [f"{sample}\n" for sample in text] - + expanded_samples = [] + for sample in text: + expanded_sample = sample.replace(IMAGE_TOKEN, IMAGE_TOKEN * self.image_seq_length) + bos_rfind_index = expanded_sample.rfind(IMAGE_TOKEN) + bos_index = bos_rfind_index + len(IMAGE_TOKEN) if bos_rfind_index != -1 else 0 + expanded_sample = ( + expanded_sample[:bos_index] + self.tokenizer.bos_token + expanded_sample[bos_index:] + ) + expanded_samples.append(expanded_sample) + input_strings = [f"{sample}\n" for sample in expanded_samples] pixel_values = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"] # max_length has to account for the image tokens diff --git a/tests/models/paligemma/test_modeling_paligemma.py b/tests/models/paligemma/test_modeling_paligemma.py index ce44436a20ad2c..5ffea7ffe55087 100644 --- a/tests/models/paligemma/test_modeling_paligemma.py +++ b/tests/models/paligemma/test_modeling_paligemma.py @@ -347,6 +347,11 @@ def test_flash_attn_2_fp32_ln(self): def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): pass + # TODO (joao, raushan): fix me -- the problem is in `cache_position[0] == 0`, i.e. dynamic control flow + @unittest.skip("PaliGemma is not compatible with end-to-end generation compilation") + def test_generate_compile_fullgraph(self): + pass + @slow @require_torch diff --git a/tests/models/paligemma/test_processor_paligemma.py b/tests/models/paligemma/test_processor_paligemma.py index 245aff594125cf..e301bf304b1ed8 100644 --- a/tests/models/paligemma/test_processor_paligemma.py +++ b/tests/models/paligemma/test_processor_paligemma.py @@ -63,8 +63,8 @@ def test_text_with_image_tokens(self): tokenizer = self.get_component("tokenizer") processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) - text_multi_images = "Dummy text!" - text_single_image = "Dummy text!" + text_multi_images = "Dummy text!" + text_single_image = "Dummy text!" text_no_image = "Dummy text!" image = self.prepare_image_inputs() @@ -85,7 +85,7 @@ def test_text_with_image_tokens(self): self.assertTrue(out_noimage[k].tolist() == out_multiimages[k].tolist()) text_batched = ["Dummy text!", "Dummy text!"] - text_batched_with_image = ["Dummy text!", "Dummy text!"] + text_batched_with_image = ["Dummy text!", "Dummy text!"] out_images = processor(text=text_batched_with_image, images=[image, image], return_tensors="np") out_noimage_nested = processor(text=text_batched, images=[[image], [image]], return_tensors="np") out_noimage = processor(text=text_batched, images=[image, image], return_tensors="np")