Skip to content

Commit

Permalink
add initial design for uniform processors + align model (huggingface#…
Browse files Browse the repository at this point in the history
…31197)

* add initial design for uniform processors + align model

* fix mutable default 👀

* add configuration test

* handle structured kwargs w defaults + add test

* protect torch-specific test

* fix style

* fix

* fix assertEqual

* move kwargs merging to processing common

* rework kwargs for type hinting

* just get Unpack from extensions

* run-slow[align]

* handle kwargs passed as nested dict

* add from_pretrained test for nested kwargs handling

* [run-slow]align

* update documentation + imports

* update audio inputs

* protect audio types, silly

* try removing imports

* make things simpler

* simplerer

* move out kwargs test to common mixin

* [run-slow]align

* skip tests for old processors

* [run-slow]align, clip

* !$#@!! protect imports, darn it

* [run-slow]align, clip

* [run-slow]align, clip

* update doc

* improve documentation for default values

* add model_max_length testing

This parameter depends on tokenizers received.

* Raise if kwargs are specified in two places

* fix

* expand VideoInput

* fix

* fix style

* remove defaults values

* add comment to indicate documentation on adding kwargs

* protect imports

* [run-slow]align

* fix

* remove set() that breaks ordering

* test more

* removed unused func

* [run-slow]align
  • Loading branch information
molbap authored and zucchini-nlp committed Jun 14, 2024
1 parent 3b52c27 commit aaa7b56
Show file tree
Hide file tree
Showing 6 changed files with 676 additions and 28 deletions.
11 changes: 10 additions & 1 deletion src/transformers/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,16 @@
] # noqa


VideoInput = Union[np.ndarray, "torch.Tensor", List[np.ndarray], List["torch.Tensor"]] # noqa
VideoInput = Union[
List["PIL.Image.Image"],
"np.ndarray",
"torch.Tensor",
List["np.ndarray"],
List["torch.Tensor"],
List[List["PIL.Image.Image"]],
List[List["np.ndarrray"]],
List[List["torch.Tensor"]],
] # noqa


class ChannelDimension(ExplicitEnum):
Expand Down
90 changes: 67 additions & 23 deletions src/transformers/models/align/processing_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,30 @@
Image/Text processor class for ALIGN
"""

from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding
from typing import List, Union


try:
from typing import Unpack
except ImportError:
from typing_extensions import Unpack

from ...image_utils import ImageInput
from ...processing_utils import (
ProcessingKwargs,
ProcessorMixin,
)
from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput


class AlignProcessorKwargs(ProcessingKwargs, total=False):
# see processing_utils.ProcessingKwargs documentation for usage.
_defaults = {
"text_kwargs": {
"padding": "max_length",
"max_length": 64,
},
}


class AlignProcessor(ProcessorMixin):
Expand All @@ -26,12 +48,28 @@ class AlignProcessor(ProcessorMixin):
[`BertTokenizer`]/[`BertTokenizerFast`] into a single processor that interits both the image processor and
tokenizer functionalities. See the [`~AlignProcessor.__call__`] and [`~OwlViTProcessor.decode`] for more
information.
The preferred way of passing kwargs is as a dictionary per modality, see usage example below.
```python
from transformers import AlignProcessor
from PIL import Image
model_id = "kakaobrain/align-base"
processor = AlignProcessor.from_pretrained(model_id)
processor(
images=your_pil_image,
text=["What is that?"],
images_kwargs = {"crop_size": {"height": 224, "width": 224}},
text_kwargs = {"padding": "do_not_pad"},
common_kwargs = {"return_tensors": "pt"},
)
```
Args:
image_processor ([`EfficientNetImageProcessor`]):
The image processor is a required input.
tokenizer ([`BertTokenizer`, `BertTokenizerFast`]):
The tokenizer is a required input.
"""

attributes = ["image_processor", "tokenizer"]
Expand All @@ -41,11 +79,18 @@ class AlignProcessor(ProcessorMixin):
def __init__(self, image_processor, tokenizer):
super().__init__(image_processor, tokenizer)

def __call__(self, text=None, images=None, padding="max_length", max_length=64, return_tensors=None, **kwargs):
def __call__(
self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
images: ImageInput = None,
audio=None,
videos=None,
**kwargs: Unpack[AlignProcessorKwargs],
) -> BatchEncoding:
"""
Main method to prepare text(s) and image(s) to be fed as input to the model. This method forwards the `text`
and `kwargs` arguments to BertTokenizerFast's [`~BertTokenizerFast.__call__`] if `text` is not `None` to encode
the text. To prepare the image(s), this method forwards the `images` and `kwargs` arguments to
arguments to BertTokenizerFast's [`~BertTokenizerFast.__call__`] if `text` is not `None` to encode
the text. To prepare the image(s), this method forwards the `images` arguments to
EfficientNetImageProcessor's [`~EfficientNetImageProcessor.__call__`] if `images` is not `None`. Please refer
to the doctsring of the above two methods for more information.
Expand All @@ -57,20 +102,12 @@ def __call__(self, text=None, images=None, padding="max_length", max_length=64,
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `max_length`):
Activates and controls padding for tokenization of input text. Choose between [`True` or `'longest'`,
`'max_length'`, `False` or `'do_not_pad'`]
max_length (`int`, *optional*, defaults to `max_length`):
Maximum padding value to use to pad the input text during tokenization.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
Expand All @@ -81,15 +118,22 @@ def __call__(self, text=None, images=None, padding="max_length", max_length=64,
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
if text is None and images is None:
raise ValueError("You have to specify either text or images. Both cannot be none.")

raise ValueError("You must specify either text or images.")
output_kwargs = self._merge_kwargs(
AlignProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
# then, we can pass correct kwargs to each processor
if text is not None:
encoding = self.tokenizer(
text, padding=padding, max_length=max_length, return_tensors=return_tensors, **kwargs
)
encoding = self.tokenizer(text, **output_kwargs["text_kwargs"])

if images is not None:
image_features = self.image_processor(images, return_tensors=return_tensors, **kwargs)
image_features = self.image_processor(images, **output_kwargs["images_kwargs"])

# BC for explicit return_tensors
if "return_tensors" in output_kwargs["common_kwargs"]:
return_tensors = output_kwargs["common_kwargs"].pop("return_tensors", None)

if text is not None and images is not None:
encoding["pixel_values"] = image_features.pixel_values
Expand Down
Loading

0 comments on commit aaa7b56

Please sign in to comment.