Skip to content

Commit

Permalink
add uniform processors for altclip + chinese_clip (huggingface#31198)
Browse files Browse the repository at this point in the history
* add initial design for uniform processors + align model

* add uniform processors for altclip + chinese_clip

* fix mutable default 👀

* add configuration test

* handle structured kwargs w defaults + add test

* protect torch-specific test

* fix style

* fix

* rebase

* update processor to generic kwargs + test

* fix style

* add sensible kwargs merge

* update test

* fix assertEqual

* move kwargs merging to processing common

* rework kwargs for type hinting

* just get Unpack from extensions

* run-slow[align]

* handle kwargs passed as nested dict

* add from_pretrained test for nested kwargs handling

* [run-slow]align

* update documentation + imports

* update audio inputs

* protect audio types, silly

* try removing imports

* make things simpler

* simplerer

* move out kwargs test to common mixin

* [run-slow]align

* skip tests for old processors

* [run-slow]align, clip

* !$#@!! protect imports, darn it

* [run-slow]align, clip

* [run-slow]align, clip

* update common processor testing

* add altclip

* add chinese_clip

* add pad_size

* [run-slow]align, clip, chinese_clip, altclip

* remove duplicated tests

* fix

* update doc

* improve documentation for default values

* add model_max_length testing

This parameter depends on tokenizers received.

* Raise if kwargs are specified in two places

* fix

* match defaults

* force padding

* fix tokenizer test

* clean defaults

* move tests to common

* remove try/catch block

* deprecate kwarg

* format

* add copyright + remove unused method

* [run-slow]altclip, chinese_clip

* clean imports

* fix version

* clean up deprecation

* fix style

* add corner case test on kwarg overlap

* resume processing - add Unpack as importable

* add tmpdirname

* fix altclip

* fix up

* add back crop_size to specific tests

* generalize tests to possible video_processor

* add back crop_size arg

* fixup overlapping kwargs test for qformer_tokenizer

* remove copied from

* fixup chinese_clip tests values

* fixup tests - qformer tokenizers

* [run-slow] altclip, chinese_clip

* remove prepare_image_inputs
  • Loading branch information
molbap authored and BernardZach committed Dec 6, 2024
1 parent 0f0ba12 commit 22118b9
Show file tree
Hide file tree
Showing 10 changed files with 463 additions and 52 deletions.
7 changes: 1 addition & 6 deletions src/transformers/models/align/processing_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,11 @@

from typing import List, Union


try:
from typing import Unpack
except ImportError:
from typing_extensions import Unpack

from ...image_utils import ImageInput
from ...processing_utils import (
ProcessingKwargs,
ProcessorMixin,
Unpack,
)
from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput

Expand Down
73 changes: 43 additions & 30 deletions src/transformers/models/altclip/processing_altclip.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,16 @@
Image/Text processor class for AltCLIP
"""

import warnings
from typing import List, Union

from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding
from ...image_utils import ImageInput
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput
from ...utils.deprecation import deprecate_kwarg


class AltClipProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {}


class AltCLIPProcessor(ProcessorMixin):
Expand All @@ -41,25 +47,23 @@ class AltCLIPProcessor(ProcessorMixin):
image_processor_class = "CLIPImageProcessor"
tokenizer_class = ("XLMRobertaTokenizer", "XLMRobertaTokenizerFast")

def __init__(self, image_processor=None, tokenizer=None, **kwargs):
feature_extractor = None
if "feature_extractor" in kwargs:
warnings.warn(
"The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`"
" instead.",
FutureWarning,
)
feature_extractor = kwargs.pop("feature_extractor")

image_processor = image_processor if image_processor is not None else feature_extractor
@deprecate_kwarg(old_name="feature_extractor", version="5.0.0", new_name="image_processor")
def __init__(self, image_processor=None, tokenizer=None):
if image_processor is None:
raise ValueError("You need to specify an `image_processor`.")
if tokenizer is None:
raise ValueError("You need to specify a `tokenizer`.")

super().__init__(image_processor, tokenizer)

def __call__(self, text=None, images=None, return_tensors=None, **kwargs):
def __call__(
self,
images: ImageInput = None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
audio=None,
videos=None,
**kwargs: Unpack[AltClipProcessorKwargs],
) -> BatchEncoding:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to XLMRobertaTokenizerFast's [`~XLMRobertaTokenizerFast.__call__`] if `text` is not
Expand All @@ -68,22 +72,20 @@ def __call__(self, text=None, images=None, return_tensors=None, **kwargs):
of the above two methods for more information.
Args:
text (`str`, `List[str]`, `List[List[str]]`):
images (`ImageInput`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
Expand All @@ -95,13 +97,24 @@ def __call__(self, text=None, images=None, return_tensors=None, **kwargs):
"""

if text is None and images is None:
raise ValueError("You have to specify either text or images. Both cannot be none.")
raise ValueError("You must specify either text or images.")

if text is not None:
encoding = self.tokenizer(text, return_tensors=return_tensors, **kwargs)
if text is None and images is None:
raise ValueError("You must specify either text or images.")
output_kwargs = self._merge_kwargs(
AltClipProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)

if text is not None:
encoding = self.tokenizer(text, **output_kwargs["text_kwargs"])
if images is not None:
image_features = self.image_processor(images, return_tensors=return_tensors, **kwargs)
image_features = self.image_processor(images, **output_kwargs["images_kwargs"])

# BC for explicit return_tensors
if "return_tensors" in output_kwargs["common_kwargs"]:
return_tensors = output_kwargs["common_kwargs"].pop("return_tensors", None)

if text is not None and images is not None:
encoding["pixel_values"] = image_features.pixel_values
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ def preprocess(
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
"""

do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=False)
Expand Down
43 changes: 31 additions & 12 deletions src/transformers/models/chinese_clip/processing_chinese_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,15 @@
"""

import warnings
from typing import List, Union

from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding
from ...image_utils import ImageInput
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput


class ChineseClipProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {}


class ChineseCLIPProcessor(ProcessorMixin):
Expand Down Expand Up @@ -60,7 +66,14 @@ def __init__(self, image_processor=None, tokenizer=None, **kwargs):
super().__init__(image_processor, tokenizer)
self.current_processor = self.image_processor

def __call__(self, text=None, images=None, return_tensors=None, **kwargs):
def __call__(
self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
images: ImageInput = None,
audio=None,
videos=None,
**kwargs: Unpack[ChineseClipProcessorKwargs],
) -> BatchEncoding:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to BertTokenizerFast's [`~BertTokenizerFast.__call__`] if `text` is not `None` to encode
Expand All @@ -79,12 +92,10 @@ def __call__(self, text=None, images=None, return_tensors=None, **kwargs):
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
Expand All @@ -97,12 +108,20 @@ def __call__(self, text=None, images=None, return_tensors=None, **kwargs):

if text is None and images is None:
raise ValueError("You have to specify either text or images. Both cannot be none.")
output_kwargs = self._merge_kwargs(
ChineseClipProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)

if text is not None:
encoding = self.tokenizer(text, return_tensors=return_tensors, **kwargs)

encoding = self.tokenizer(text, **output_kwargs["text_kwargs"])
if images is not None:
image_features = self.image_processor(images, return_tensors=return_tensors, **kwargs)
image_features = self.image_processor(images, **output_kwargs["images_kwargs"])

# BC for explicit return_tensors
if "return_tensors" in output_kwargs["common_kwargs"]:
return_tensors = output_kwargs["common_kwargs"].pop("return_tensors", None)

if text is not None and images is not None:
encoding["pixel_values"] = image_features.pixel_values
Expand Down
14 changes: 13 additions & 1 deletion src/transformers/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@
import inspect
import json
import os
import sys
import typing
import warnings
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union

import numpy as np
import typing_extensions

from .dynamic_module_utils import custom_object_save
from .image_utils import ChannelDimension, is_valid_image, is_vision_available
Expand Down Expand Up @@ -67,6 +70,11 @@
"AutoImageProcessor": "ImageProcessingMixin",
}

if sys.version_info >= (3, 11):
Unpack = typing.Unpack
else:
Unpack = typing_extensions.Unpack


class TextKwargs(TypedDict, total=False):
"""
Expand Down Expand Up @@ -151,6 +159,8 @@ class methods and docstrings.
Standard deviation to use if normalizing the image.
do_pad (`bool`, *optional*):
Whether to pad the image to the `(max_height, max_width)` of the images in the batch.
pad_size (`Dict[str, int]`, *optional*):
The size `{"height": int, "width" int}` to pad the images to.
do_center_crop (`bool`, *optional*):
Whether to center crop the image.
data_format (`ChannelDimension` or `str`, *optional*):
Expand All @@ -170,6 +180,7 @@ class methods and docstrings.
image_mean: Optional[Union[float, List[float]]]
image_std: Optional[Union[float, List[float]]]
do_pad: Optional[bool]
pad_size: Optional[Dict[str, int]]
do_center_crop: Optional[bool]
data_format: Optional[ChannelDimension]
input_data_format: Optional[Union[str, ChannelDimension]]
Expand Down Expand Up @@ -814,7 +825,8 @@ class MyProcessingKwargs(ProcessingKwargs, CommonKwargs, TextKwargs, ImagesKwarg
# check if this key was passed as a flat kwarg.
if kwarg_value != "__empty__" and modality_key in non_modality_kwargs:
raise ValueError(
f"Keyword argument {modality_key} was passed two times: in a dictionary for {modality} and as a **kwarg."
f"Keyword argument {modality_key} was passed two times:\n"
f"in a dictionary for {modality} and as a **kwarg."
)
elif modality_key in kwargs:
kwarg_value = kwargs.pop(modality_key, "__empty__")
Expand Down
Loading

0 comments on commit 22118b9

Please sign in to comment.