Skip to content

Commit

Permalink
Make pipeline able to load processor (#32514)
Browse files Browse the repository at this point in the history
* Refactor get_test_pipeline

* Fixup

* Fixing tests

* Add processor loading in tests

* Restructure processors loading

* Add processor to the pipeline

* Move model loading on tom of the test

* Update `get_test_pipeline`

* Fixup

* Add class-based flags for loading processors

* Change `is_pipeline_test_to_skip` signature

* Skip t5 failing test for slow tokenizer

* Fixup

* Fix copies for T5

* Fix typo

* Add try/except for tokenizer loading (kosmos-2 case)

* Fixup

* Llama not fails for long generation

* Revert processor pass in text-generation test

* Fix docs

* Switch back to json file for image processors and feature extractors

* Add processor type check

* Remove except for tokenizers

* Fix docstring

* Fix empty lists for tests

* Fixup

* Fix load check

* Ensure we have non-empty test cases

* Update src/transformers/pipelines/__init__.py

Co-authored-by: Lysandre Debut <hi@lysand.re>

* Update src/transformers/pipelines/base.py

Co-authored-by: Lysandre Debut <hi@lysand.re>

* Rework comment

* Better docs, add note about pipeline components

* Change warning to error raise

* Fixup

* Refine pipeline docs

---------

Co-authored-by: Lysandre Debut <hi@lysand.re>
  • Loading branch information
qubvel and LysandreJik authored Oct 9, 2024
1 parent 4fb2870 commit 48461c0
Show file tree
Hide file tree
Showing 91 changed files with 1,311 additions and 240 deletions.
80 changes: 71 additions & 9 deletions src/transformers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor
from ..models.auto.image_processing_auto import IMAGE_PROCESSOR_MAPPING, AutoImageProcessor
from ..models.auto.modeling_auto import AutoModelForDepthEstimation, AutoModelForImageToImage
from ..models.auto.processing_auto import PROCESSOR_MAPPING, AutoProcessor
from ..models.auto.tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
from ..processing_utils import ProcessorMixin
from ..tokenization_utils import PreTrainedTokenizer
from ..utils import (
CONFIG_NAME,
Expand Down Expand Up @@ -556,6 +558,7 @@ def pipeline(
tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None,
feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None,
image_processor: Optional[Union[str, BaseImageProcessor]] = None,
processor: Optional[Union[str, ProcessorMixin]] = None,
framework: Optional[str] = None,
revision: Optional[str] = None,
use_fast: bool = True,
Expand All @@ -571,11 +574,19 @@ def pipeline(
"""
Utility factory method to build a [`Pipeline`].
Pipelines are made of:
A pipeline consists of:
- A [tokenizer](tokenizer) in charge of mapping raw textual input to token.
- A [model](model) to make predictions from the inputs.
- Some (optional) post processing for enhancing model's output.
- One or more components for pre-processing model inputs, such as a [tokenizer](tokenizer),
[image_processor](image_processor), [feature_extractor](feature_extractor), or [processor](processors).
- A [model](model) that generates predictions from the inputs.
- Optional post-processing steps to refine the model's output, which can also be handled by processors.
<Tip>
While there are such optional arguments as `tokenizer`, `feature_extractor`, `image_processor`, and `processor`,
they shouldn't be specified all at once. If these components are not provided, `pipeline` will try to load
required ones automatically. In case you want to provide these components explicitly, please refer to a
specific pipeline in order to get more details regarding what components are required.
</Tip>
Args:
task (`str`):
Expand Down Expand Up @@ -644,6 +655,25 @@ def pipeline(
`model` is not specified or not a string, then the default feature extractor for `config` is loaded (if it
is a string). However, if `config` is also not given or not a string, then the default feature extractor
for the given `task` will be loaded.
image_processor (`str` or [`BaseImageProcessor`], *optional*):
The image processor that will be used by the pipeline to preprocess images for the model. This can be a
model identifier or an actual image processor inheriting from [`BaseImageProcessor`].
Image processors are used for Vision models and multi-modal models that require image inputs. Multi-modal
models will also require a tokenizer to be passed.
If not provided, the default image processor for the given `model` will be loaded (if it is a string). If
`model` is not specified or not a string, then the default image processor for `config` is loaded (if it is
a string).
processor (`str` or [`ProcessorMixin`], *optional*):
The processor that will be used by the pipeline to preprocess data for the model. This can be a model
identifier or an actual processor inheriting from [`ProcessorMixin`].
Processors are used for multi-modal models that require multi-modal inputs, for example, a model that
requires both text and image inputs.
If not provided, the default processor for the given `model` will be loaded (if it is a string). If `model`
is not specified or not a string, then the default processor for `config` is loaded (if it is a string).
framework (`str`, *optional*):
The framework to use, either `"pt"` for PyTorch or `"tf"` for TensorFlow. The specified framework must be
installed.
Expand Down Expand Up @@ -905,13 +935,17 @@ def pipeline(

model_config = model.config
hub_kwargs["_commit_hash"] = model.config._commit_hash
load_tokenizer = (
type(model_config) in TOKENIZER_MAPPING
or model_config.tokenizer_class is not None
or isinstance(tokenizer, str)
)

load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None
load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING or feature_extractor is not None
load_image_processor = type(model_config) in IMAGE_PROCESSOR_MAPPING or image_processor is not None
load_processor = type(model_config) in PROCESSOR_MAPPING or processor is not None

# Check that pipeline class required loading
load_tokenizer = load_tokenizer and pipeline_class._load_tokenizer
load_feature_extractor = load_feature_extractor and pipeline_class._load_feature_extractor
load_image_processor = load_image_processor and pipeline_class._load_image_processor
load_processor = load_processor and pipeline_class._load_processor

# If `model` (instance of `PretrainedModel` instead of `str`) is passed (and/or same for config), while
# `image_processor` or `feature_extractor` is `None`, the loading will fail. This happens particularly for some
Expand Down Expand Up @@ -1074,6 +1108,31 @@ def pipeline(
if not is_pyctcdecode_available():
logger.warning("Try to install `pyctcdecode`: `pip install pyctcdecode")

if load_processor:
# Try to infer processor from model or config name (if provided as str)
if processor is None:
if isinstance(model_name, str):
processor = model_name
elif isinstance(config, str):
processor = config
else:
# Impossible to guess what is the right processor here
raise Exception(
"Impossible to guess which processor to use. "
"Please provide a processor instance or a path/identifier "
"to a processor."
)

# Instantiate processor if needed
if isinstance(processor, (str, tuple)):
processor = AutoProcessor.from_pretrained(processor, _from_pipeline=task, **hub_kwargs, **model_kwargs)
if not isinstance(processor, ProcessorMixin):
raise TypeError(
"Processor was loaded, but it is not an instance of `ProcessorMixin`. "
f"Got type `{type(processor)}` instead. Please check that you specified "
"correct pipeline task for the model and model has processor implemented and saved."
)

if task == "translation" and model.config.task_specific_params:
for key in model.config.task_specific_params:
if key.startswith("translation"):
Expand All @@ -1099,4 +1158,7 @@ def pipeline(
if device is not None:
kwargs["device"] = device

if processor is not None:
kwargs["processor"] = processor

return pipeline_class(model=model, framework=framework, task=task, **kwargs)
38 changes: 36 additions & 2 deletions src/transformers/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from ..image_processing_utils import BaseImageProcessor
from ..modelcard import ModelCard
from ..models.auto.configuration_auto import AutoConfig
from ..processing_utils import ProcessorMixin
from ..tokenization_utils import PreTrainedTokenizer
from ..utils import (
ModelOutput,
Expand Down Expand Up @@ -716,6 +717,7 @@ def build_pipeline_init_args(
has_tokenizer: bool = False,
has_feature_extractor: bool = False,
has_image_processor: bool = False,
has_processor: bool = False,
supports_binary_output: bool = True,
) -> str:
docstring = r"""
Expand All @@ -738,6 +740,12 @@ def build_pipeline_init_args(
image_processor ([`BaseImageProcessor`]):
The image processor that will be used by the pipeline to encode data for the model. This object inherits from
[`BaseImageProcessor`]."""
if has_processor:
docstring += r"""
processor ([`ProcessorMixin`]):
The processor that will be used by the pipeline to encode data for the model. This object inherits from
[`ProcessorMixin`]. Processor is a composite object that might contain `tokenizer`, `feature_extractor`, and
`image_processor`."""
docstring += r"""
modelcard (`str` or [`ModelCard`], *optional*):
Model card attributed to the model for this pipeline.
Expand Down Expand Up @@ -774,7 +782,11 @@ def build_pipeline_init_args(


PIPELINE_INIT_ARGS = build_pipeline_init_args(
has_tokenizer=True, has_feature_extractor=True, has_image_processor=True, supports_binary_output=True
has_tokenizer=True,
has_feature_extractor=True,
has_image_processor=True,
has_processor=True,
supports_binary_output=True,
)


Expand All @@ -787,7 +799,11 @@ def build_pipeline_init_args(
)


@add_end_docstrings(build_pipeline_init_args(has_tokenizer=True, has_feature_extractor=True, has_image_processor=True))
@add_end_docstrings(
build_pipeline_init_args(
has_tokenizer=True, has_feature_extractor=True, has_image_processor=True, has_processor=True
)
)
class Pipeline(_ScikitCompat, PushToHubMixin):
"""
The Pipeline class is the class from which all pipelines inherit. Refer to this class for methods shared across
Expand All @@ -805,6 +821,22 @@ class Pipeline(_ScikitCompat, PushToHubMixin):
constructor argument. If set to `True`, the output will be stored in the pickle format.
"""

# Historically we have pipelines working with `tokenizer`, `feature_extractor`, and `image_processor`
# as separate processing components. While we have `processor` class that combines them, some pipelines
# might still operate with these components separately.
# With the addition of `processor` to `pipeline`, we want to avoid:
# - loading `processor` for pipelines that still work with `image_processor` and `tokenizer` separately;
# - loading `image_processor`/`tokenizer` as a separate component while we operate only with `processor`,
# because `processor` will load required sub-components by itself.
# Below flags allow granular control over loading components and set to be backward compatible with current
# pipelines logic. You may override these flags when creating your pipeline. For example, for
# `zero-shot-object-detection` pipeline which operates with `processor` you should set `_load_processor=True`
# and all the rest flags to `False` to avoid unnecessary loading of the components.
_load_processor = False
_load_image_processor = True
_load_feature_extractor = True
_load_tokenizer = True

default_input_names = None

def __init__(
Expand All @@ -813,6 +845,7 @@ def __init__(
tokenizer: Optional[PreTrainedTokenizer] = None,
feature_extractor: Optional[PreTrainedFeatureExtractor] = None,
image_processor: Optional[BaseImageProcessor] = None,
processor: Optional[ProcessorMixin] = None,
modelcard: Optional[ModelCard] = None,
framework: Optional[str] = None,
task: str = "",
Expand All @@ -830,6 +863,7 @@ def __init__(
self.tokenizer = tokenizer
self.feature_extractor = feature_extractor
self.image_processor = image_processor
self.processor = processor
self.modelcard = modelcard
self.framework = framework

Expand Down
11 changes: 9 additions & 2 deletions tests/models/altclip/test_modeling_altclip.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,9 +436,16 @@ class AltCLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)

# TODO: Fix the failed tests when this model gets more usage
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if pipeline_test_casse_name == "FeatureExtractionPipelineTests":
if pipeline_test_case_name == "FeatureExtractionPipelineTests":
return True

return False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,16 @@ class ASTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):

# TODO: Fix the failed tests when this model gets more usage
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if pipeline_test_casse_name == "AudioClassificationPipelineTests":
if pipeline_test_case_name == "AudioClassificationPipelineTests":
return True

return False
Expand Down
11 changes: 9 additions & 2 deletions tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,16 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT

# TODO: Fix the failed tests
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if pipeline_test_casse_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
if pipeline_test_case_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
return True

return False
Expand Down
11 changes: 9 additions & 2 deletions tests/models/blenderbot_small/test_modeling_blenderbot_small.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,16 @@ class BlenderbotSmallModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline

# TODO: Fix the failed tests when this model gets more usage
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
return pipeline_test_casse_name == "TextGenerationPipelineTests"
return pipeline_test_case_name == "TextGenerationPipelineTests"

def setUp(self):
self.model_tester = BlenderbotSmallModelTester(self)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,9 +321,16 @@ class FlaxBlenderbotSmallModelTest(FlaxModelTesterMixin, unittest.TestCase, Flax
all_generative_model_classes = (FlaxBlenderbotSmallForConditionalGeneration,) if is_flax_available() else ()

def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
return pipeline_test_casse_name == "TextGenerationPipelineTests"
return pipeline_test_case_name == "TextGenerationPipelineTests"

def setUp(self):
self.model_tester = FlaxBlenderbotSmallModelTester(self)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,16 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, PipelineTesterMixin, unitte
test_onnx = False

def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
return pipeline_test_casse_name == "TextGenerationPipelineTests"
return pipeline_test_case_name == "TextGenerationPipelineTests"

def setUp(self):
self.model_tester = TFBlenderbotSmallModelTester(self)
Expand Down
9 changes: 8 additions & 1 deletion tests/models/bros/test_modeling_bros.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,14 @@ class BrosModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
# BROS requires `bbox` in the inputs which doesn't fit into the above 2 pipelines' input formats.
# see https://github.com/huggingface/transformers/pull/26294
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
return True

Expand Down
9 changes: 8 additions & 1 deletion tests/models/cpm/test_tokenization_cpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,14 @@
class CpmTokenizationTest(unittest.TestCase):
# There is no `CpmModel`
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
return True

Expand Down
11 changes: 9 additions & 2 deletions tests/models/ctrl/test_modeling_ctrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,16 @@ class CTRLModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin

# TODO: Fix the failed tests
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if pipeline_test_casse_name == "ZeroShotClassificationPipelineTests":
if pipeline_test_case_name == "ZeroShotClassificationPipelineTests":
# Get `tokenizer does not have a padding token` error for both fast/slow tokenizers.
# `CTRLConfig` was never used in pipeline tests, either because of a missing checkpoint or because a tiny
# config could not be created.
Expand Down
Loading

0 comments on commit 48461c0

Please sign in to comment.