Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend pipelines for automodel tupels #12025

Merged
102 changes: 42 additions & 60 deletions src/transformers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
PipelineDataFormat,
PipelineException,
get_default_model,
infer_framework_from_model,
infer_framework_load_model,
)
from .conversational import Conversation, ConversationalPipeline
from .feature_extraction import FeatureExtractionPipeline
Expand Down Expand Up @@ -110,14 +110,14 @@
SUPPORTED_TASKS = {
"feature-extraction": {
"impl": FeatureExtractionPipeline,
"tf": TFAutoModel if is_tf_available() else None,
"pt": AutoModel if is_torch_available() else None,
"tf": (TFAutoModel,) if is_tf_available() else (),
"pt": (AutoModel,) if is_torch_available() else (),
"default": {"model": {"pt": "distilbert-base-cased", "tf": "distilbert-base-cased"}},
},
"text-classification": {
"impl": TextClassificationPipeline,
"tf": TFAutoModelForSequenceClassification if is_tf_available() else None,
"pt": AutoModelForSequenceClassification if is_torch_available() else None,
"tf": (TFAutoModelForSequenceClassification,) if is_tf_available() else (),
"pt": (AutoModelForSequenceClassification,) if is_torch_available() else (),
"default": {
"model": {
"pt": "distilbert-base-uncased-finetuned-sst-2-english",
Expand All @@ -127,8 +127,8 @@
},
"token-classification": {
"impl": TokenClassificationPipeline,
"tf": TFAutoModelForTokenClassification if is_tf_available() else None,
"pt": AutoModelForTokenClassification if is_torch_available() else None,
"tf": (TFAutoModelForTokenClassification,) if is_tf_available() else (),
"pt": (AutoModelForTokenClassification,) if is_torch_available() else (),
"default": {
"model": {
"pt": "dbmdz/bert-large-cased-finetuned-conll03-english",
Expand All @@ -138,16 +138,16 @@
},
"question-answering": {
"impl": QuestionAnsweringPipeline,
"tf": TFAutoModelForQuestionAnswering if is_tf_available() else None,
"pt": AutoModelForQuestionAnswering if is_torch_available() else None,
"tf": (TFAutoModelForQuestionAnswering,) if is_tf_available() else (),
"pt": (AutoModelForQuestionAnswering,) if is_torch_available() else (),
"default": {
"model": {"pt": "distilbert-base-cased-distilled-squad", "tf": "distilbert-base-cased-distilled-squad"},
},
},
"table-question-answering": {
"impl": TableQuestionAnsweringPipeline,
"pt": AutoModelForTableQuestionAnswering if is_torch_available() else None,
"tf": None,
"pt": (AutoModelForTableQuestionAnswering,) if is_torch_available() else (),
"tf": (),
"default": {
"model": {
"pt": "google/tapas-base-finetuned-wtq",
Expand All @@ -158,21 +158,21 @@
},
"fill-mask": {
"impl": FillMaskPipeline,
"tf": TFAutoModelForMaskedLM if is_tf_available() else None,
"pt": AutoModelForMaskedLM if is_torch_available() else None,
"tf": (TFAutoModelForMaskedLM,) if is_tf_available() else (),
"pt": (AutoModelForMaskedLM,) if is_torch_available() else (),
"default": {"model": {"pt": "distilroberta-base", "tf": "distilroberta-base"}},
},
"summarization": {
"impl": SummarizationPipeline,
"tf": TFAutoModelForSeq2SeqLM if is_tf_available() else None,
"pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
"tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
"pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
"default": {"model": {"pt": "sshleifer/distilbart-cnn-12-6", "tf": "t5-small"}},
},
# This task is a special case as it's parametrized by SRC, TGT languages.
"translation": {
"impl": TranslationPipeline,
"tf": TFAutoModelForSeq2SeqLM if is_tf_available() else None,
"pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
"tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
"pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
"default": {
("en", "fr"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
("en", "de"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
Expand All @@ -181,20 +181,20 @@
},
"text2text-generation": {
"impl": Text2TextGenerationPipeline,
"tf": TFAutoModelForSeq2SeqLM if is_tf_available() else None,
"pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
"tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
"pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
"default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
},
"text-generation": {
"impl": TextGenerationPipeline,
"tf": TFAutoModelForCausalLM if is_tf_available() else None,
"pt": AutoModelForCausalLM if is_torch_available() else None,
"tf": (TFAutoModelForCausalLM,) if is_tf_available() else (),
"pt": (AutoModelForCausalLM,) if is_torch_available() else (),
"default": {"model": {"pt": "gpt2", "tf": "gpt2"}},
},
"zero-shot-classification": {
"impl": ZeroShotClassificationPipeline,
"tf": TFAutoModelForSequenceClassification if is_tf_available() else None,
"pt": AutoModelForSequenceClassification if is_torch_available() else None,
"tf": (TFAutoModelForSequenceClassification,) if is_tf_available() else (),
"pt": (AutoModelForSequenceClassification,) if is_torch_available() else (),
"default": {
"model": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
"config": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
Expand All @@ -203,14 +203,14 @@
},
"conversational": {
"impl": ConversationalPipeline,
"tf": TFAutoModelForCausalLM if is_tf_available() else None,
"pt": AutoModelForCausalLM if is_torch_available() else None,
"tf": (TFAutoModelForSeq2SeqLM, TFAutoModelForCausalLM) if is_tf_available() else (),
Narsil marked this conversation as resolved.
Show resolved Hide resolved
"pt": (AutoModelForSeq2SeqLM, AutoModelForCausalLM) if is_torch_available() else (),
"default": {"model": {"pt": "microsoft/DialoGPT-medium", "tf": "microsoft/DialoGPT-medium"}},
},
"image-classification": {
"impl": ImageClassificationPipeline,
"tf": None,
"pt": AutoModelForImageClassification if is_torch_available() else None,
"tf": (),
"pt": (AutoModelForImageClassification,) if is_torch_available() else (),
"default": {"model": {"pt": "google/vit-base-patch16-224"}},
},
}
Expand Down Expand Up @@ -379,53 +379,35 @@ def pipeline(
>>> tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
>>> pipeline('ner', model=model, tokenizer=tokenizer)
"""

# Retrieve the task
targeted_task, task_options = check_task(task)
task_class = targeted_task["impl"]

# Use default model/config/tokenizer for the task if no model is provided
if model is None:
# At that point framework might still be undetermined
model = get_default_model(targeted_task, framework, task_options)

model_name = model if isinstance(model, str) else None

# Infer the framework form the model
if framework is None:
framework, model = infer_framework_from_model(model, targeted_task, revision=revision, task=task)

task_class, model_class = targeted_task["impl"], targeted_task[framework]

# Retrieve use_auth_token and add it to model_kwargs to be used in .from_pretrained
model_kwargs["use_auth_token"] = model_kwargs.get("use_auth_token", use_auth_token)

# Config is the primordial information item.
# Instantiate config if needed
if isinstance(config, str):
config = AutoConfig.from_pretrained(config, revision=revision, _from_pipeline=task, **model_kwargs)
elif config is None and isinstance(model, str):
config = AutoConfig.from_pretrained(model, revision=revision, _from_pipeline=task, **model_kwargs)

# Instantiate model if needed
if isinstance(model, str):
# Handle transparent TF/PT model conversion
if framework == "pt" and model.endswith(".h5"):
model_kwargs["from_tf"] = True
logger.warning(
"Model might be a TensorFlow model (ending with `.h5`) but TensorFlow is not available. "
"Trying to load the model with PyTorch."
)
elif framework == "tf" and model.endswith(".bin"):
model_kwargs["from_pt"] = True
logger.warning(
"Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. "
"Trying to load the model with Tensorflow."
)
model_name = model if isinstance(model, str) else None

if model_class is None:
raise ValueError(
f"Pipeline using {framework} framework, but this framework is not supported by this pipeline."
)
# Retrieve use_auth_token and add it to model_kwargs to be used in .from_pretrained
model_kwargs["use_auth_token"] = model_kwargs.get("use_auth_token", use_auth_token)

model = model_class.from_pretrained(
model, config=config, revision=revision, _from_pipeline=task, **model_kwargs
)
# Infer the framework from the model
# Forced if framework already defined, inferred if it's None
# Will load the correct model if possible
model_classes = {"tf": targeted_task["tf"], "pt": targeted_task["pt"]}
framework, model = infer_framework_load_model(
model, model_classes=model_classes, config=config, framework=framework, revision=revision, task=task
)

model_config = model.config

Expand Down
117 changes: 101 additions & 16 deletions src/transformers/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import csv
import importlib
import json
import os
import pickle
Expand All @@ -21,11 +22,12 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from os.path import abspath, exists
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

from ..feature_extraction_utils import PreTrainedFeatureExtractor
from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available
from ..modelcard import ModelCard
from ..models.auto.configuration_auto import AutoConfig
from ..tokenization_utils import PreTrainedTokenizer, TruncationStrategy
from ..utils import logging

Expand All @@ -48,8 +50,13 @@
logger = logging.get_logger(__name__)


def infer_framework_from_model(
model, model_classes: Optional[Dict[str, type]] = None, task: Optional[str] = None, **model_kwargs
def infer_framework_load_model(
model,
config: AutoConfig,
model_classes: Optional[Dict[str, Tuple[type]]] = None,
task: Optional[str] = None,
framework: Optional[str] = None,
**model_kwargs
):
"""
Select framework (TensorFlow or PyTorch) to use from the :obj:`model` passed. Returns a tuple (framework, model).
Expand All @@ -64,6 +71,8 @@ def infer_framework_from_model(
model (:obj:`str`, :class:`~transformers.PreTrainedModel` or :class:`~transformers.TFPreTrainedModel`):
The model to infer the framework from. If :obj:`str`, a checkpoint name. The model to infer the framewrok
from.
config (:class:`~transformers.AutoConfig`):
The config associated with the model to help using the correct class
model_classes (dictionary :obj:`str` to :obj:`type`, `optional`):
A mapping framework to class.
task (:obj:`str`):
Expand All @@ -83,24 +92,100 @@ def infer_framework_from_model(
)
if isinstance(model, str):
model_kwargs["_from_pipeline"] = task
if is_torch_available() and not is_tf_available():
model_class = model_classes.get("pt", AutoModel)
model = model_class.from_pretrained(model, **model_kwargs)
elif is_tf_available() and not is_torch_available():
model_class = model_classes.get("tf", TFAutoModel)
model = model_class.from_pretrained(model, **model_kwargs)
else:
class_tuple = ()
look_pt = is_torch_available() and framework in {"pt", None}
look_tf = is_tf_available() and framework in {"tf", None}
if model_classes:
if look_pt:
class_tuple = class_tuple + model_classes.get("pt", (AutoModel,))
if look_tf:
class_tuple = class_tuple + model_classes.get("tf", (TFAutoModel,))
if config.architectures:
classes = []
for architecture in config.architectures:
transformers_module = importlib.import_module("transformers")
if look_tf:
_class = getattr(transformers_module, architecture, None)
if _class is not None:
classes.append(_class)
if look_pt:
_class = getattr(transformers_module, f"TF{architecture}", None)
if _class is not None:
classes.append(_class)
class_tuple = class_tuple + tuple(classes)

if len(class_tuple) == 0:
raise ValueError(f"Pipeline cannot infer suitable model classes from {model}")

for model_class in class_tuple:
kwargs = model_kwargs.copy()
if framework == "pt" and model.endswith(".h5"):
kwargs["from_tf"] = True
logger.warning(
"Model might be a TensorFlow model (ending with `.h5`) but TensorFlow is not available. "
"Trying to load the model with PyTorch."
)
elif framework == "tf" and model.endswith(".bin"):
kwargs["from_pt"] = True
logger.warning(
"Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. "
"Trying to load the model with Tensorflow."
)

try:
model_class = model_classes.get("pt", AutoModel)
model = model_class.from_pretrained(model, **model_kwargs)
except OSError:
model_class = model_classes.get("tf", TFAutoModel)
model = model_class.from_pretrained(model, **model_kwargs)
model = model_class.from_pretrained(model, **kwargs)
# Stop loading on the first successful load.
break
except (OSError, ValueError):
continue

if isinstance(model, str):
raise ValueError(f"Could not load model {model} with any of the following classes: {class_tuple}.")

framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
return framework, model


def infer_framework_from_model(
model,
model_classes: Optional[Dict[str, Tuple[type]]] = None,
task: Optional[str] = None,
framework: Optional[str] = None,
**model_kwargs
):
"""
Select framework (TensorFlow or PyTorch) to use from the :obj:`model` passed. Returns a tuple (framework, model).

If :obj:`model` is instantiated, this function will just infer the framework from the model class. Otherwise
:obj:`model` is actually a checkpoint name and this method will try to instantiate it using :obj:`model_classes`.
Since we don't want to instantiate the model twice, this model is returned for use by the pipeline.

If both frameworks are installed and available for :obj:`model`, PyTorch is selected.

Args:
model (:obj:`str`, :class:`~transformers.PreTrainedModel` or :class:`~transformers.TFPreTrainedModel`):
The model to infer the framework from. If :obj:`str`, a checkpoint name. The model to infer the framewrok
from.
model_classes (dictionary :obj:`str` to :obj:`type`, `optional`):
A mapping framework to class.
task (:obj:`str`):
The task defining which pipeline will be returned.
model_kwargs:
Additional dictionary of keyword arguments passed along to the model's :obj:`from_pretrained(...,
**model_kwargs)` function.

Returns:
:obj:`Tuple`: A tuple framework, model.
"""
if isinstance(model, str):
config = AutoConfig.from_pretrained(model, _from_pipeline=task, **model_kwargs)
else:
config = model.config
return infer_framework_load_model(
model, config, model_classes=model_classes, _from_pipeline=task, task=task, framework=framework, **model_kwargs
)


def get_framework(model, revision: Optional[str] = None):
"""
Select framework (TensorFlow or PyTorch) to use.
Expand Down Expand Up @@ -534,7 +619,7 @@ def __init__(
):

if framework is None:
framework, model = infer_framework_from_model(model)
framework, model = infer_framework_load_model(model, config=model.config)

self.task = task
self.model = model
Expand Down
Loading