Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add uniform processors for altclip + chinese_clip #31198

Merged
merged 84 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from 67 commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
b85036f
add initial design for uniform processors + align model
molbap Jun 3, 2024
1336931
add uniform processors for altclip + chinese_clip
molbap Jun 3, 2024
bb8ac70
fix mutable default :eyes:
molbap Jun 3, 2024
cd8c601
add configuration test
molbap Jun 3, 2024
f00c852
handle structured kwargs w defaults + add test
molbap Jun 3, 2024
693036f
protect torch-specific test
molbap Jun 3, 2024
766da3a
fix style
molbap Jun 3, 2024
844394d
fix
molbap Jun 3, 2024
7d860a0
rebase
molbap Jun 3, 2024
7cb9925
update processor to generic kwargs + test
molbap Jun 3, 2024
ad4cbf7
fix style
molbap Jun 3, 2024
def56cd
add sensible kwargs merge
molbap Jun 3, 2024
2e6b7e1
update test
molbap Jun 3, 2024
c19bbc6
fix assertEqual
molbap Jun 4, 2024
3c38119
move kwargs merging to processing common
molbap Jun 4, 2024
81ae819
rework kwargs for type hinting
molbap Jun 5, 2024
ce4abcd
just get Unpack from extensions
molbap Jun 7, 2024
3acdf28
run-slow[align]
molbap Jun 7, 2024
404239f
handle kwargs passed as nested dict
molbap Jun 7, 2024
603be40
add from_pretrained test for nested kwargs handling
molbap Jun 7, 2024
71c9d6c
[run-slow]align
molbap Jun 7, 2024
26383c5
update documentation + imports
molbap Jun 7, 2024
4521f4f
update audio inputs
molbap Jun 7, 2024
b96eb64
protect audio types, silly
molbap Jun 7, 2024
9c5c01c
try removing imports
molbap Jun 7, 2024
3ccb505
make things simpler
molbap Jun 7, 2024
142acf3
simplerer
molbap Jun 7, 2024
60a5730
move out kwargs test to common mixin
molbap Jun 10, 2024
be6c141
[run-slow]align
molbap Jun 10, 2024
84135d7
skip tests for old processors
molbap Jun 10, 2024
ce967ac
[run-slow]align, clip
molbap Jun 10, 2024
f78ec52
!$#@!! protect imports, darn it
molbap Jun 10, 2024
52fd5ad
[run-slow]align, clip
molbap Jun 10, 2024
8f21abe
Merge branch 'main' into uniform_processors_1
molbap Jun 10, 2024
d510030
[run-slow]align, clip
molbap Jun 10, 2024
b2f0336
fix conflicts
molbap Jun 10, 2024
40c8a0b
update common processor testing
molbap Jun 10, 2024
2e19860
add altclip
molbap Jun 10, 2024
06b7ae2
add chinese_clip
molbap Jun 10, 2024
2e58518
add pad_size
molbap Jun 10, 2024
aa7a68c
[run-slow]align, clip, chinese_clip, altclip
molbap Jun 10, 2024
f0ca955
remove duplicated tests
molbap Jun 10, 2024
7f61246
fix
molbap Jun 10, 2024
fd43bcd
update doc
molbap Jun 11, 2024
b2cd7c9
improve documentation for default values
molbap Jun 11, 2024
bcbd646
add model_max_length testing
molbap Jun 11, 2024
39c1587
Raise if kwargs are specified in two places
molbap Jun 11, 2024
1f73bdf
fix
molbap Jun 11, 2024
934e612
Merge branch 'uniform_processors_1' into uniform_processors_2
molbap Jun 11, 2024
ee57813
Merge branch 'main' into uniform_processors_2
molbap Jun 11, 2024
bab441f
Merge branch 'main' into uniform_processors_2
molbap Jun 14, 2024
4fd60cf
match defaults
molbap Jun 17, 2024
34d0b61
force padding
molbap Jun 17, 2024
10d727b
fix tokenizer test
molbap Jun 17, 2024
986ed9f
clean defaults
molbap Jun 17, 2024
3c265d1
move tests to common
molbap Jun 17, 2024
24bef68
remove try/catch block
molbap Jun 18, 2024
82e72cf
deprecate kwarg
molbap Jun 18, 2024
5485a5c
format
molbap Jun 18, 2024
8ad6587
add copyright + remove unused method
molbap Jun 18, 2024
76fff78
Merge branch 'main' into uniform_processors_2
molbap Jun 20, 2024
e447062
[run-slow]altclip, chinese_clip
molbap Jun 20, 2024
0ee79aa
clean imports
molbap Jun 25, 2024
4cfd018
fix version
molbap Jun 25, 2024
520240b
clean up deprecation
molbap Jun 26, 2024
823ce00
fix style
molbap Jun 26, 2024
134b589
add corner case test on kwarg overlap
molbap Jul 15, 2024
9978621
resume processing - add Unpack as importable
molbap Aug 9, 2024
182a9ec
Merge branch 'main' into uniform_processors_2
molbap Aug 9, 2024
514aae9
Merge branch 'main' into uniform_processors_2
molbap Aug 14, 2024
f5e2326
add tmpdirname
molbap Aug 14, 2024
357e8ff
fix altclip
molbap Aug 14, 2024
402445b
fix up
molbap Aug 14, 2024
8212d8e
Merge branch 'main' into uniform_processors_2
molbap Sep 18, 2024
eb6a933
add back crop_size to specific tests
molbap Sep 18, 2024
468541c
Merge branch 'main' into uniform_processors_2
molbap Sep 19, 2024
1e950e3
generalize tests to possible video_processor
molbap Sep 19, 2024
ae2e605
add back crop_size arg
molbap Sep 19, 2024
4b3c4f3
fixup overlapping kwargs test for qformer_tokenizer
molbap Sep 19, 2024
5f50aeb
remove copied from
molbap Sep 19, 2024
de5980a
fixup chinese_clip tests values
molbap Sep 19, 2024
61e2664
fixup tests - qformer tokenizers
molbap Sep 19, 2024
58211f5
[run-slow] altclip, chinese_clip
molbap Sep 19, 2024
416e10a
remove prepare_image_inputs
molbap Sep 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/transformers/models/align/processing_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@
Image/Text processor class for ALIGN
"""

import sys
from typing import List, Union


try:
if sys.version_info >= (3, 11):
from typing import Unpack
except ImportError:
else:
from typing_extensions import Unpack

from ...image_utils import ImageInput
Expand Down
80 changes: 50 additions & 30 deletions src/transformers/models/altclip/processing_altclip.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,23 @@
Image/Text processor class for AltCLIP
"""

import warnings
import sys
from typing import List, Union

from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding

if sys.version_info >= (3, 11):
from typing import Unpack
else:
from typing_extensions import Unpack
molbap marked this conversation as resolved.
Show resolved Hide resolved

from ...image_utils import ImageInput
from ...processing_utils import ProcessingKwargs, ProcessorMixin
from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput
from ...utils.deprecation import deprecate_kwarg


class AltClipProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {}


class AltCLIPProcessor(ProcessorMixin):
Expand All @@ -41,25 +54,23 @@ class AltCLIPProcessor(ProcessorMixin):
image_processor_class = "CLIPImageProcessor"
tokenizer_class = ("XLMRobertaTokenizer", "XLMRobertaTokenizerFast")

def __init__(self, image_processor=None, tokenizer=None, **kwargs):
feature_extractor = None
if "feature_extractor" in kwargs:
warnings.warn(
"The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`"
" instead.",
FutureWarning,
)
feature_extractor = kwargs.pop("feature_extractor")

image_processor = image_processor if image_processor is not None else feature_extractor
@deprecate_kwarg(old_name="feature_extractor", version="5.0.0", new_name="image_processor")
def __init__(self, image_processor=None, tokenizer=None):
if image_processor is None:
raise ValueError("You need to specify an `image_processor`.")
if tokenizer is None:
raise ValueError("You need to specify a `tokenizer`.")

super().__init__(image_processor, tokenizer)

def __call__(self, text=None, images=None, return_tensors=None, **kwargs):
def __call__(
self,
images: ImageInput = None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
audio=None,
videos=None,
**kwargs: Unpack[AltClipProcessorKwargs],
) -> BatchEncoding:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to XLMRobertaTokenizerFast's [`~XLMRobertaTokenizerFast.__call__`] if `text` is not
Expand All @@ -68,22 +79,20 @@ def __call__(self, text=None, images=None, return_tensors=None, **kwargs):
of the above two methods for more information.

Args:
text (`str`, `List[str]`, `List[List[str]]`):

images (`ImageInput`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.

return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:

- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.

- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchEncoding`]: A [`BatchEncoding`] with the following fields:

Expand All @@ -95,13 +104,24 @@ def __call__(self, text=None, images=None, return_tensors=None, **kwargs):
"""

if text is None and images is None:
raise ValueError("You have to specify either text or images. Both cannot be none.")
raise ValueError("You must specify either text or images.")

if text is not None:
encoding = self.tokenizer(text, return_tensors=return_tensors, **kwargs)
if text is None and images is None:
raise ValueError("You must specify either text or images.")
output_kwargs = self._merge_kwargs(
AltClipProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
molbap marked this conversation as resolved.
Show resolved Hide resolved
**kwargs,
)

if text is not None:
encoding = self.tokenizer(text, **output_kwargs["text_kwargs"])
if images is not None:
image_features = self.image_processor(images, return_tensors=return_tensors, **kwargs)
image_features = self.image_processor(images, **output_kwargs["images_kwargs"])

# BC for explicit return_tensors
if "return_tensors" in output_kwargs["common_kwargs"]:
return_tensors = output_kwargs["common_kwargs"].pop("return_tensors", None)

if text is not None and images is not None:
encoding["pixel_values"] = image_features.pixel_values
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def preprocess(
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
"""

do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=False)
Expand Down
51 changes: 39 additions & 12 deletions src/transformers/models/chinese_clip/processing_chinese_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,24 @@
Image/Text processor class for Chinese-CLIP
"""

import sys
import warnings
from typing import List, Union

from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding
from ...image_utils import ImageInput
from ...processing_utils import ProcessingKwargs, ProcessorMixin


if sys.version_info >= (3, 11):
from typing import Unpack
else:
from typing_extensions import Unpack

from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput


class ChineseClipProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {}


class ChineseCLIPProcessor(ProcessorMixin):
Expand Down Expand Up @@ -60,7 +74,14 @@ def __init__(self, image_processor=None, tokenizer=None, **kwargs):
super().__init__(image_processor, tokenizer)
self.current_processor = self.image_processor

def __call__(self, text=None, images=None, return_tensors=None, **kwargs):
def __call__(
self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
images: ImageInput = None,
audio=None,
videos=None,
**kwargs: Unpack[ChineseClipProcessorKwargs],
) -> BatchEncoding:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to BertTokenizerFast's [`~BertTokenizerFast.__call__`] if `text` is not `None` to encode
Expand All @@ -79,12 +100,10 @@ def __call__(self, text=None, images=None, return_tensors=None, **kwargs):

return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:

- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.

- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchEncoding`]: A [`BatchEncoding`] with the following fields:

Expand All @@ -97,12 +116,20 @@ def __call__(self, text=None, images=None, return_tensors=None, **kwargs):

if text is None and images is None:
raise ValueError("You have to specify either text or images. Both cannot be none.")
output_kwargs = self._merge_kwargs(
ChineseClipProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)

if text is not None:
encoding = self.tokenizer(text, return_tensors=return_tensors, **kwargs)

encoding = self.tokenizer(text, **output_kwargs["text_kwargs"])
if images is not None:
image_features = self.image_processor(images, return_tensors=return_tensors, **kwargs)
image_features = self.image_processor(images, **output_kwargs["images_kwargs"])

# BC for explicit return_tensors
if "return_tensors" in output_kwargs["common_kwargs"]:
return_tensors = output_kwargs["common_kwargs"].pop("return_tensors", None)

if text is not None and images is not None:
encoding["pixel_values"] = image_features.pixel_values
Expand Down
6 changes: 5 additions & 1 deletion src/transformers/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ class methods and docstrings.
Standard deviation to use if normalizing the image.
do_pad (`bool`, *optional*):
Whether to pad the image to the `(max_height, max_width)` of the images in the batch.
pad_size (`Dict[str, int]`, *optional*):
The size `{"height": int, "width" int}` to pad the images to.
do_center_crop (`bool`, *optional*):
Whether to center crop the image.
data_format (`ChannelDimension` or `str`, *optional*):
Expand All @@ -169,6 +171,7 @@ class methods and docstrings.
image_mean: Optional[Union[float, List[float]]]
image_std: Optional[Union[float, List[float]]]
do_pad: Optional[bool]
molbap marked this conversation as resolved.
Show resolved Hide resolved
pad_size: Optional[Dict[str, int]]
do_center_crop: Optional[bool]
data_format: Optional[ChannelDimension]
input_data_format: Optional[Union[str, ChannelDimension]]
Expand Down Expand Up @@ -753,7 +756,8 @@ class MyProcessingKwargs(ProcessingKwargs, CommonKwargs, TextKwargs, ImagesKwarg
# check if this key was passed as a flat kwarg.
if kwarg_value != "__empty__" and modality_key in non_modality_kwargs:
raise ValueError(
f"Keyword argument {modality_key} was passed two times: in a dictionary for {modality} and as a **kwarg."
f"Keyword argument {modality_key} was passed two times:\n"
f"in a dictionary for {modality} and as a **kwarg."
)
elif modality_key in kwargs:
kwarg_value = kwargs.pop(modality_key, "__empty__")
Expand Down
82 changes: 82 additions & 0 deletions tests/models/altclip/test_processor_altclip.py
molbap marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import unittest
molbap marked this conversation as resolved.
Show resolved Hide resolved

import numpy as np

from transformers import XLMRobertaTokenizer, XLMRobertaTokenizerFast
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_vision_available

from ...test_processing_common import ProcessorTesterMixin


if is_vision_available():
from PIL import Image

from transformers import AltCLIPProcessor, CLIPImageProcessor


@require_vision
class AltClipProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = AltCLIPProcessor

def setUp(self):
self.model_id = "BAAI/AltCLIP"

def get_tokenizer(self, **kwargs):
return XLMRobertaTokenizer.from_pretrained(self.model_id, **kwargs)

def get_rust_tokenizer(self, **kwargs):
return XLMRobertaTokenizerFast.from_pretrained(self.model_id, **kwargs)

def get_image_processor(self, **kwargs):
return CLIPImageProcessor.from_pretrained(self.model_id, **kwargs)

def prepare_image_inputs(self):
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
or a list of PyTorch tensors if one specifies torchify=True.
"""

image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]
image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
return image_inputs
molbap marked this conversation as resolved.
Show resolved Hide resolved

@require_torch
@require_vision
def test_unstructured_kwargs_batched(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")

processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)

input_str = ["lower newer", "upper older longer string"]
image_input = self.prepare_image_inputs() * 2
inputs = processor(
text=input_str,
images=image_input,
return_tensors="pt",
crop_size={"height": 214, "width": 214},
padding="longest",
max_length=76,
)
self.assertEqual(inputs["pixel_values"].shape[2], 214)

self.assertEqual(len(inputs["input_ids"][0]), 7)
6 changes: 5 additions & 1 deletion tests/models/chinese_clip/test_processor_chinese_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from transformers.testing_utils import require_vision
from transformers.utils import FEATURE_EXTRACTOR_NAME, is_vision_available

from ...test_processing_common import ProcessorTesterMixin


if is_vision_available():
from PIL import Image
Expand All @@ -34,7 +36,9 @@


@require_vision
class ChineseCLIPProcessorTest(unittest.TestCase):
class ChineseCLIPProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = ChineseCLIPProcessor

def setUp(self):
self.tmpdirname = tempfile.mkdtemp()

Expand Down
Loading
Loading