diff --git a/src/transformers/models/align/processing_align.py b/src/transformers/models/align/processing_align.py index a5846a87d23696..7cfe14e52b44f9 100644 --- a/src/transformers/models/align/processing_align.py +++ b/src/transformers/models/align/processing_align.py @@ -18,16 +18,11 @@ from typing import List, Union - -try: - from typing import Unpack -except ImportError: - from typing_extensions import Unpack - from ...image_utils import ImageInput from ...processing_utils import ( ProcessingKwargs, ProcessorMixin, + Unpack, ) from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput diff --git a/src/transformers/models/altclip/processing_altclip.py b/src/transformers/models/altclip/processing_altclip.py index 5343498842832c..153ecc2e2bfc87 100644 --- a/src/transformers/models/altclip/processing_altclip.py +++ b/src/transformers/models/altclip/processing_altclip.py @@ -16,10 +16,16 @@ Image/Text processor class for AltCLIP """ -import warnings +from typing import List, Union -from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import BatchEncoding +from ...image_utils import ImageInput +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput +from ...utils.deprecation import deprecate_kwarg + + +class AltClipProcessorKwargs(ProcessingKwargs, total=False): + _defaults = {} class AltCLIPProcessor(ProcessorMixin): @@ -41,17 +47,8 @@ class AltCLIPProcessor(ProcessorMixin): image_processor_class = "CLIPImageProcessor" tokenizer_class = ("XLMRobertaTokenizer", "XLMRobertaTokenizerFast") - def __init__(self, image_processor=None, tokenizer=None, **kwargs): - feature_extractor = None - if "feature_extractor" in kwargs: - warnings.warn( - "The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`" - " instead.", - FutureWarning, - ) - feature_extractor = kwargs.pop("feature_extractor") - - image_processor = image_processor if image_processor is not None else feature_extractor + @deprecate_kwarg(old_name="feature_extractor", version="5.0.0", new_name="image_processor") + def __init__(self, image_processor=None, tokenizer=None): if image_processor is None: raise ValueError("You need to specify an `image_processor`.") if tokenizer is None: @@ -59,7 +56,14 @@ def __init__(self, image_processor=None, tokenizer=None, **kwargs): super().__init__(image_processor, tokenizer) - def __call__(self, text=None, images=None, return_tensors=None, **kwargs): + def __call__( + self, + images: ImageInput = None, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + audio=None, + videos=None, + **kwargs: Unpack[AltClipProcessorKwargs], + ) -> BatchEncoding: """ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` and `kwargs` arguments to XLMRobertaTokenizerFast's [`~XLMRobertaTokenizerFast.__call__`] if `text` is not @@ -68,22 +72,20 @@ def __call__(self, text=None, images=None, return_tensors=None, **kwargs): of the above two methods for more information. Args: - text (`str`, `List[str]`, `List[List[str]]`): + + images (`ImageInput`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`): The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). - images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): - The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch - tensor. Both channels-first and channels-last formats are supported. - return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - - `'tf'`: Return TensorFlow `tf.constant` objects. - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. - + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: @@ -95,13 +97,24 @@ def __call__(self, text=None, images=None, return_tensors=None, **kwargs): """ if text is None and images is None: - raise ValueError("You have to specify either text or images. Both cannot be none.") + raise ValueError("You must specify either text or images.") - if text is not None: - encoding = self.tokenizer(text, return_tensors=return_tensors, **kwargs) + if text is None and images is None: + raise ValueError("You must specify either text or images.") + output_kwargs = self._merge_kwargs( + AltClipProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + if text is not None: + encoding = self.tokenizer(text, **output_kwargs["text_kwargs"]) if images is not None: - image_features = self.image_processor(images, return_tensors=return_tensors, **kwargs) + image_features = self.image_processor(images, **output_kwargs["images_kwargs"]) + + # BC for explicit return_tensors + if "return_tensors" in output_kwargs["common_kwargs"]: + return_tensors = output_kwargs["common_kwargs"].pop("return_tensors", None) if text is not None and images is not None: encoding["pixel_values"] = image_features.pixel_values diff --git a/src/transformers/models/chinese_clip/image_processing_chinese_clip.py b/src/transformers/models/chinese_clip/image_processing_chinese_clip.py index 515c2de0cfc300..52349f84bffe0b 100644 --- a/src/transformers/models/chinese_clip/image_processing_chinese_clip.py +++ b/src/transformers/models/chinese_clip/image_processing_chinese_clip.py @@ -231,6 +231,7 @@ def preprocess( - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. """ + do_resize = do_resize if do_resize is not None else self.do_resize size = size if size is not None else self.size size = get_size_dict(size, default_to_square=False) diff --git a/src/transformers/models/chinese_clip/processing_chinese_clip.py b/src/transformers/models/chinese_clip/processing_chinese_clip.py index 1f44fc50aed576..2cfd314c649866 100644 --- a/src/transformers/models/chinese_clip/processing_chinese_clip.py +++ b/src/transformers/models/chinese_clip/processing_chinese_clip.py @@ -17,9 +17,15 @@ """ import warnings +from typing import List, Union -from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import BatchEncoding +from ...image_utils import ImageInput +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput + + +class ChineseClipProcessorKwargs(ProcessingKwargs, total=False): + _defaults = {} class ChineseCLIPProcessor(ProcessorMixin): @@ -60,7 +66,14 @@ def __init__(self, image_processor=None, tokenizer=None, **kwargs): super().__init__(image_processor, tokenizer) self.current_processor = self.image_processor - def __call__(self, text=None, images=None, return_tensors=None, **kwargs): + def __call__( + self, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + images: ImageInput = None, + audio=None, + videos=None, + **kwargs: Unpack[ChineseClipProcessorKwargs], + ) -> BatchEncoding: """ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` and `kwargs` arguments to BertTokenizerFast's [`~BertTokenizerFast.__call__`] if `text` is not `None` to encode @@ -79,12 +92,10 @@ def __call__(self, text=None, images=None, return_tensors=None, **kwargs): return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - - `'tf'`: Return TensorFlow `tf.constant` objects. - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. - + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: @@ -97,12 +108,20 @@ def __call__(self, text=None, images=None, return_tensors=None, **kwargs): if text is None and images is None: raise ValueError("You have to specify either text or images. Both cannot be none.") + output_kwargs = self._merge_kwargs( + ChineseClipProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) if text is not None: - encoding = self.tokenizer(text, return_tensors=return_tensors, **kwargs) - + encoding = self.tokenizer(text, **output_kwargs["text_kwargs"]) if images is not None: - image_features = self.image_processor(images, return_tensors=return_tensors, **kwargs) + image_features = self.image_processor(images, **output_kwargs["images_kwargs"]) + + # BC for explicit return_tensors + if "return_tensors" in output_kwargs["common_kwargs"]: + return_tensors = output_kwargs["common_kwargs"].pop("return_tensors", None) if text is not None and images is not None: encoding["pixel_values"] = image_features.pixel_values diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index 8b33c456924ded..53e83613a07c8f 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -20,11 +20,14 @@ import inspect import json import os +import sys +import typing import warnings from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union import numpy as np +import typing_extensions from .dynamic_module_utils import custom_object_save from .image_utils import ChannelDimension, is_valid_image, is_vision_available @@ -67,6 +70,11 @@ "AutoImageProcessor": "ImageProcessingMixin", } +if sys.version_info >= (3, 11): + Unpack = typing.Unpack +else: + Unpack = typing_extensions.Unpack + class TextKwargs(TypedDict, total=False): """ @@ -151,6 +159,8 @@ class methods and docstrings. Standard deviation to use if normalizing the image. do_pad (`bool`, *optional*): Whether to pad the image to the `(max_height, max_width)` of the images in the batch. + pad_size (`Dict[str, int]`, *optional*): + The size `{"height": int, "width" int}` to pad the images to. do_center_crop (`bool`, *optional*): Whether to center crop the image. data_format (`ChannelDimension` or `str`, *optional*): @@ -170,6 +180,7 @@ class methods and docstrings. image_mean: Optional[Union[float, List[float]]] image_std: Optional[Union[float, List[float]]] do_pad: Optional[bool] + pad_size: Optional[Dict[str, int]] do_center_crop: Optional[bool] data_format: Optional[ChannelDimension] input_data_format: Optional[Union[str, ChannelDimension]] @@ -814,7 +825,8 @@ class MyProcessingKwargs(ProcessingKwargs, CommonKwargs, TextKwargs, ImagesKwarg # check if this key was passed as a flat kwarg. if kwarg_value != "__empty__" and modality_key in non_modality_kwargs: raise ValueError( - f"Keyword argument {modality_key} was passed two times: in a dictionary for {modality} and as a **kwarg." + f"Keyword argument {modality_key} was passed two times:\n" + f"in a dictionary for {modality} and as a **kwarg." ) elif modality_key in kwargs: kwarg_value = kwargs.pop(modality_key, "__empty__") diff --git a/tests/models/altclip/test_processor_altclip.py b/tests/models/altclip/test_processor_altclip.py new file mode 100644 index 00000000000000..1aca2280969404 --- /dev/null +++ b/tests/models/altclip/test_processor_altclip.py @@ -0,0 +1,165 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import tempfile +import unittest + +from transformers import XLMRobertaTokenizer, XLMRobertaTokenizerFast +from transformers.testing_utils import require_torch, require_vision +from transformers.utils import is_vision_available + +from ...test_processing_common import ProcessorTesterMixin + + +if is_vision_available(): + from transformers import AltCLIPProcessor, CLIPImageProcessor + + +@require_vision +class AltClipProcessorTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = AltCLIPProcessor + + def setUp(self): + self.model_id = "BAAI/AltCLIP" + self.tmpdirname = tempfile.mkdtemp() + image_processor = CLIPImageProcessor() + tokenizer = XLMRobertaTokenizer.from_pretrained(self.model_id) + + processor = self.processor_class(image_processor, tokenizer) + + processor.save_pretrained(self.tmpdirname) + + def get_tokenizer(self, **kwargs): + return XLMRobertaTokenizer.from_pretrained(self.model_id, **kwargs) + + def get_rust_tokenizer(self, **kwargs): + return XLMRobertaTokenizerFast.from_pretrained(self.model_id, **kwargs) + + def get_image_processor(self, **kwargs): + return CLIPImageProcessor.from_pretrained(self.model_id, **kwargs) + + @require_torch + @require_vision + def test_unstructured_kwargs_batched(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = ["lower newer", "upper older longer string"] + image_input = self.prepare_image_inputs() * 2 + inputs = processor( + text=input_str, + images=image_input, + return_tensors="pt", + crop_size={"height": 214, "width": 214}, + padding="longest", + max_length=76, + ) + self.assertEqual(inputs["pixel_values"].shape[2], 214) + + self.assertEqual(len(inputs["input_ids"][0]), 7) + + def test_structured_kwargs_nested(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + # Define the kwargs for each modality + all_kwargs = { + "common_kwargs": {"return_tensors": "pt"}, + "images_kwargs": {"crop_size": {"height": 214, "width": 214}}, + "text_kwargs": {"padding": "max_length", "max_length": 76}, + } + + inputs = processor(text=input_str, images=image_input, **all_kwargs) + self.skip_processor_without_typed_kwargs(processor) + + self.assertEqual(inputs["pixel_values"].shape[2], 214) + + self.assertEqual(len(inputs["input_ids"][0]), 76) + + def test_structured_kwargs_nested_from_dict(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + # Define the kwargs for each modality + all_kwargs = { + "common_kwargs": {"return_tensors": "pt"}, + "images_kwargs": {"crop_size": {"height": 214, "width": 214}}, + "text_kwargs": {"padding": "max_length", "max_length": 76}, + } + + inputs = processor(text=input_str, images=image_input, **all_kwargs) + self.assertEqual(inputs["pixel_values"].shape[2], 214) + + self.assertEqual(len(inputs["input_ids"][0]), 76) + + def test_unstructured_kwargs(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + inputs = processor( + text=input_str, + images=image_input, + return_tensors="pt", + crop_size={"height": 214, "width": 214}, + padding="max_length", + max_length=76, + ) + + self.assertEqual(inputs["pixel_values"].shape[2], 214) + self.assertEqual(len(inputs["input_ids"][0]), 76) + + def test_image_processor_defaults_preserved_by_image_kwargs(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor", crop_size=(234, 234)) + tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + inputs = processor(text=input_str, images=image_input) + self.assertEqual(len(inputs["pixel_values"][0][0]), 234) diff --git a/tests/models/chinese_clip/test_processor_chinese_clip.py b/tests/models/chinese_clip/test_processor_chinese_clip.py index e433c38f789104..5b191ce2df0894 100644 --- a/tests/models/chinese_clip/test_processor_chinese_clip.py +++ b/tests/models/chinese_clip/test_processor_chinese_clip.py @@ -206,3 +206,129 @@ def test_model_input_names(self): inputs = processor(text=input_str, images=image_input) self.assertListEqual(list(inputs.keys()), processor.model_input_names) + + def test_unstructured_kwargs_batched(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = ["lower newer", "upper older longer string"] + image_input = self.prepare_image_inputs() * 2 + inputs = processor( + text=input_str, + images=image_input, + return_tensors="pt", + crop_size={"height": 214, "width": 214}, + padding="longest", + max_length=76, + ) + self.assertEqual(inputs["pixel_values"].shape[2], 214) + + self.assertEqual(len(inputs["input_ids"][0]), 6) + + def test_structured_kwargs_nested(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + # Define the kwargs for each modality + all_kwargs = { + "common_kwargs": {"return_tensors": "pt"}, + "images_kwargs": {"crop_size": {"height": 214, "width": 214}}, + "text_kwargs": {"padding": "max_length", "max_length": 76}, + } + + inputs = processor(text=input_str, images=image_input, **all_kwargs) + self.skip_processor_without_typed_kwargs(processor) + + self.assertEqual(inputs["pixel_values"].shape[2], 214) + + self.assertEqual(len(inputs["input_ids"][0]), 76) + + def test_structured_kwargs_nested_from_dict(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + # Define the kwargs for each modality + all_kwargs = { + "common_kwargs": {"return_tensors": "pt"}, + "images_kwargs": {"crop_size": {"height": 214, "width": 214}}, + "text_kwargs": {"padding": "max_length", "max_length": 76}, + } + + inputs = processor(text=input_str, images=image_input, **all_kwargs) + self.assertEqual(inputs["pixel_values"].shape[2], 214) + + self.assertEqual(len(inputs["input_ids"][0]), 76) + + def test_unstructured_kwargs(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + inputs = processor( + text=input_str, + images=image_input, + return_tensors="pt", + crop_size={"height": 214, "width": 214}, + padding="max_length", + max_length=76, + ) + + self.assertEqual(inputs["pixel_values"].shape[2], 214) + self.assertEqual(len(inputs["input_ids"][0]), 76) + + def test_image_processor_defaults_preserved_by_image_kwargs(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor", crop_size=(234, 234)) + tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + inputs = processor(text=input_str, images=image_input) + self.assertEqual(len(inputs["pixel_values"][0][0]), 234) + + def test_kwargs_overrides_default_image_processor_kwargs(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor", crop_size=(234, 234)) + tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + inputs = processor(text=input_str, images=image_input, crop_size=[224, 224]) + self.assertEqual(len(inputs["pixel_values"][0][0]), 224) diff --git a/tests/models/instructblip/test_processor_instructblip.py b/tests/models/instructblip/test_processor_instructblip.py index cc929e3575a6d8..e03e555fed0857 100644 --- a/tests/models/instructblip/test_processor_instructblip.py +++ b/tests/models/instructblip/test_processor_instructblip.py @@ -409,3 +409,31 @@ def test_structured_kwargs_nested_from_dict(self): self.assertEqual(inputs["pixel_values"].shape[2], 214) self.assertEqual(len(inputs["input_ids"][0]), 76) + + def test_overlapping_text_kwargs_handling(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + processor_kwargs = {} + processor_kwargs["image_processor"] = self.get_component("image_processor") + processor_kwargs["tokenizer"] = tokenizer = self.get_component("tokenizer") + if not tokenizer.pad_token: + tokenizer.pad_token = "[TEST_PAD]" + if "video_processor" in self.processor_class.attributes: + processor_kwargs["video_processor"] = self.get_component("video_processor") + + qformer_tokenizer = self.get_component("qformer_tokenizer") + + processor = self.processor_class(**processor_kwargs, qformer_tokenizer=qformer_tokenizer) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + with self.assertRaises(ValueError): + _ = processor( + text=input_str, + images=image_input, + return_tensors="pt", + padding="max_length", + text_kwargs={"padding": "do_not_pad"}, + ) diff --git a/tests/models/instructblipvideo/test_processor_instructblipvideo.py b/tests/models/instructblipvideo/test_processor_instructblipvideo.py index 945fe004d12e02..8b29c771759217 100644 --- a/tests/models/instructblipvideo/test_processor_instructblipvideo.py +++ b/tests/models/instructblipvideo/test_processor_instructblipvideo.py @@ -423,3 +423,31 @@ def test_structured_kwargs_nested_from_dict(self): self.assertEqual(inputs["pixel_values"].shape[2], 214) self.assertEqual(len(inputs["input_ids"][0]), 76) + + def test_overlapping_text_kwargs_handling(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + processor_kwargs = {} + processor_kwargs["image_processor"] = self.get_component("image_processor") + processor_kwargs["tokenizer"] = tokenizer = self.get_component("tokenizer") + if not tokenizer.pad_token: + tokenizer.pad_token = "[TEST_PAD]" + if "video_processor" in self.processor_class.attributes: + processor_kwargs["video_processor"] = self.get_component("video_processor") + + qformer_tokenizer = self.get_component("qformer_tokenizer") + + processor = self.processor_class(**processor_kwargs, qformer_tokenizer=qformer_tokenizer) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + with self.assertRaises(ValueError): + _ = processor( + text=input_str, + images=image_input, + return_tensors="pt", + padding="max_length", + text_kwargs={"padding": "do_not_pad"}, + ) diff --git a/tests/test_processing_common.py b/tests/test_processing_common.py index 306f3100fb8d74..fe17de4eeb5cdf 100644 --- a/tests/test_processing_common.py +++ b/tests/test_processing_common.py @@ -146,7 +146,6 @@ def test_tokenizer_defaults_preserved_by_kwargs(self): self.skip_processor_without_typed_kwargs(processor) input_str = "lower newer" image_input = self.prepare_image_inputs() - inputs = processor(text=input_str, images=image_input, return_tensors="pt") self.assertEqual(len(inputs["input_ids"][0]), 117) @@ -175,7 +174,6 @@ def test_kwargs_overrides_default_tokenizer_kwargs(self): self.skip_processor_without_typed_kwargs(processor) input_str = "lower newer" image_input = self.prepare_image_inputs() - inputs = processor( text=input_str, images=image_input, return_tensors="pt", max_length=112, padding="max_length" ) @@ -238,7 +236,6 @@ def test_unstructured_kwargs_batched(self): padding="longest", max_length=76, ) - self.assertEqual(inputs["pixel_values"].shape[2], 214) self.assertEqual(len(inputs["input_ids"][0]), 6) @@ -311,3 +308,30 @@ def test_structured_kwargs_nested_from_dict(self): self.assertEqual(inputs["pixel_values"].shape[2], 214) self.assertEqual(len(inputs["input_ids"][0]), 76) + + # TODO: the same test, but for audio + text processors that have strong overlap in kwargs + # TODO (molbap) use the same structure of attribute kwargs for other tests to avoid duplication + def test_overlapping_text_kwargs_handling(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + processor_kwargs = {} + processor_kwargs["image_processor"] = self.get_component("image_processor") + processor_kwargs["tokenizer"] = tokenizer = self.get_component("tokenizer") + if not tokenizer.pad_token: + tokenizer.pad_token = "[TEST_PAD]" + if "video_processor" in self.processor_class.attributes: + processor_kwargs["video_processor"] = self.get_component("video_processor") + processor = self.processor_class(**processor_kwargs) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + with self.assertRaises(ValueError): + _ = processor( + text=input_str, + images=image_input, + return_tensors="pt", + padding="max_length", + text_kwargs={"padding": "do_not_pad"}, + )