diff --git a/docs/source/en/model_doc/chameleon.md b/docs/source/en/model_doc/chameleon.md index 323b8381316..8aa90cf71af 100644 --- a/docs/source/en/model_doc/chameleon.md +++ b/docs/source/en/model_doc/chameleon.md @@ -50,6 +50,8 @@ The original code can be found [here](https://github.com/facebookresearch/chamel - We advise users to use `padding_side="left"` when computing batched generation as it leads to more accurate results. Simply make sure to set `processor.tokenizer.padding_side = "left"` before generating. +- When generating images, we advice users to load the model in `bfloat16` for better results. Simply make sure to set `torch_dtype=torch.bfloat16` when loading the model. + - Note that Chameleon was tuned for safety alignment. If the model is refusing to answer, consider asking a more concrete question, instead of an open question. - Chameleon generates in chat format which means that the generated text will always be the "assistant's turn". You can enable a text completion generation by passing `return_for_text_completion=True` when calling the processor. @@ -57,6 +59,9 @@ The original code can be found [here](https://github.com/facebookresearch/chamel > [!NOTE] > Chameleon implementation in Transformers uses a special image token to indicate where to merge image embeddings. For special image token we didn't add a new one but used one of the reserved tokens: ``. You have to add `` to your prompt in the place where the image should be embedded for correct generation. +> [!NOTE] +> The official model checkpoint currently only supports text generation. To generate images and interleaved text-image responses, you can use finetuned versions such as [Anole](https://arxiv.org/abs/2407.06135). Note however that Anole has a bias for "empty" or background patches, so it is recommended to use sampling when generating images (i.e. setting `do_sample=True` during generation) to reduce the likelihood of generating a blank image. + ## Usage example ### Single image inference @@ -117,13 +122,154 @@ prompts = [ # We can simply feed images in the order they have to be used in the text prompt # Each "" token uses one image leaving the next for the subsequent "" tokens -inputs = processor(text=prompts, images=[image_stop, image_cats, image_snowman], padding=True, return_tensors="pt").to(device="cuda", dtype=torch.bfloat16) +inputs = processor( + text=prompts, + images=[image_stop, image_cats, image_snowman], + padding=True, + return_tensors="pt", +).to(device="cuda", dtype=torch.bfloat16) # Generate generate_ids = model.generate(**inputs, max_new_tokens=50) processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) ``` +### Text to image generation + +Chameleon can also generate images. However, the official model checkpoint currently only supports text generation. We need to use finetuned versions such as [Anole](https://arxiv.org/abs/2407.06135) to do image generation. Here is how you can do it: + +```python +import torch +from transformers import ChameleonProcessor, ChameleonForConditionalGeneration + +processor = ChameleonProcessor.from_pretrained("leloy/Anole-7b-v0.1-hf") +model = ChameleonForConditionalGeneration.from_pretrained( + "leloy/Anole-7b-v0.1-hf", + device_map="auto", + torch_dtype=torch.bfloat16, +) + +# Prepare a prompt +prompt = "Generate an image of a snowman." + +# Preprocess the prompt +inputs = processor(prompt, padding=True, return_tensors="pt").to(model.device, dtype=model.dtype) + +# Generate discrete image tokens +generate_ids = model.generate( + **inputs, + multimodal_generation_mode="image-only", + # Note: We need to set `max_new_tokens` to 1026 since the model generates the `image_start_token` marker token first, then 1024 image tokens, and finally the `image_end_token` marker token. + max_new_tokens=1026, + # This is important because most of the image tokens during training were for "empty" patches, so greedy decoding of image tokens will likely result in a blank image. + do_sample=True, +) + +# Only keep the tokens from the response +response_ids = generate_ids[:, inputs["input_ids"].shape[-1]:] + +# Decode the generated image tokens +pixel_values = model.decode_image_tokens(response_ids[:, 1:-1]) +images = processor.postprocess_pixel_values(pixel_values) + +# Save the image +images[0].save("snowman.png") +``` + +### Text-image to image generation + +We can also interleave text and images in the prompt to generate images. Here is how you can do it: + +```python +import requests + +import torch +from PIL import Image +from transformers import ChameleonProcessor, ChameleonForConditionalGeneration +from transformers.image_transforms import to_pil_image + +processor = ChameleonProcessor.from_pretrained("leloy/Anole-7b-v0.1-hf") +model = ChameleonForConditionalGeneration.from_pretrained( + "leloy/Anole-7b-v0.1-hf", + device_map="auto", + torch_dtype=torch.bfloat16, +) + +# Get image of a snowman +url = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg" +image_snowman = Image.open(requests.get(url, stream=True).raw) + +# Prepare a prompt +prompt = "Generate a variation of this image." + +# Preprocess the prompt +inputs = processor( + prompt, + images=[image_snowman], + padding=True, + return_tensors="pt", +).to(model.device, dtype=model.dtype) + +# Generate discrete image tokens +generate_ids = model.generate( + **inputs, + multimodal_generation_mode="image-only", + # This is important because most of the image tokens during training were for "empty" patches, so greedy decoding of image tokens will likely result in a blank image. + do_sample=True, +) + +# Only keep the tokens from the response +response_ids = generate_ids[:, inputs["input_ids"].shape[-1]:] + +# The generated image tokens are wrapped by the `image_start_token` and `image_end_token` tokens. We need to remove them before decoding the image tokens. +image_token_ids = response_ids[:, 1:-1] + +# Decode the generated image tokens +pixel_values = model.decode_image_tokens(image_token_ids) +pixel_values = processor.postprocess_pixel_values(pixel_values) + +# Save the image +image = to_pil_image(pixel_values[0].detach().cpu()) +image.save("snowman.png") +``` + +### Interleaved text-image generation + +We can also generate interleaved text and images in the output. Here is how you can do it: + +```python +import torch +from transformers import ChameleonProcessor, ChameleonForConditionalGeneration + +processor = ChameleonProcessor.from_pretrained("leloy/Anole-7b-v0.1-hf") +model = ChameleonForConditionalGeneration.from_pretrained( + "leloy/Anole-7b-v0.1-hf", + device_map="auto", + torch_dtype=torch.bfloat16, +) + +# Prepare a prompt +prompt = "Can you draw a snowman and explain how to build one?" + +# Preprocess the prompt +inputs = processor(prompt, padding=True, return_tensors="pt").to(model.device, dtype=model.dtype) + +# Generate interleaved text and discrete image tokens +generate_ids = model.generate( + **inputs, + multimodal_generation_mode="interleaved-text-image", + # Note: We will need a larger `max_new_tokens` value since we are generating both text and image tokens. + max_new_tokens=4096, + # This is important because most of the image tokens during training were for "empty" patches, so greedy decoding of image tokens will likely result in a blank image. + do_sample=True, +) + +# Only keep the tokens from the response +response_ids = generate_ids[:, inputs["input_ids"].shape[-1]:] +``` + +From here, you can split the response tokens into text and image token segments, decode them separately as shown in the previous examples, and finally render the resulting text and images together. You can also use [MMSG](https://github.com/leloykun/mmsg) to do this more easily. + ## Model optimization ### Quantization using Bitsandbytes diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index e9ba4560682..94018fa882e 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1750,7 +1750,38 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to return scores_processed -class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor): +class SuppressTokensInIndexRangeLogitsProcessor(LogitsProcessor): + r""" + [`SuppressTokensInIndexRangeLogitsProcessor`] supresses a list of tokens from `start_index` to `end_index` (exclusive) + + Args: + suppress_tokens (`List[int]`): + List of token ids to suppress during generation. + start_index (`int`): + The index at which to start suppressing tokens. + end_index (`int`, *optional*): + The index at which to end suppressing tokens. If `None`, it will suppress tokens indefinitely. + device (`str`, *optional*, defaults to `"cpu"`): + The device to allocate the tensors. + """ + + def __init__( + self, suppress_tokens: List[int], start_index: int, end_index: Optional[int] = None, device: str = "cpu" + ): + self.suppress_tokens = torch.tensor(suppress_tokens, device=device) + self.start_index = start_index + self.end_index = end_index if end_index is not None else math.inf + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + current_index = input_ids.shape[1] + if self.start_index > current_index or current_index > self.end_index: + return scores + suppress_tokens_mask = torch.zeros_like(scores, dtype=torch.bool) + suppress_tokens_mask[:, self.suppress_tokens] = True + return scores.masked_fill(suppress_tokens_mask, torch.finfo(scores.dtype).min) + + +class SuppressTokensAtBeginLogitsProcessor(SuppressTokensInIndexRangeLogitsProcessor): r""" [`SuppressTokensAtBeginLogitsProcessor`] supresses a list of tokens as soon as the `generate` function starts generating using `begin_index` tokens. This should ensure that the tokens defined by `begin_suppress_tokens` are @@ -1786,24 +1817,17 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor): """ def __init__(self, begin_suppress_tokens, begin_index, device: str = "cpu"): - self.begin_suppress_tokens = torch.tensor(list(begin_suppress_tokens), device=device) + super().__init__(begin_suppress_tokens, begin_index, begin_index + 1, device=device) self.begin_index = begin_index def set_begin_index(self, begin_index): + self.start_index = begin_index + self.end_index = begin_index + 1 + # Keeping this here for backwards compatibility self.begin_index = begin_index - @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - vocab_tensor = torch.arange(scores.shape[-1], device=scores.device) - suppress_token_mask = torch.isin(vocab_tensor, self.begin_suppress_tokens) - scores_processed = scores - if input_ids.shape[-1] == self.begin_index: - scores_processed = torch.where(suppress_token_mask, -float("inf"), scores) - - return scores_processed - -class SuppressTokensLogitsProcessor(LogitsProcessor): +class SuppressTokensLogitsProcessor(SuppressTokensInIndexRangeLogitsProcessor): r""" This processor can be used to suppress a list of tokens. The processor will set their log probs to `-inf` so that they are not generated. Originally created for @@ -1833,14 +1857,7 @@ class SuppressTokensLogitsProcessor(LogitsProcessor): """ def __init__(self, suppress_tokens, device: str = "cpu"): - self.suppress_tokens = torch.tensor(list(suppress_tokens), device=device) - - @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - vocab_tensor = torch.arange(scores.shape[-1], device=scores.device) - suppress_token_mask = torch.isin(vocab_tensor, self.suppress_tokens) - scores = torch.where(suppress_token_mask, -float("inf"), scores) - return scores + super().__init__(suppress_tokens, 0, device=device) class WhisperTimeStampLogitsProcessor(LogitsProcessor): @@ -2449,3 +2466,108 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to scores_processed[b_idx, greenlist_ids] = scores_processed[b_idx, greenlist_ids] + self.bias return scores_processed + + +class AllowOnlyTokensAtRelativeOffsetLogitsProcessor(LogitsProcessor): + r""" + [`AllowOnlyTokensAtRelativeOffsetLogitsProcessor`] suppresses the logits of tokens aside from a specific set of tokens + that can be generated at a relative offset from a trigger token (e.g. begin image token). If `exclusive` is set to + `True`, the set of tokens allowed at this offset will not be allowed anywhere else. This is useful for enforcing + multimodal generation constraints with begin and end marker tokens. + + Originally created for [Chameleon](https://huggingface.co/docs/transformers/model_doc/chameleon). + + Args: + trigger_token_id (`int`): + The token id that triggers the offset check. + allowed_token_ids (`List[int]`): + The list of token ids that are allowed at the specified offset. + offset (`int`): + The relative offset from the trigger token. + exclusive (`bool`, *optional*, defaults to `False`): + If `True`, the set of tokens allowed at this offset will not be allowed anywhere else. + device (`str`, *optional*, defaults to `cpu`): + The device to allocate the util tensor on. + """ + + def __init__( + self, + trigger_token_id: int, + allowed_token_ids: List[int], + offset: int, + exclusive: bool = False, + device: str = "cpu", + ): + self.trigger_token_id = trigger_token_id + self.allowed_token_ids = torch.tensor(allowed_token_ids, device=device) + self.offset = offset + self.exclusive = exclusive + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + if input_ids.shape[1] < self.offset and not self.exclusive: + return scores + + disallowed_tokens_mask = torch.ones_like(scores, dtype=torch.bool) + disallowed_tokens_mask[:, self.allowed_token_ids] = False + + if input_ids.shape[1] < self.offset: + return scores.masked_fill(~disallowed_tokens_mask, torch.finfo(scores.dtype).min) + + trigger_positions = (input_ids[:, -self.offset] == self.trigger_token_id).unsqueeze(-1) + + if self.exclusive: + return scores.masked_fill(~(disallowed_tokens_mask ^ trigger_positions), torch.finfo(scores.dtype).min) + return scores.masked_fill(disallowed_tokens_mask & trigger_positions, torch.finfo(scores.dtype).min) + + +class AllowOnlyTokensInRelativeWindowLogitsProcessor(LogitsProcessor): + r""" + [`AllowOnlyTokensInRelativeWindowLogitsProcessor`] suppresses the logits of tokens aside from a specific set of tokens + that can be generated at a relative window from a trigger token (e.g. begin image token). If `exclusive` is set to + `True`, the set of tokens allowed at this window will not be allowed anywhere else. This is useful for enforcing + multimodal generation constraints. + + Originally created for [Chameleon](https://huggingface.co/docs/transformers/model_doc/chameleon). + + Args: + trigger_token_id (`int`): + The token id that triggers the window check. + allowed_token_ids (`List[int]`): + The list of token ids that are allowed at the specified relative window. + window_width (`int`): + The window_width of the window from the trigger token. + exclusive (`bool`, *optional*, defaults to `False`): + If `True`, the set of tokens allowed at this window will not be allowed anywhere else. + device (`str`, *optional*, defaults to `cpu`): + The device to allocate the util tensor on. + """ + + def __init__( + self, + trigger_token_id: int, + allowed_token_ids: List[int], + window_width: int, + exclusive: bool = False, + device: str = "cpu", + ): + self.trigger_token_id = trigger_token_id + self.allowed_token_ids = torch.tensor(allowed_token_ids, device=device).unsqueeze(0) + self.window_width = window_width + self.exclusive = exclusive + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + window_width = min(self.window_width, input_ids.shape[1]) + trigger_positions = (input_ids[:, -window_width:] == self.trigger_token_id).any(dim=1).unsqueeze(-1) + + disallowed_tokens_mask = torch.ones_like(scores, dtype=torch.bool) + disallowed_tokens_mask[:, self.allowed_token_ids] = False + + if self.exclusive: + return scores.masked_fill( + ~(disallowed_tokens_mask ^ trigger_positions), + torch.finfo(scores.dtype).min, + ) + return scores.masked_fill( + disallowed_tokens_mask & trigger_positions, + torch.finfo(scores.dtype).min, + ) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index e3c70ac1098..a1cf3f04c65 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1276,13 +1276,13 @@ def _validate_generated_length(self, generation_config, input_ids_length, has_de def _prepare_generated_length( self, - generation_config, - has_default_max_length, - has_default_min_length, - model_input_name, - input_ids_length, - inputs_tensor, - ): + generation_config: GenerationConfig, + has_default_max_length: bool, + has_default_min_length: bool, + model_input_name: str, + input_ids_length: int, + inputs_tensor: torch.Tensor, + ) -> GenerationConfig: """Prepared max and min length in generaion configs to avoid clashes between similar attributes""" if generation_config.max_new_tokens is not None: diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index b094f50b5e9..4fe6649309e 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -109,10 +109,7 @@ ("canine", ("CanineTokenizer", None)), ( "chameleon", - ( - "LlamaTokenizer" if is_sentencepiece_available() else None, - "LlamaTokenizerFast" if is_tokenizers_available() else None, - ), + (None, "LlamaTokenizerFast" if is_tokenizers_available() else None), ), ("chinese_clip", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ( diff --git a/src/transformers/models/chameleon/configuration_chameleon.py b/src/transformers/models/chameleon/configuration_chameleon.py index 67de37f2d01..55f93de0d95 100644 --- a/src/transformers/models/chameleon/configuration_chameleon.py +++ b/src/transformers/models/chameleon/configuration_chameleon.py @@ -45,6 +45,8 @@ class ChameleonVQVAEConfig(PretrainedConfig): Resolution of the input images. in_channels (`int`, *optional*, defaults to 3): Number of input channels. + out_channels (`int`, *optional*, defaults to 3): + Number of output channels. base_channels (`int`, *optional*, defaults to 128): Base channel count. channel_multiplier (`List[int]`, *optional*, defaults to `[1, 1, 2, 2, 4]`): @@ -71,6 +73,7 @@ def __init__( latent_channels: int = 256, resolution: int = 512, in_channels: int = 3, + out_channels: int = 3, base_channels: int = 128, channel_multiplier: List[int] = [1, 1, 2, 2, 4], num_res_blocks: int = 2, @@ -87,6 +90,7 @@ def __init__( self.latent_channels = latent_channels self.resolution = resolution self.in_channels = in_channels + self.out_channels = out_channels self.base_channels = base_channels self.channel_multiplier = channel_multiplier self.num_res_blocks = num_res_blocks @@ -169,6 +173,12 @@ class ChameleonConfig(PretrainedConfig): ChameleonVQConfig instance containing the configuration for the VQ-VAE model. vocabulary_map (`dict`, *optional*): A dictionary containing the vocabulary map from the tokenizer. Used to obtain tokens from the image inputs. + image_token_id (`int`, *optional*, defaults to 8711): + The ID for the token used to represent the image in the input sequence. + boi_token_id (`int`, *optional*, defaults to 8197): + Beginning of image token stream id. + eoi_token_id (`int`, *optional*, defaults to 8196): + End of image token stream id. mlp_bias (`bool`, *optional*, defaults to `False`): Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. @@ -214,6 +224,9 @@ def __init__( swin_norm=False, vq_config=None, vocabulary_map=None, + image_token_id=8711, + boi_token_id=8197, + eoi_token_id=8196, mlp_bias=False, **kwargs, ): @@ -245,6 +258,9 @@ def __init__( self.vq_config = ChameleonVQVAEConfig(**vq_config) self.vocabulary_map = vocabulary_map + self.image_token_id = image_token_id + self.boi_token_id = boi_token_id + self.eoi_token_id = eoi_token_id super().__init__( pad_token_id=pad_token_id, diff --git a/src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py b/src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py index 1aebeb0f0bb..4203f1a7c4e 100644 --- a/src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py +++ b/src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py @@ -24,7 +24,7 @@ from transformers import ( ChameleonConfig, - ChameleonForCausalLM, + ChameleonForConditionalGeneration, ChameleonImageProcessor, ChameleonProcessor, ) @@ -49,10 +49,10 @@ Thereafter, models can be loaded via: ```py -from transformers import ChameleonForCausalLM, LlamaTokenizer +from transformers import ChameleonForConditionalGeneration, LlamaTokenizerFast -model = ChameleonForCausalLM.from_pretrained("/output/path") -tokenizer = LlamaTokenizer.from_pretrained("/output/path") +model = ChameleonForConditionalGeneration.from_pretrained("/output/path") +tokenizer = LlamaTokenizerFast.from_pretrained("/output/path") ``` Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions @@ -81,7 +81,7 @@ def write_json(text, path): json.dump(text, f) -def write_model(model_path, input_base_path, model_size, chameleon_version=1): +def write_model(model_path, input_base_path, model_size, chameleon_version=1, vqvae_path=None): os.makedirs(model_path, exist_ok=True) input_model_path = os.path.join(input_base_path, "models", model_size.lower()) params_path = os.path.join(input_model_path, "params.json") @@ -316,8 +316,6 @@ def permute(w, n_heads, dim1=dim, dim2=dim): vqgan_path = os.path.join(input_base_path, "tokenizer/vqgan.ckpt") vqgan_state_dict = torch.load(vqgan_path, map_location="cpu")["state_dict"] for k, v in vqgan_state_dict.items(): - if "decoder" in k: - continue # we dont do image generation yet state_dict[f"model.vqmodel.{k}"] = v # Write configs @@ -370,13 +368,19 @@ def permute(w, n_heads, dim1=dim, dim2=dim): swin_norm=swin_norm, vq_config=vq_config, vocabulary_map=vocabulary_map, + image_token_id=vocabulary_map[""], + boi_token_id=vocabulary_map[""], + eoi_token_id=vocabulary_map[""], ) with init_empty_weights(): - model = ChameleonForCausalLM(config) + model = ChameleonForConditionalGeneration(config) model.load_state_dict(state_dict, assign=True, strict=False) model.save_pretrained(model_path, safe_serialization=True) + if vqvae_path is not None: + model.model.vqmodel.save_pretrained(vqvae_path, safe_serialization=True) + # Load and save the processor tokenizer = LlamaTokenizerFast( tokenizer_file=os.path.join(input_base_path, "tokenizer/text_tokenizer_modified.json"), legacy=False @@ -397,7 +401,7 @@ def permute(w, n_heads, dim1=dim, dim2=dim): # taken from https://github.com/facebookresearch/chameleon/blob/7a72f40aa5f462965c8374f25257f55b65b25ff4/data/prompts_for_human_evaluations.jsonl print("Loading the checkpoint in a Chameleon model...") print("*" * 100) - model = ChameleonForCausalLM.from_pretrained( + model = ChameleonForConditionalGeneration.from_pretrained( model_path, attn_implementation="eager", torch_dtype=torch.bfloat16, device_map="auto" ) processor = ChameleonProcessor.from_pretrained(model_path) @@ -463,12 +467,18 @@ def main(): type=int, help="Version of the Chameleon model to convert", ) + parser.add_argument( + "--vqvae_path", + default=None, + help="Location to write VQ-VAE model", + ) args = parser.parse_args() write_model( model_path=args.output_dir, input_base_path=args.input_dir, model_size=args.model_size, chameleon_version=args.chameleon_version, + vqvae_path=args.vqvae_path, ) diff --git a/src/transformers/models/chameleon/image_processing_chameleon.py b/src/transformers/models/chameleon/image_processing_chameleon.py index a23fdbed028..0a619d9294d 100644 --- a/src/transformers/models/chameleon/image_processing_chameleon.py +++ b/src/transformers/models/chameleon/image_processing_chameleon.py @@ -14,7 +14,7 @@ # limitations under the License. """Image processor class for Chameleon.""" -from typing import Dict, List, Optional, Union +from typing import Dict, Iterable, List, Optional, Union import numpy as np @@ -35,11 +35,13 @@ valid_images, validate_preprocess_arguments, ) -from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging +from ...utils import TensorType, filter_out_non_signature_kwargs, is_torch_available, is_vision_available, logging logger = logging.get_logger(__name__) +if is_torch_available(): + import torch if is_vision_available(): import PIL @@ -209,7 +211,8 @@ def preprocess( return_tensors: Optional[Union[str, TensorType]] = None, data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, input_data_format: Optional[Union[str, ChannelDimension]] = None, - ) -> PIL.Image.Image: + **kwargs, + ) -> BatchFeature: """ Preprocess an image or batch of images. @@ -368,3 +371,91 @@ def blend_rgba(self, image: ImageInput) -> ImageInput: alpha = img_rgba[:, :, 3] / 255.0 img_rgb = (1 - alpha[:, :, np.newaxis]) * 255 + alpha[:, :, np.newaxis] * img_rgba[:, :, :3] return PIL.Image.fromarray(img_rgb.astype("uint8"), "RGB") + + def postprocess( + self, + pixel_values: "torch.Tensor", + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_unnormalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + ) -> "torch.Tensor": + """ + Postprocess a batch of pixel values to images. + + Args: + pixel_values (`torch.Tensor` of shape `(batch_size, num_channels, image_size, image_size)` or `(num_channels, image_size, image_size)`): + A batch or single tensor of pixel values to postprocess. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_unnormalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to unnormalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for unnormalization. Only has an effect if `do_unnormalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for unnormalization. Only has an effect if `do_unnormalize` is set to + `True`. + + Returns: + `torch.Tensor`: A batch or a single tensor of pixel values. + """ + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = 1.0 / self.rescale_factor if rescale_factor is None else rescale_factor + do_unnormalize = do_unnormalize if do_unnormalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + + if do_unnormalize: + pixel_values = self.unnormalize(pixel_values, mean=image_mean, std=image_std) + + if do_rescale: + pixel_values *= rescale_factor + + return torch.clip(pixel_values, 0, 255).to(dtype=torch.uint8) + + def unnormalize( + self, + pixel_values: "torch.Tensor", + mean: Union[float, Iterable[float]], + std: Union[float, Iterable[float]], + ) -> "torch.Tensor": + """ + Unnormalizes `pixel_values` using the mean and standard deviation specified by `mean` and `std`. + + pixel_values = (pixel_values * std) + mean + + Args: + pixel_values (`torch.Tensor` of shape `(batch_size, num_channels, image_size, image_size)` or `(num_channels, image_size, image_size)`): + Batch of pixel values to postprocess. + mean (`float` or `Iterable[float]`): + The mean to use for unnormalization. + std (`float` or `Iterable[float]`): + The standard deviation to use for unnormalization. + data_format (`ChannelDimension`, *optional*): + The channel dimension format of the output image. If unset, will use the inferred format from the input. + """ + channel_axis = 1 if pixel_values.ndim == 4 else 0 + num_channels = pixel_values.shape[channel_axis] + + if isinstance(mean, Iterable): + if len(mean) != num_channels: + raise ValueError(f"mean must have {num_channels} elements if it is an iterable, got {len(mean)}") + else: + mean = [mean] * num_channels + mean = torch.tensor(mean, dtype=pixel_values.dtype, device=pixel_values.device) + + if isinstance(std, Iterable): + if len(std) != num_channels: + raise ValueError(f"std must have {num_channels} elements if it is an iterable, got {len(std)}") + else: + std = [std] * num_channels + std = torch.tensor(std, dtype=pixel_values.dtype, device=pixel_values.device) + + if pixel_values.ndim == 4: + pixel_values = (pixel_values * std.view(1, -1, 1, 1)) + mean.view(1, -1, 1, 1) + else: + pixel_values = (pixel_values * std.view(-1, 1, 1)) + mean.view(-1, 1, 1) + return pixel_values diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index 955c5ed6839..062808e1aca 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -15,8 +15,9 @@ """PyTorch Chameleon model.""" import math +import warnings from functools import cached_property -from typing import Optional, Tuple, Union +from typing import Dict, Literal, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -26,6 +27,16 @@ from ...activations import ACT2FN from ...cache_utils import Cache, StaticCache +from ...generation.configuration_utils import GenerationConfig +from ...generation.logits_process import ( + AllowOnlyTokensAtRelativeOffsetLogitsProcessor, + AllowOnlyTokensInRelativeWindowLogitsProcessor, + LogitsProcessorList, + SuppressTokensAtBeginLogitsProcessor, + SuppressTokensInIndexRangeLogitsProcessor, + SuppressTokensLogitsProcessor, +) +from ...generation.utils import GenerateOutput from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import ( @@ -803,12 +814,14 @@ def __init__(self, config): super().__init__() self.num_embeddings = config.num_embeddings self.embedding_dim = config.embed_dim + self.quant_state_dims = [config.resolution // 2 ** (len(config.channel_multiplier) - 1)] * 2 self.beta = getattr(config, "beta", 0.25) self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim) self.re_embed = self.num_embeddings - def forward(self, hidden_state: torch.Tensor): + def forward(self, hidden_state: torch.FloatTensor): + batch_size = hidden_state.shape[0] hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous() hidden_state_flattened = hidden_state.view(-1, self.embedding_dim) @@ -833,7 +846,30 @@ def forward(self, hidden_state: torch.Tensor): # reshape back to match original input shape hidden_state_quant = hidden_state_quant.permute(0, 3, 1, 2).contiguous() - return hidden_state_quant, loss, min_encoding_indices + return hidden_state_quant, loss, min_encoding_indices.view(batch_size, -1) + + def get_codebook_entry(self, image_tokens: torch.LongTensor) -> torch.FloatTensor: + batch_size = image_tokens.shape[0] + emb_dim: int = self.embedding.weight.shape[-1] + # get quantized latent vectors + hidden_state_quant = self.embedding(image_tokens) + + # reshape back to match original input shape + hidden_state_quant = hidden_state_quant.view((batch_size, *self.quant_state_dims, emb_dim)) + hidden_state_quant = hidden_state_quant.permute(0, 3, 1, 2).contiguous() + + return hidden_state_quant + + +class ChameleonVQVAEDecoderConvUpsample(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, hidden_states): + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + hidden_states = self.conv(hidden_states) + return hidden_states class ChameleonVQVAEEncoderConvDownsample(nn.Module): @@ -848,7 +884,7 @@ def forward(self, hidden_states): return hidden_states -class ChameleonVQVAEEncoderResnetBlock(nn.Module): +class ChameleonVQVAEResnetBlock(nn.Module): def __init__( self, config, @@ -892,7 +928,7 @@ def forward(self, hidden_states): return residual + hidden_states -class ChameleonVQVAEEncoderAttnBlock(nn.Module): +class ChameleonVQVAEAttnBlock(nn.Module): def __init__(self, in_channels): super().__init__() self.in_channels = in_channels @@ -903,7 +939,7 @@ def __init__(self, in_channels): self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - def forward(self, hidden_states): + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: residual = hidden_states hidden_states = self.norm(hidden_states) query_states = self.q(hidden_states) @@ -953,7 +989,7 @@ def __init__(self, config): block_out = base_channels * channel_multiplier[i_level] for i_block in range(self.num_res_blocks): block.append( - ChameleonVQVAEEncoderResnetBlock( + ChameleonVQVAEResnetBlock( config=config, in_channels=block_in, out_channels=block_out, @@ -965,7 +1001,7 @@ def __init__(self, config): and curr_res in config.attn_resolutions and config.attn_type == "vanilla" ): - attn.append(ChameleonVQVAEEncoderAttnBlock(block_in)) + attn.append(ChameleonVQVAEAttnBlock(block_in)) down = nn.Module() down.block = block @@ -976,13 +1012,13 @@ def __init__(self, config): self.down.append(down) self.mid = nn.Module() - self.mid.block_1 = ChameleonVQVAEEncoderResnetBlock( + self.mid.block_1 = ChameleonVQVAEResnetBlock( config=config, in_channels=block_in, out_channels=block_in, ) - self.mid.attn_1 = ChameleonVQVAEEncoderAttnBlock(block_in) if config.attn_type == "vanilla" else nn.Identity() - self.mid.block_2 = ChameleonVQVAEEncoderResnetBlock( + self.mid.attn_1 = ChameleonVQVAEAttnBlock(block_in) if config.attn_type == "vanilla" else nn.Identity() + self.mid.block_2 = ChameleonVQVAEResnetBlock( config=config, in_channels=block_in, out_channels=block_in, @@ -997,7 +1033,7 @@ def __init__(self, config): padding=1, ) - def forward(self, pixel_values: torch.LongTensor): + def forward(self, pixel_values: torch.FloatTensor) -> torch.FloatTensor: # downsampling hidden_states = [self.conv_in(pixel_values)] for i_level in range(self.num_resolutions): @@ -1024,6 +1060,95 @@ def forward(self, pixel_values: torch.LongTensor): return last_hidden_state +class ChameleonVQVAEDecoder(nn.Module): + def __init__(self, config): + super().__init__() + + self.num_resolutions = len(config.channel_multiplier) + self.num_res_blocks = config.num_res_blocks + base_channels = config.base_channels + resolution = config.resolution + latent_channels = config.latent_channels + out_channels = config.out_channels + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = base_channels * config.channel_multiplier[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, latent_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = torch.nn.Conv2d(latent_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ChameleonVQVAEResnetBlock( + config=config, + in_channels=block_in, + out_channels=block_in, + ) + self.mid.attn_1 = ChameleonVQVAEAttnBlock(block_in) if config.attn_type == "vanilla" else nn.Identity() + self.mid.block_2 = ChameleonVQVAEResnetBlock( + config=config, + in_channels=block_in, + out_channels=block_in, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = base_channels * config.channel_multiplier[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ChameleonVQVAEResnetBlock( + config=config, + in_channels=block_in, + out_channels=block_out, + ) + ) + block_in = block_out + if ( + config.attn_resolutions is not None + and curr_res in config.attn_resolutions + and config.attn_type == "vanilla" + ): + attn.append(ChameleonVQVAEAttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = ChameleonVQVAEDecoderConvUpsample(block_in) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = torch.nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = torch.nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, hidden_state: torch.FloatTensor) -> torch.FloatTensor: + hidden_state = self.conv_in(hidden_state) + + # middle + hidden_state = self.mid.block_1(hidden_state) + hidden_state = self.mid.attn_1(hidden_state) + hidden_state = self.mid.block_2(hidden_state) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + hidden_state = self.up[i_level].block[i_block](hidden_state) + if len(self.up[i_level].attn) > 0: + hidden_state = self.up[i_level].attn[i_block](hidden_state) + if i_level != 0: + hidden_state = self.up[i_level].upsample(hidden_state) + + hidden_state = self.norm_out(hidden_state) + hidden_state *= torch.sigmoid(hidden_state) + hidden_state = self.conv_out(hidden_state) + return hidden_state + + CHAMELEON_VQ_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads @@ -1068,33 +1193,79 @@ def __init__(self, config: ChameleonVQVAEConfig): super().__init__(config) self.encoder = ChameleonVQVAEEncoder(config) + self.decoder = ChameleonVQVAEDecoder(config) self.quantize = ChameleonVQVAEVectorQuantizer(config) self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1) self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, config.latent_channels, 1) self.eval() # Chameleon's VQ model is frozen - def encode(self, pixel_values: torch.LongTensor): + def encode(self, pixel_values: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.LongTensor]: + """ + Encodes pixel values into quantized tokens. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): + The tensors corresponding to the input images. + + Returns: + quant (`torch.FloatTensor` of shape `(batch_size, embed_dim, quantize.quant_state_dims[0], quantize.quant_state_dims[1])`): + Embeddings of quantized tokens. + emb_loss (`torch.FloatTensor`): + Embedding loss. + indices (`torch.LongTensor` of shape `(batch_size, quantize.quant_state_dims[0] * quantize.quant_state_dims[1])`): + Token IDs + """ hidden_states = self.encoder(pixel_values) hidden_states = self.quant_conv(hidden_states) quant, emb_loss, indices = self.quantize(hidden_states) return quant, emb_loss, indices + def decode(self, image_tokens: torch.LongTensor) -> torch.FloatTensor: + """ + Decodes quantized token IDs into pixel values. + + Args: + image_tokens (`torch.LongTensor` of shape `(batch_size, quantize.quant_state_dims[0] * quantize.quant_state_dims[1])`): + Batch of token IDs. + + Returns: + (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + Pixel values decoded from the token IDs. + """ + if image_tokens.shape[1] != self.quantize.quant_state_dims[0] * self.quantize.quant_state_dims[1]: + raise ValueError( + f"Expected `image_tokens` to have shape `(batch_size, {self.quantize.quant_state_dims[0] * self.quantize.quant_state_dims[1]})`, " + f"but got shape `{image_tokens.shape}`." + ) + codebook_entry = self.quantize.get_codebook_entry(image_tokens) + hidden_states = self.post_quant_conv(codebook_entry) + pixel_values = self.decoder(hidden_states) + return pixel_values + class ChameleonImageVocabularyMapping: """ A class for mapping discrete image tokens from VQGAN to BPE tokens. """ - def __init__(self, vocab_map): + def __init__( + self, + vocab_map: Dict[str, int], + image_token_id: int, + boi_token_id: int, + eoi_token_id: int, + ): self.vocab_map = vocab_map - self.image_token_id = vocab_map.get("") + self.image_token_id = image_token_id + self.boi_token_id = boi_token_id + self.eoi_token_id = eoi_token_id @cached_property def val2name(self): return {v: k for k, v in self.vocab_map.items()} @cached_property - def image_tokens(self): + def image_token_ids(self): return sorted([val for name, val in self.vocab_map.items() if name.startswith("IMGIMG")]) @cached_property @@ -1104,15 +1275,18 @@ def bpe2img(self): def remap(old_name: str) -> str: return "".join(img_tkn_chr_mapping.get(c, c) for c in old_name[len("IMGIMG") : -1]) - return {tok: int(remap(self.val2name[tok])) for tok in self.image_tokens} + return {tok: int(remap(self.val2name[tok])) for tok in self.image_token_ids} @cached_property def img2bpe(self): return {v: k for k, v in self.bpe2img.items()} @cached_property - def bpe2img_search_tensors(self): - return torch.tensor(sorted(self.bpe2img.keys())), torch.tensor(sorted(self.bpe2img.values())) + def bpe2img_mapping_tensor(self): + mapping = torch.zeros(max(self.bpe2img.keys()) + 1, dtype=torch.int) + for k, v in self.bpe2img.items(): + mapping[k] = v + return mapping @cached_property def img2bpe_mapping_tensor(self): @@ -1121,11 +1295,6 @@ def img2bpe_mapping_tensor(self): mapping[k] = v return mapping - def convert_img2bpe(self, img_batch: torch.Tensor) -> torch.Tensor: - device = img_batch.device - img_tokens = self.img2bpe_mapping_tensor[img_batch.to("cpu")] - return img_tokens.to(device) - CHAMELEON_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the @@ -1152,7 +1321,7 @@ class ChameleonPreTrainedModel(PreTrainedModel): config_class = ChameleonConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["ChameleonDecoderLayer", "ChameleonSwinDecoderLayer"] + _no_split_modules = ["ChameleonDecoderLayer", "ChameleonSwinDecoderLayer", "ChameleonVQVAE"] _skip_keys_device_placement = ["past_key_values", "causal_mask"] _supports_flash_attn_2 = True _supports_sdpa = True @@ -1262,7 +1431,22 @@ def __init__(self, config: ChameleonConfig): self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.vocabulary_mapping = ChameleonImageVocabularyMapping(config.vocabulary_map) + self.vocabulary_mapping = ChameleonImageVocabularyMapping( + config.vocabulary_map, + config.image_token_id, + config.boi_token_id, + config.eoi_token_id, + ) + self.register_buffer( + "img2bpe_mapping_tensor", + self.vocabulary_mapping.img2bpe_mapping_tensor, + persistent=False, + ) + self.register_buffer( + "bpe2img_mapping_tensor", + self.vocabulary_mapping.bpe2img_mapping_tensor, + persistent=False, + ) decoder_layer = ChameleonDecoderLayer if not self.config.swin_norm else ChameleonSwinDecoderLayer self.layers = nn.ModuleList( [decoder_layer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] @@ -1274,12 +1458,65 @@ def __init__(self, config: ChameleonConfig): # Initialize weights and apply final processing self.post_init() + @property + def image_seq_length(self) -> int: + return self.vqmodel.quantize.quant_state_dims[0] * self.vqmodel.quantize.quant_state_dims[1] + def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value + def convert_img2bpe_tokens(self, img_batch: torch.LongTensor) -> torch.LongTensor: + """ + Converts image tokens generated by the VQVAE model into BPE tokens compatible with the text tokenizer. + + Notes: + - It is important to move the `img_batch` tensor to the same device as the `img2bpe_mapping_tensor` buffer + as Accelerate may move the buffer to a different device when loading the model with `device_map="auto"`. + - Accelerate up to version 0.33.0 (and also maybe later versions) has a bug where buffers in downstream modules + may be ignored when inferring the proper device map. See: https://github.com/huggingface/accelerate/blob/79ca85c27df292dbf64cfa2bcc12dbb62fbe9267/src/accelerate/utils/modeling.py#L1273 + This causes the `img2bpe_mapping_tensor` buffer to be placed on the CPU by default, which may cause a performance + loss--especially with prompts that contain many images. No action needs to be done when this bug is fixed. + + Args: + img_batch (`torch.Tensor` of shape `(batch_size, image_seq_length)`): + The image tokens generated by the VQVAE model. + + Returns: + `torch.Tensor` of shape `(batch_size, image_seq_length)`: + The image tokens converted to be compatible with the text tokenizer's BPE tokens. + """ + device = img_batch.device + img_tokens = self.img2bpe_mapping_tensor[img_batch.to(self.img2bpe_mapping_tensor.device)] + return img_tokens.to(device) + + def convert_bpe2img_tokens(self, bpe_batch: torch.LongTensor) -> torch.LongTensor: + """ + Converts image tokens that are compatible with the text tokenizer into image tokens compatible with the VQVAE + model. + + Notes: + - It is important to move the `img_batch` tensor to the same device as the `img2bpe_mapping_tensor` buffer + as Accelerate may move the buffer to a different device when loading the model with `device_map="auto"`. + - Accelerate up to version 0.33.0 (and also maybe later versions) has a bug where buffers in downstream modules + may be ignored when inferring the proper device map. See: https://github.com/huggingface/accelerate/blob/79ca85c27df292dbf64cfa2bcc12dbb62fbe9267/src/accelerate/utils/modeling.py#L1273 + This causes the `img2bpe_mapping_tensor` buffer to be placed on the CPU by default, which may cause a performance + loss--especially when generating interleaved text & images. No action needs to be done when this bug is fixed. + + Args: + bpe_batch (`torch.Tensor` of shape `(batch_size, image_seq_length)`): + The image tokens compatible with the text tokenizer. + + Returns: + `torch.Tensor` of shape `(batch_size, image_seq_length)`: + The image tokens converted to be compatible with the VQVAE model. + """ + device = bpe_batch.device + img_tokens = self.bpe2img_mapping_tensor[bpe_batch.to(self.bpe2img_mapping_tensor.device)] + return img_tokens.to(device) + def get_image_tokens(self, pixel_values: torch.FloatTensor): """ Tokenizes images into discrete tokens with VQGAN module. Converts @@ -1289,12 +1526,30 @@ def get_image_tokens(self, pixel_values: torch.FloatTensor): Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): The tensors corresponding to the input images. + + Returns: + `torch.Tensor` of shape `(batch_size, image_seq_length)`: + The BPE tokens generated by the model. """ - batch_size = pixel_values.shape[0] _, _, image_toks = self.vqmodel.encode(pixel_values) - bpe_toks = self.vocabulary_mapping.convert_img2bpe(image_toks) - bpe_toks = bpe_toks.view(batch_size, -1) - return bpe_toks + return self.convert_img2bpe_tokens(image_toks) + + def decode_image_tokens(self, bpe_tokens: torch.LongTensor) -> torch.LongTensor: + """ + Converts BPE tokens generated by the model into discrete image tokens + compatible with the VQGAN module, then decodes them into pixel values. + + Args: + bpe_tokens (`torch.tensor` of shape `(batch, image_seq_length)`): + The BPE tokens generated by the model. + + Returns: + `torch.Tensor` of shape `(batch, num_channels, 512, 512)`: + """ + if bpe_tokens.shape[1] != self.image_seq_length: + raise ValueError(f"All batches must have {self.image_seq_length} tokens.") + image_tensor = self.convert_bpe2img_tokens(bpe_tokens) + return self.vqmodel.decode(image_tensor) @add_start_docstrings_to_model_forward(CHAMELEON_INPUTS_DOCSTRING) @add_code_sample_docstrings( @@ -1504,9 +1759,10 @@ def _update_causal_mask( class ChameleonForConditionalGeneration(ChameleonPreTrainedModel): _tied_weights_keys = ["lm_head.weight"] - def __init__(self, config): + def __init__(self, config: ChameleonConfig): super().__init__(config) self.model = ChameleonModel(config) + self.vocabulary_mapping = self.model.vocabulary_mapping self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) @@ -1531,6 +1787,169 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + def _prepare_generation_config( + self, + generation_config: Optional[GenerationConfig] = None, + multimodal_generation_mode: Optional[ + Literal["text-only", "image-only", "interleaved-text-image", "unrestricted"] + ] = None, + **kwargs, + ): + if ( + multimodal_generation_mode == "image-only" + and kwargs.get("max_length") is None + and kwargs.get("max_new_tokens") is None + and ( + generation_config is None + or (generation_config.max_length is None and generation_config.max_new_tokens is None) + ) + ): + kwargs["max_new_tokens"] = self.model.image_seq_length + 2 + generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs) + if multimodal_generation_mode is not None: + generation_config.multimodal_generation_mode = multimodal_generation_mode + if ( + not hasattr(generation_config, "multimodal_generation_mode") + or generation_config.multimodal_generation_mode is None + ): + generation_config.multimodal_generation_mode = "text-only" + return generation_config, model_kwargs + + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + multimodal_generation_mode: Optional[ + Literal["text-only", "image-only", "interleaved-text-image", "unrestricted"] + ] = None, + **kwargs, + ) -> Union[GenerateOutput, torch.LongTensor]: + generation_config, model_kwargs = self._prepare_generation_config( + generation_config, multimodal_generation_mode, **kwargs + ) + + inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( + inputs, generation_config.bos_token_id, model_kwargs + ) + input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") + + # Prepare `max_length` depending on other stopping criteria. + input_ids_length = input_ids.shape[-1] + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None + generation_config = self._prepare_generated_length( + generation_config=generation_config, + has_default_max_length=has_default_max_length, + has_default_min_length=has_default_min_length, + model_input_name=model_input_name, + inputs_tensor=inputs_tensor, + input_ids_length=input_ids_length, + ) + + if logits_processor is None: + logits_processor = LogitsProcessorList() + if generation_config.multimodal_generation_mode == "text-only": + logits_processor.append( + SuppressTokensLogitsProcessor( + suppress_tokens=self.vocabulary_mapping.image_token_ids + + [ + self.vocabulary_mapping.boi_token_id, + self.vocabulary_mapping.eoi_token_id, + ], + device=self.device, + ) + ) + elif generation_config.multimodal_generation_mode == "image-only": + inferred_max_new_tokens = generation_config.max_length - input_ids_length + if inferred_max_new_tokens < self.model.image_seq_length + 2: + warnings.warn( + f"The VQVAE decoder expects to receive {self.model.image_seq_length} image tokens to generate an image." + "And Chameleon wraps the image tokens with the `beginning-of-image` and `end-of-image` tokens when on image generation mode." + f"Therefore, the `max_new_tokens` must be at least {self.model.image_seq_length + 2}." + f"However, the inferred `max_new_tokens` from the generation config is only {inferred_max_new_tokens}." + "You would need to pad the output tokens with dummy image tokens before passing them to the VQVAE decoder." + f"To avoid this warning, set `max_new_tokens` to at least {self.model.image_seq_length + 2}." + ) + allowed_tokens = self.vocabulary_mapping.image_token_ids + [ + self.config.eos_token_id, + self.vocabulary_mapping.boi_token_id, + self.vocabulary_mapping.eoi_token_id, + ] + suppress_tokens = [token_id for token_id in range(self.vocab_size) if token_id not in allowed_tokens] + logits_processor.extend( + [ + AllowOnlyTokensAtRelativeOffsetLogitsProcessor( + trigger_token_id=self.vocabulary_mapping.boi_token_id, + allowed_token_ids=[self.vocabulary_mapping.eoi_token_id], + offset=self.model.image_seq_length + 1, + exclusive=True, + device=self.device, + ), + AllowOnlyTokensInRelativeWindowLogitsProcessor( + trigger_token_id=self.vocabulary_mapping.boi_token_id, + allowed_token_ids=self.vocabulary_mapping.image_token_ids, + window_width=self.model.image_seq_length, + exclusive=True, + device=self.device, + ), + # Don't start generating an image if there aren't enough space for the + # rest of the image tokens. + SuppressTokensInIndexRangeLogitsProcessor( + suppress_tokens=[self.vocabulary_mapping.boi_token_id], + start_index=generation_config.max_length - self.model.image_seq_length - 1, + device=self.device, + ), + # Allow only image tokens + SuppressTokensLogitsProcessor(suppress_tokens=suppress_tokens, device=self.device), + # Force image generation + SuppressTokensAtBeginLogitsProcessor( + begin_suppress_tokens=[self.config.eos_token_id], + begin_index=input_ids_length, + device=self.device, + ), + ] + ) + elif generation_config.multimodal_generation_mode == "interleaved-text-image": + logits_processor.extend( + [ + AllowOnlyTokensAtRelativeOffsetLogitsProcessor( + trigger_token_id=self.vocabulary_mapping.boi_token_id, + allowed_token_ids=[self.vocabulary_mapping.eoi_token_id], + offset=self.model.image_seq_length + 1, + exclusive=True, + device=self.device, + ), + AllowOnlyTokensInRelativeWindowLogitsProcessor( + trigger_token_id=self.vocabulary_mapping.boi_token_id, + allowed_token_ids=self.vocabulary_mapping.image_token_ids, + window_width=self.model.image_seq_length, + exclusive=True, + device=self.device, + ), + # Don't start generating an image if there aren't enough space for the + # rest of the image tokens. + SuppressTokensInIndexRangeLogitsProcessor( + suppress_tokens=[self.vocabulary_mapping.boi_token_id], + start_index=generation_config.max_length - self.model.image_seq_length - 1, + device=self.device, + ), + ] + ) + elif generation_config.multimodal_generation_mode == "unrestricted": + pass + else: + raise ValueError( + f"Unknown multimodal generation mode: {generation_config.multimodal_generation_mode}. Please choose one of 'unrestricted', 'text-only', 'image-only', or 'interleaved-text-image'." + ) + return super().generate( + inputs=inputs, + generation_config=generation_config, + logits_processor=logits_processor, + **kwargs, + ) + @add_start_docstrings_to_model_forward(CHAMELEON_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -1583,10 +2002,19 @@ def forward( ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if pixel_values is not None: + batch_size, sequence_length = input_ids.shape + input_ids = input_ids.view(batch_size * sequence_length) + image_tokens = self.model.get_image_tokens(pixel_values) + special_image_mask = input_ids == self.vocabulary_mapping.image_token_id + image_tokens = image_tokens.to(input_ids.device, input_ids.dtype) + input_ids = input_ids.masked_scatter(special_image_mask, image_tokens) + input_ids = input_ids.view(batch_size, sequence_length) + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, - pixel_values=pixel_values, + pixel_values=None, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, @@ -1602,12 +2030,10 @@ def forward( logits = self.lm_head(hidden_states) logits = logits.float() - # Disallow image tokens which does not include special begin-image and end-image tokens - image_tokens = self.model.vocabulary_mapping.image_tokens - logits[:, :, image_tokens] = torch.finfo(logits.dtype).min - loss = None if labels is not None: + mask = labels != -100 + labels = torch.where(mask, input_ids, labels) # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() @@ -1680,3 +2106,17 @@ def prepare_inputs_for_generation( } ) return model_inputs + + def decode_image_tokens(self, bpe_tokens: torch.Tensor): + """ + Converts BPE tokens generated by the model into discrete image tokens + compatible with the VQGAN module, then decodes them into pixel values. + + Args: + bpe_tokens (`torch.tensor` of shape `(batch, image_seq_length)`): + The BPE tokens generated by the model. + + Returns: + `torch.Tensor` of shape `(batch, num_channels, 512, 512)`: + """ + return self.model.decode_image_tokens(bpe_tokens) diff --git a/src/transformers/models/chameleon/processing_chameleon.py b/src/transformers/models/chameleon/processing_chameleon.py index 1480808336d..f5764078df1 100644 --- a/src/transformers/models/chameleon/processing_chameleon.py +++ b/src/transformers/models/chameleon/processing_chameleon.py @@ -16,13 +16,41 @@ Processor class for Chameleon. """ +import sys from typing import List, Optional, Union from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput -from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy -from ...utils import TensorType +from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import is_torch_available + + +if sys.version_info >= (3, 11): + from typing import Unpack +else: + from typing_extensions import Unpack + + +if is_torch_available(): + import torch + + +class ChameleonTextKwargs(TextKwargs, total=False): + return_for_text_completion: bool + + +class ChameleonProcessorKwargs(ProcessingKwargs, total=False): + text_kwargs: ChameleonTextKwargs + _defaults = { + "text_kwargs": { + "padding": False, + "return_for_text_completion": False, + }, + "common_kwargs": { + "return_tensors": "pt", + }, + } class ChameleonProcessor(ProcessorMixin): @@ -45,7 +73,7 @@ class ChameleonProcessor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] - tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") + tokenizer_class = "LlamaTokenizerFast" image_processor_class = "ChameleonImageProcessor" def __init__(self, image_processor, tokenizer, image_seq_length: int = 1024, image_token: str = ""): @@ -57,13 +85,9 @@ def __init__(self, image_processor, tokenizer, image_seq_length: int = 1024, ima def __call__( self, - text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, - images: ImageInput = None, - padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = None, - max_length: int = None, - return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, - return_for_text_completion: bool = False, + text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + images: Optional[ImageInput] = None, + **kwargs: Unpack[ChameleonProcessorKwargs], ) -> BatchFeature: """ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` @@ -80,26 +104,6 @@ def __call__( images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch tensor. Both channels-first and channels-last formats are supported. - padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): - Select a strategy to pad the returned sequences (according to the model's padding side and padding - index) among: - - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single - sequence if provided). - - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum - acceptable input length for the model if that argument is not provided. - - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different - lengths). - max_length (`int`, *optional*): - Maximum length of the returned list and optionally padding length (see above). - truncation (`bool`, *optional*): - Activates truncation to cut input sequences longer than `max_length` to `max_length`. - return_tensors (`str` or [`~utils.TensorType`], *optional*): - If set, will return tensors of a particular framework. Acceptable values are: - - - `'tf'`: Return TensorFlow `tf.constant` objects. - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: @@ -114,6 +118,15 @@ def __call__( text = [text] elif not isinstance(text, list) and not isinstance(text[0], str): raise TypeError("Invalid input text. Please provide a string, or a list of strings") + if text is None and images is None: + raise ValueError("You must provide either text or images") + + output_kwargs = self._merge_kwargs( + ChameleonProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + return_for_text_completion = output_kwargs["text_kwargs"].pop("return_for_text_completion", False) # Replace the image token with the expanded image token sequence prompt_strings = [] @@ -124,19 +137,12 @@ def __call__( sample += self.tokenizer.sep_token # special Chameleon treatment to add sep for chat mode prompt_strings.append(sample) - data = self.tokenizer( - prompt_strings, - return_tensors=return_tensors, - padding=padding, - truncation=truncation, - max_length=max_length, - ) + data = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"]) if images is not None: - pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"] - data["pixel_values"] = pixel_values + data["pixel_values"] = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"] - return BatchFeature(data=data, tensor_type=return_tensors) + return BatchFeature(data=data, tensor_type=output_kwargs["common_kwargs"]["return_tensors"]) # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama def batch_decode(self, *args, **kwargs): @@ -160,3 +166,17 @@ def model_input_names(self): tokenizer_input_names = self.tokenizer.model_input_names image_processor_input_names = self.image_processor.model_input_names return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + + def postprocess_pixel_values(self, pixel_values: "torch.FloatTensor") -> "torch.Tensor": + """ + Postprocess a batch of pixel values to images. + + Args: + pixel_values (`np.ndarray` of shape `(batch_size, num_channels, image_size, image_size)` or `(num_channels, image_size, image_size)`): + A batch or a single tensor of pixel values to postprocess. + + Returns: + `torch.Tensor` of shape `(batch_size, num_channels, image_size, image_size)`: + The postprocessed images. + """ + return self.image_processor.postprocess(pixel_values) diff --git a/tests/models/chameleon/test_image_processing_chameleon.py b/tests/models/chameleon/test_image_processing_chameleon.py index 4a5c8c54679..f729e68529c 100644 --- a/tests/models/chameleon/test_image_processing_chameleon.py +++ b/tests/models/chameleon/test_image_processing_chameleon.py @@ -204,3 +204,16 @@ def test_nested_input(self): # Image processor should return same pixel values, independently of input format self.assertTrue((encoded_images_nested == encoded_images).all()) + + def test_postprocessing(self): + image_processing = self.image_processing_class(**self.image_processor_dict) + # Pixel values for an image with 3 channels and 32x32 resolution + pixel_values_single = torch.zeros((3, 32, 32)) + # Pixel values for a batch of 2 images with 3 channels and 32x32 resolution + pixel_values_batch = torch.zeros((2, 3, 32, 32)) + + for pixel_values in [pixel_values_single, pixel_values_batch]: + unnormalized_pixel_values = image_processing.postprocess(pixel_values) + self.assertEqual(unnormalized_pixel_values.shape, pixel_values.shape) + expected_pixel_values = torch.full_like(pixel_values, 128) + self.assertTrue(torch.equal(unnormalized_pixel_values, expected_pixel_values)) diff --git a/tests/models/chameleon/test_modeling_chameleon.py b/tests/models/chameleon/test_modeling_chameleon.py index 16e0a548e6d..f495bc967bd 100644 --- a/tests/models/chameleon/test_modeling_chameleon.py +++ b/tests/models/chameleon/test_modeling_chameleon.py @@ -27,6 +27,7 @@ require_read_token, require_torch, require_torch_gpu, + require_torch_multi_gpu, slow, torch_device, ) @@ -61,6 +62,8 @@ def __init__( use_labels=True, vocab_size=99, image_token_id=98, + boi_token_id=97, + eoi_token_id=96, hidden_size=32, num_hidden_layers=2, num_attention_heads=2, @@ -90,6 +93,8 @@ def __init__( self.use_labels = use_labels self.vocab_size = vocab_size self.image_token_id = image_token_id + self.boi_token_id = boi_token_id + self.eoi_token_id = eoi_token_id self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads @@ -141,7 +146,9 @@ def get_config(self): start = self.vq_img_token_start_id end = self.vq_img_token_start_id + self.vq_num_embeds for i in range(start, end): - vocab_map[i] = f"IMGIMGBS{i}" # dummy str for each token, anything starting with IMGIMG + image_token_infix = "".join(chr(ord("A") + int(c)) for c in str(i)) + # dummy str for each image token, anything starting with IMGIMG + vocab_map[i] = f"IMGIMG{image_token_infix}Z" return ChameleonConfig( vocab_size=self.vocab_size, @@ -160,6 +167,9 @@ def get_config(self): pad_token_id=self.pad_token_id, vocabulary_map={v: k for k, v in vocab_map.items()}, vq_config=self.get_vq_config(), + image_token_id=self.image_token_id, + boi_token_id=self.boi_token_id, + eoi_token_id=self.eoi_token_id, ) def get_vq_config(self): @@ -457,3 +467,29 @@ def test_model_7b_multi_image(self): generated_ids = model.generate(**inputs, max_new_tokens=40, do_sample=False) text = processor.batch_decode(generated_ids, skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, text) + + @slow + @require_bitsandbytes + @require_read_token + @require_torch_multi_gpu + def test_model_7b_multi_gpu(self): + model = ChameleonForConditionalGeneration.from_pretrained( + "facebook/chameleon-7b", + load_in_4bit=True, + device_map="auto", + max_memory={0: "1GB"}, + ) + processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b") + + image = Image.open( + requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw + ) + prompt = "Describe what do you see here and tell me about the history behind it?" + + inputs = processor(prompt, images=image, return_tensors="pt").to(model.device, torch.float16) + + # greedy generation outputs + EXPECTED_TEXT_COMPLETION = ['Describe what do you see here and tell me about the history behind it?The image depicts a star map, with a bright blue line extending across the center of the image. The line is labeled "390 light years" and is accompanied by a small black and'] # fmt: skip + generated_ids = model.generate(**inputs, max_new_tokens=40, do_sample=False) + text = processor.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, text) diff --git a/tests/models/chameleon/test_processor_chameleon.py b/tests/models/chameleon/test_processor_chameleon.py new file mode 100644 index 00000000000..74314e3d4c1 --- /dev/null +++ b/tests/models/chameleon/test_processor_chameleon.py @@ -0,0 +1,16 @@ +import tempfile +import unittest + +from transformers import ChameleonProcessor + +from ...test_processing_common import ProcessorTesterMixin + + +class ChameleonProcessorTest(ProcessorTesterMixin, unittest.TestCase): + from_pretrained_id = "leloy/Anole-7b-v0.1-hf" + processor_class = ChameleonProcessor + + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + processor = self.processor_class.from_pretrained(self.from_pretrained_id) + processor.save_pretrained(self.tmpdirname) diff --git a/tests/test_processing_common.py b/tests/test_processing_common.py index a30c6363b9d..691cea12b03 100644 --- a/tests/test_processing_common.py +++ b/tests/test_processing_common.py @@ -49,6 +49,8 @@ @require_torch class ProcessorTesterMixin: processor_class = None + text_data_arg_name = "input_ids" + images_data_arg_name = "pixel_values" def prepare_processor_dict(self): return {} @@ -136,14 +138,14 @@ def test_tokenizer_defaults_preserved_by_kwargs(self): image_input = self.prepare_image_inputs() inputs = processor(text=input_str, images=image_input, return_tensors="pt") - self.assertEqual(len(inputs["input_ids"][0]), 117) + self.assertEqual(len(inputs[self.text_data_arg_name][0]), 117) @require_torch @require_vision def test_image_processor_defaults_preserved_by_image_kwargs(self): if "image_processor" not in self.processor_class.attributes: self.skipTest(f"image_processor attribute not present in {self.processor_class}") - image_processor = self.get_component("image_processor", size=(234, 234)) + image_processor = self.get_component("image_processor", size=(234, 234), crop_size=(234, 234)) tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length") processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) @@ -153,7 +155,7 @@ def test_image_processor_defaults_preserved_by_image_kwargs(self): image_input = self.prepare_image_inputs() inputs = processor(text=input_str, images=image_input) - self.assertEqual(len(inputs["pixel_values"][0][0]), 234) + self.assertEqual(len(inputs[self.images_data_arg_name][0][0]), 234) @require_vision @require_torch @@ -161,7 +163,7 @@ def test_kwargs_overrides_default_tokenizer_kwargs(self): if "image_processor" not in self.processor_class.attributes: self.skipTest(f"image_processor attribute not present in {self.processor_class}") image_processor = self.get_component("image_processor") - tokenizer = self.get_component("tokenizer", padding="longest") + tokenizer = self.get_component("tokenizer", padding="max_length", max_length=117) processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) self.skip_processor_without_typed_kwargs(processor) @@ -171,7 +173,7 @@ def test_kwargs_overrides_default_tokenizer_kwargs(self): inputs = processor( text=input_str, images=image_input, return_tensors="pt", max_length=112, padding="max_length" ) - self.assertEqual(len(inputs["input_ids"][0]), 112) + self.assertEqual(len(inputs[self.text_data_arg_name][0]), 112) @require_torch @require_vision @@ -187,8 +189,8 @@ def test_kwargs_overrides_default_image_processor_kwargs(self): input_str = "lower newer" image_input = self.prepare_image_inputs() - inputs = processor(text=input_str, images=image_input, size=[224, 224]) - self.assertEqual(len(inputs["pixel_values"][0][0]), 224) + inputs = processor(text=input_str, images=image_input, size=[224, 224], crop_size=(224, 224)) + self.assertEqual(len(inputs[self.images_data_arg_name][0][0]), 224) @require_torch @require_vision @@ -208,12 +210,13 @@ def test_unstructured_kwargs(self): images=image_input, return_tensors="pt", size={"height": 214, "width": 214}, + crop_size={"height": 214, "width": 214}, padding="max_length", max_length=76, ) - self.assertEqual(inputs["pixel_values"].shape[2], 214) - self.assertEqual(len(inputs["input_ids"][0]), 76) + self.assertEqual(inputs[self.images_data_arg_name].shape[2], 214) + self.assertEqual(len(inputs[self.text_data_arg_name][0]), 76) @require_torch @require_vision @@ -233,13 +236,14 @@ def test_unstructured_kwargs_batched(self): images=image_input, return_tensors="pt", size={"height": 214, "width": 214}, + crop_size={"height": 214, "width": 214}, padding="longest", max_length=76, ) - self.assertEqual(inputs["pixel_values"].shape[2], 214) + self.assertEqual(inputs[self.images_data_arg_name].shape[2], 214) - self.assertEqual(len(inputs["input_ids"][0]), 6) + self.assertEqual(len(inputs[self.text_data_arg_name][0]), 6) @require_torch @require_vision @@ -260,6 +264,7 @@ def test_doubly_passed_kwargs(self): images=image_input, images_kwargs={"size": {"height": 222, "width": 222}}, size={"height": 214, "width": 214}, + crop_size={"height": 214, "width": 214}, ) @require_torch @@ -279,16 +284,19 @@ def test_structured_kwargs_nested(self): # Define the kwargs for each modality all_kwargs = { "common_kwargs": {"return_tensors": "pt"}, - "images_kwargs": {"size": {"height": 214, "width": 214}}, + "images_kwargs": { + "size": {"height": 214, "width": 214}, + "crop_size": {"height": 214, "width": 214}, + }, "text_kwargs": {"padding": "max_length", "max_length": 76}, } inputs = processor(text=input_str, images=image_input, **all_kwargs) self.skip_processor_without_typed_kwargs(processor) - self.assertEqual(inputs["pixel_values"].shape[2], 214) + self.assertEqual(inputs[self.images_data_arg_name].shape[2], 214) - self.assertEqual(len(inputs["input_ids"][0]), 76) + self.assertEqual(len(inputs[self.text_data_arg_name][0]), 76) @require_torch @require_vision @@ -307,14 +315,17 @@ def test_structured_kwargs_nested_from_dict(self): # Define the kwargs for each modality all_kwargs = { "common_kwargs": {"return_tensors": "pt"}, - "images_kwargs": {"size": {"height": 214, "width": 214}}, + "images_kwargs": { + "size": {"height": 214, "width": 214}, + "crop_size": {"height": 214, "width": 214}, + }, "text_kwargs": {"padding": "max_length", "max_length": 76}, } inputs = processor(text=input_str, images=image_input, **all_kwargs) - self.assertEqual(inputs["pixel_values"].shape[2], 214) + self.assertEqual(inputs[self.images_data_arg_name].shape[2], 214) - self.assertEqual(len(inputs["input_ids"][0]), 76) + self.assertEqual(len(inputs[self.text_data_arg_name][0]), 76) class MyProcessor(ProcessorMixin): diff --git a/utils/check_repo.py b/utils/check_repo.py index acd6662cc2f..b68d1b30768 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -128,7 +128,7 @@ "SeamlessM4TTextToUnitModel", # Building part of bigger (tested) model. "SeamlessM4TCodeHifiGan", # Building part of bigger (tested) model. "SeamlessM4TTextToUnitForConditionalGeneration", # Building part of bigger (tested) model. - "ChameleonVQVAE", # VQVAE here is used only for encoding (discretizing) and is tested as part of bigger model + "ChameleonVQVAE", # Building part of bigger (tested) model. ] # Update this list with test files that don't have a tester with a `all_model_classes` variable and which don't