Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add image text to text pipeline #34170

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
de37f38
Standardize image-text-to-text-models-output
yonigozlan Aug 6, 2024
a7e56fa
nit var name post_process_image_text_to_text udop
yonigozlan Oct 11, 2024
9161a52
nit fix deprecation warnings
yonigozlan Oct 11, 2024
04d4f77
Add image-text-to-text pipeline
yonigozlan Oct 11, 2024
0aabcb2
add support for image url in chat template for pipeline
yonigozlan Oct 14, 2024
5189f16
Reformat to be fully compatible with chat templates
yonigozlan Oct 14, 2024
a0c9075
Add tests chat template
yonigozlan Oct 15, 2024
726933f
Fix imports and tests
yonigozlan Oct 15, 2024
b1d7a34
Add pipeline tag
yonigozlan Oct 15, 2024
f92628c
change logic handling of single prompt ans multiple images
yonigozlan Oct 15, 2024
fa25411
add pipeline mapping to models
yonigozlan Oct 15, 2024
9fe26c7
fix batched inference
yonigozlan Oct 15, 2024
316cf7d
fix tests
yonigozlan Oct 15, 2024
c633b4b
Add manual batching for preprocessing
yonigozlan Oct 16, 2024
d6598da
Fix outputs with nested images
yonigozlan Oct 16, 2024
c8e5802
Add support for all common processing kwargs
yonigozlan Oct 17, 2024
20cdd5a
Add default padding when multiple text inputs (batch size>1)
yonigozlan Oct 17, 2024
5bc43be
nit change version deprecation warning
yonigozlan Oct 17, 2024
6cccf5f
Add support for text only inference
yonigozlan Oct 17, 2024
ba8f85f
add chat_template warnings
yonigozlan Oct 21, 2024
00174e8
Add pipeline tests and add copied from post process function
yonigozlan Oct 24, 2024
8a65ea4
Fix batched pipeline tests
yonigozlan Oct 24, 2024
da05987
nit
yonigozlan Oct 24, 2024
1f2dafb
Fix pipeline tests blip2
yonigozlan Oct 24, 2024
d66e523
remove unnecessary max_new_tokens
yonigozlan Oct 24, 2024
5056aa5
revert processing kosmos2 and remove unnecessary max_new_tokens
yonigozlan Oct 24, 2024
fe7e75d
fix pipeline tests idefics
yonigozlan Oct 24, 2024
b866c27
Force try loading processor if pipeline supports it
yonigozlan Oct 25, 2024
3118dac
revert load_processor change
yonigozlan Oct 25, 2024
065542a
hardcode loading only processor
yonigozlan Oct 25, 2024
7f583df
remove unnecessary try except
yonigozlan Oct 25, 2024
aad9ad4
skip imagetexttotext tests for kosmos2 as tiny model causes problems
yonigozlan Oct 25, 2024
e227b83
Make code clearer
yonigozlan Oct 25, 2024
8f370f4
Address review comments
yonigozlan Oct 25, 2024
c82fe29
remove preprocessing logic from pipeline
yonigozlan Oct 28, 2024
7e1fb07
fix fuyu
yonigozlan Oct 28, 2024
f581eaa
add BC resize fuyu
yonigozlan Oct 28, 2024
4eda963
Move post_process_image_text_to_text to ProcessorMixin
yonigozlan Oct 28, 2024
0263221
add guard in post_process
yonigozlan Oct 28, 2024
45c1706
fix zero shot object detection pipeline
yonigozlan Oct 29, 2024
2e69b97
add support for generator input in pipeline
yonigozlan Oct 29, 2024
58a6fb8
nit
yonigozlan Oct 29, 2024
66c017c
change default image-text-to-text model to llava onevision
yonigozlan Oct 29, 2024
5772312
fix owlv2 size dict
yonigozlan Oct 30, 2024
61cc576
Change legacy deprecation warning to only show when True
yonigozlan Oct 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/en/main_classes/pipelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,12 @@ Pipelines available for multimodal tasks include the following.
- __call__
- all

### ImageTextToTextPipeline

[[autodoc]] ImageTextToTextPipeline
- __call__
- all

### MaskGenerationPipeline

[[autodoc]] MaskGenerationPipeline
Expand Down
6 changes: 6 additions & 0 deletions docs/source/ja/main_classes/pipelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,12 @@ my_pipeline = pipeline(model="xxxx", pipeline_class=MyPipeline)
- __call__
- all

### ImageTextToTextPipeline

[[autodoc]] ImageTextToTextPipeline
- __call__
- all

### VisualQuestionAnsweringPipeline

[[autodoc]] VisualQuestionAnsweringPipeline
Expand Down
6 changes: 6 additions & 0 deletions docs/source/zh/main_classes/pipelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,12 @@ See [`TokenClassificationPipeline`] for all details.
- __call__
- all

### ImageTextToTextPipeline

[[autodoc]] ImageTextToTextPipeline
- __call__
- all

### MaskGenerationPipeline

[[autodoc]] MaskGenerationPipeline
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,7 @@
"ImageClassificationPipeline",
"ImageFeatureExtractionPipeline",
"ImageSegmentationPipeline",
"ImageTextToTextPipeline",
"ImageToImagePipeline",
"ImageToTextPipeline",
"JsonPipelineDataFormat",
Expand Down Expand Up @@ -5794,6 +5795,7 @@
ImageClassificationPipeline,
ImageFeatureExtractionPipeline,
ImageSegmentationPipeline,
ImageTextToTextPipeline,
ImageToImagePipeline,
ImageToTextPipeline,
JsonPipelineDataFormat,
Expand Down
21 changes: 21 additions & 0 deletions src/transformers/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,27 @@ def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] =
return image


def load_images(
images: Union[List, Tuple, str, "PIL.Image.Image"], timeout: Optional[float] = None
) -> Union["PIL.Image.Image", List["PIL.Image.Image"], List[List["PIL.Image.Image"]]]:
"""Loads images, handling different levels of nesting.

Args:
images: A single image, a list of images, or a list of lists of images to load.
timeout: Timeout for loading images.

Returns:
A single image, a list of images, a list of lists of images.
"""
if isinstance(images, (list, tuple)):
if len(images) and isinstance(images[0], (list, tuple)):
return [[load_image(image, timeout=timeout) for image in image_group] for image_group in images]
else:
return [load_image(image, timeout=timeout) for image in images]
else:
return load_image(images, timeout=timeout)


def validate_preprocess_arguments(
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/image_processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@
("oneformer", ("OneFormerImageProcessor",)),
("owlv2", ("Owlv2ImageProcessor",)),
("owlvit", ("OwlViTImageProcessor",)),
("paligemma", ("SiglipImageProcessor",)),
("perceiver", ("PerceiverImageProcessor",)),
("pix2struct", ("Pix2StructImageProcessor",)),
("pixtral", ("PixtralImageProcessor",)),
Expand Down
16 changes: 16 additions & 0 deletions src/transformers/models/donut/processing_donut.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,16 @@
from ...image_utils import ImageInput
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import logging


class DonutProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {}


logger = logging.get_logger(__name__)


class DonutProcessor(ProcessorMixin):
r"""
Constructs a Donut processor which wraps a Donut image processor and an XLMRoBERTa tokenizer into a single
Expand Down Expand Up @@ -85,6 +89,16 @@ def __call__(
[`~DonutTokenizer.__call__`]. Please refer to the doctsring of the above two methods for more information.
"""
# For backward compatibility
legacy = kwargs.pop("legacy", True)
if legacy:
# With `add_special_tokens=True`, the performance of donut are degraded when working with both images and text.
logger.warning_once(
"Legacy behavior is being used. The current behavior will be deprecated in version 5.0.0. "
"In the new behavior, if both images and text are provided, the default value of `add_special_tokens` "
"will be changed to `False` when calling the tokenizer if `add_special_tokens` is unset. "
"To test the new behavior, set `legacy=False`as a processor call argument."
)

if self._in_target_context_manager:
return self.current_processor(images, text, **kwargs)

Expand All @@ -100,6 +114,8 @@ def __call__(
if images is not None:
inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
if text is not None:
if not legacy and images is not None:
output_kwargs["text_kwargs"].setdefault("add_special_tokens", False)
encodings = self.tokenizer(text, **output_kwargs["text_kwargs"])

if text is None:
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/fuyu/image_processing_fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import numpy as np

from ...image_processing_utils import BaseImageProcessor, BatchFeature
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import (
pad,
resize,
Expand Down Expand Up @@ -475,6 +475,7 @@ def preprocess(
input_data_format = infer_channel_dimension_format(batch_images[0][0])

original_image_sizes = [get_image_size(images[0], channel_dim=input_data_format) for images in batch_images]
size = get_size_dict(size) # for BC

if do_resize:
batch_images = [
Expand Down
30 changes: 28 additions & 2 deletions src/transformers/models/fuyu/processing_fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,10 @@ def _tokenize_prompts_with_image_and_batch(
bos_token = tokenizer.vocab["|ENDOFTEXT|"]
prompts_tokens = [[[bos_token] + x for x in prompt_seq] for prompt_seq in prompts_tokens]
if add_beginning_of_answer_token:
boa = tokenizer.vocab[BEGINNING_OF_ANSWER_STRING]
beginning_of_answer = tokenizer.vocab[BEGINNING_OF_ANSWER_STRING]
# Only add bbox open token to the last subsequence since that is what will be completed
for token_seq in prompts_tokens:
token_seq[-1].append(boa)
token_seq[-1].append(beginning_of_answer)

# Now we have a list of list of tokens which each list has a different
# size. We want to extend this list to:
Expand Down Expand Up @@ -682,6 +682,32 @@ def tokens_to_points(tokens, original_size):

return results

def post_process_image_text_to_text(self, generated_outputs):
"""
Post-processes the output of `FuyuForConditionalGeneration` to only return the text output.

Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
containing the token ids of the generated sequences.

Returns:
`List[str]`: The decoded text output.
"""
beginning_of_answer = self.tokenizer.convert_tokens_to_ids(BEGINNING_OF_ANSWER_STRING)
# get boa index for each outputted sequence tensor
# start all generated sequences from the beginning of the answer token, pad to have consistent length
unpadded_output_sequences = [
seq[(seq == beginning_of_answer).nonzero(as_tuple=True)[0] + 1 :] for seq in generated_outputs
]
max_len = max(len(seq) for seq in unpadded_output_sequences)
# convert to torch and pad sequences
padded_output_sequences = torch.full((len(unpadded_output_sequences), max_len), self.pad_token_id)
for i, seq in enumerate(unpadded_output_sequences):
padded_output_sequences[i, : len(seq)] = torch.tensor(seq)

return self.batch_decode(padded_output_sequences, skip_special_tokens=True)

def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
Expand Down
17 changes: 17 additions & 0 deletions src/transformers/models/git/processing_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,16 @@
from ...image_utils import ImageInput
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import logging


class GitProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {}


logger = logging.get_logger(__name__)


class GitProcessor(ProcessorMixin):
r"""
Constructs a GIT processor which wraps a CLIP image processor and a BERT tokenizer into a single processor.
Expand Down Expand Up @@ -91,6 +95,15 @@ def __call__(
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
legacy = kwargs.pop("legacy", True)
if legacy:
logger.warning_once(
"Legacy behavior is being used. The current behavior will be deprecated in version 5.0.0. "
"In the new behavior, if both images and text are provided, the last token (EOS token) "
"of the input_ids and attention_mask tensors will be removed. "
"To test the new behavior, set `legacy=False`as a processor call argument."
)

if text is None and images is None:
raise ValueError("You have to specify either text or images. Both cannot be none.")

Expand All @@ -110,6 +123,10 @@ def __call__(
if images is not None:
image_features = self.image_processor(images, **output_kwargs["images_kwargs"])
data.update(image_features)
if not legacy:
data["input_ids"] = data["input_ids"][:, :-1]
data["attention_mask"] = data["attention_mask"][:, :-1]

return BatchFeature(data=data, tensor_type=output_kwargs["common_kwargs"].get("return_tensors"))

def batch_decode(self, *args, **kwargs):
Expand Down
15 changes: 15 additions & 0 deletions src/transformers/models/kosmos2/processing_kosmos2.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,21 @@ def post_process_generation(self, text, cleanup_and_extract=True):
return clean_text_and_extract_entities_with_bboxes(caption)
return caption

def post_process_image_text_to_text(self, generated_outputs):
"""
Post-process the output of the model to decode the text.

Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`.

Returns:
`List[str]`: The decoded text.
"""
generated_texts = self.batch_decode(generated_outputs, skip_special_tokens=True)
return [self.post_process_generation(text, cleanup_and_extract=False) for text in generated_texts]

@property
# Copied from transformers.models.blip.processing_blip.BlipProcessor.model_input_names
def model_input_names(self):
Expand Down
16 changes: 16 additions & 0 deletions src/transformers/models/mllama/processing_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,22 @@ def decode(self, *args, **kwargs):
"""
return self.tokenizer.decode(*args, **kwargs)

def post_process_image_text_to_text(self, generated_outputs):
"""
Post-process the output of the model to decode the text.

Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`.

Returns:
`List[str]`: The decoded text.
"""
return self.tokenizer.batch_decode(
generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
)

@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/owlv2/image_processing_owlv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import numpy as np

from ...image_processing_utils import BaseImageProcessor, BatchFeature
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import (
center_to_corners_format,
pad,
Expand Down Expand Up @@ -399,6 +399,7 @@ def preprocess(
image_std = image_std if image_std is not None else self.image_std

size = size if size is not None else self.size
size = get_size_dict(size) # for BC

images = make_list_of_images(images)

Expand Down
20 changes: 20 additions & 0 deletions src/transformers/models/pix2struct/processing_pix2struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ...feature_extraction_utils import BatchFeature
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput
from ...utils import logging


class Pix2StructImagesKwargs(ImagesKwargs, total=False):
Expand Down Expand Up @@ -48,6 +49,9 @@ class Pix2StructProcessorKwargs(ProcessingKwargs, total=False):
}


logger = logging.get_logger(__name__)


class Pix2StructProcessor(ProcessorMixin):
r"""
Constructs a PIX2STRUCT processor which wraps a BERT tokenizer and PIX2STRUCT image processor into a single
Expand Down Expand Up @@ -85,6 +89,15 @@ def __call__(

Please refer to the docstring of the above two methods for more information.
"""
legacy = kwargs.pop("legacy", True)
if legacy:
logger.warning_once(
"Legacy behavior is being used. The current behavior will be deprecated in version 5.0.0. "
"In the new behavior, If both images and text are provided, image_processor is not a VQA processor, and `add_special_tokens` is unset, "
"the default value of `add_special_tokens` will be changed to `False` when calling the tokenizer. "
"To test the new behavior, set `legacy=False`as a processor call argument."
)

if images is None and text is None:
raise ValueError("You have to specify either images or text.")

Expand All @@ -93,8 +106,12 @@ def __call__(
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
add_special_tokens = output_kwargs["text_kwargs"].pop("add_special_tokens", None)
# Get only text
if images is None and not self.image_processor.is_vqa:
output_kwargs["text_kwargs"]["add_special_tokens"] = (
add_special_tokens if add_special_tokens is not None else True
)
self.current_processor = self.tokenizer
text_encoding = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
return text_encoding
Expand All @@ -108,6 +125,9 @@ def __call__(
encoding_image_processor = self.image_processor(images, **output_kwargs["images_kwargs"])

if text is not None and not self.image_processor.is_vqa:
output_kwargs["text_kwargs"]["add_special_tokens"] = (
add_special_tokens if add_special_tokens is not None else legacy
)
text_encoding = self.tokenizer(text=text, **output_kwargs["text_kwargs"])

if "attention_mask" in text_encoding:
Expand Down
16 changes: 16 additions & 0 deletions src/transformers/models/qwen2_vl/processing_qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,22 @@ def decode(self, *args, **kwargs):
"""
return self.tokenizer.decode(*args, **kwargs)

def post_process_image_text_to_text(self, generated_outputs):
"""
Post-process the output of the model to decode the text.

Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`.

Returns:
`List[str]`: The decoded text.
"""
return self.tokenizer.batch_decode(
generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
)

@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
Expand Down
Loading
Loading