diff --git a/src/transformers/commands/run.py b/src/transformers/commands/run.py index 856ac6d12dd082..563a086a7d8727 100644 --- a/src/transformers/commands/run.py +++ b/src/transformers/commands/run.py @@ -14,7 +14,7 @@ from argparse import ArgumentParser -from ..pipelines import SUPPORTED_TASKS, Pipeline, PipelineDataFormat, pipeline +from ..pipelines import SUPPORTED_TASKS, TASK_ALIASES, Pipeline, PipelineDataFormat, pipeline from ..utils import logging from . import BaseTransformersCLICommand @@ -63,7 +63,9 @@ def __init__(self, nlp: Pipeline, reader: PipelineDataFormat): @staticmethod def register_subcommand(parser: ArgumentParser): run_parser = parser.add_parser("run", help="Run a pipeline through the CLI") - run_parser.add_argument("--task", choices=SUPPORTED_TASKS.keys(), help="Task to run") + run_parser.add_argument( + "--task", choices=list(SUPPORTED_TASKS.keys()) + list(TASK_ALIASES.keys()), help="Task to run" + ) run_parser.add_argument("--input", type=str, help="Path to the file to use for inference") run_parser.add_argument("--output", type=str, help="Path to the file that will be used post to write results.") run_parser.add_argument("--model", type=str, help="Name or path to the model to instantiate.") diff --git a/src/transformers/commands/serving.py b/src/transformers/commands/serving.py index cb4a3fe6c1f155..dd2aec1f3aba3a 100644 --- a/src/transformers/commands/serving.py +++ b/src/transformers/commands/serving.py @@ -15,7 +15,7 @@ from argparse import ArgumentParser, Namespace from typing import Any, List, Optional -from ..pipelines import SUPPORTED_TASKS, Pipeline, pipeline +from ..pipelines import SUPPORTED_TASKS, TASK_ALIASES, Pipeline, pipeline from ..utils import logging from . import BaseTransformersCLICommand @@ -102,7 +102,10 @@ def register_subcommand(parser: ArgumentParser): "serve", help="CLI tool to run inference requests through REST and GraphQL endpoints." ) serve_parser.add_argument( - "--task", type=str, choices=SUPPORTED_TASKS.keys(), help="The task to run the pipeline on" + "--task", + type=str, + choices=list(SUPPORTED_TASKS.keys()) + list(TASK_ALIASES.keys()), + help="The task to run the pipeline on", ) serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.") serve_parser.add_argument("--port", type=int, default=8888, help="Port the serving will listen to.") diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index fb1b959d4686da..9e55c3f93c3624 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -93,6 +93,10 @@ # Register all the supported tasks here +TASK_ALIASES = { + "sentiment-analysis": "text-classification", + "ner": "token-classification", +} SUPPORTED_TASKS = { "feature-extraction": { "impl": FeatureExtractionPipeline, @@ -100,7 +104,7 @@ "pt": AutoModel if is_torch_available() else None, "default": {"model": {"pt": "distilbert-base-cased", "tf": "distilbert-base-cased"}}, }, - "sentiment-analysis": { + "text-classification": { "impl": TextClassificationPipeline, "tf": TFAutoModelForSequenceClassification if is_tf_available() else None, "pt": AutoModelForSequenceClassification if is_torch_available() else None, @@ -111,7 +115,7 @@ }, }, }, - "ner": { + "token-classification": { "impl": TokenClassificationPipeline, "tf": TFAutoModelForTokenClassification if is_tf_available() else None, "pt": AutoModelForTokenClassification if is_torch_available() else None, @@ -206,8 +210,10 @@ def check_task(task: str) -> Tuple[Dict, Any]: The task defining which pipeline will be returned. Currently accepted tasks are: - :obj:`"feature-extraction"` - - :obj:`"sentiment-analysis"` - - :obj:`"ner"` + - :obj:`"text-classification"` + - :obj:`"sentiment-analysis"` (alias of :obj:`"text-classification") + - :obj:`"token-classification"` + - :obj:`"ner"` (alias of :obj:`"token-classification") - :obj:`"question-answering"` - :obj:`"fill-mask"` - :obj:`"summarization"` @@ -222,6 +228,8 @@ def check_task(task: str) -> Tuple[Dict, Any]: """ + if task in TASK_ALIASES: + task = TASK_ALIASES[task] if task in SUPPORTED_TASKS: targeted_task = SUPPORTED_TASKS[task] return targeted_task, None @@ -264,8 +272,12 @@ def pipeline( The task defining which pipeline will be returned. Currently accepted tasks are: - :obj:`"feature-extraction"`: will return a :class:`~transformers.FeatureExtractionPipeline`. - - :obj:`"sentiment-analysis"`: will return a :class:`~transformers.TextClassificationPipeline`. - - :obj:`"ner"`: will return a :class:`~transformers.TokenClassificationPipeline`. + - :obj:`"text-classification"`: will return a :class:`~transformers.TextClassificationPipeline`. + - :obj:`"sentiment-analysis"`: (alias of :obj:`"text-classification") will return a + :class:`~transformers.TextClassificationPipeline`. + - :obj:`"token-classification"`: will return a :class:`~transformers.TokenClassificationPipeline`. + - :obj:`"ner"` (alias of :obj:`"token-classification"): will return a + :class:`~transformers.TokenClassificationPipeline`. - :obj:`"question-answering"`: will return a :class:`~transformers.QuestionAnsweringPipeline`. - :obj:`"fill-mask"`: will return a :class:`~transformers.FillMaskPipeline`. - :obj:`"summarization"`: will return a :class:`~transformers.SummarizationPipeline`. diff --git a/tests/test_pipelines_sentiment_analysis.py b/tests/test_pipelines_text_classification.py similarity index 92% rename from tests/test_pipelines_sentiment_analysis.py rename to tests/test_pipelines_text_classification.py index 7f5dbfa7e8cb6f..7db8a24116c5ed 100644 --- a/tests/test_pipelines_sentiment_analysis.py +++ b/tests/test_pipelines_text_classification.py @@ -17,7 +17,7 @@ from .test_pipelines_common import MonoInputPipelineCommonMixin -class SentimentAnalysisPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase): +class TextClassificationPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase): pipeline_task = "sentiment-analysis" small_models = [ "sshleifer/tiny-distilbert-base-uncased-finetuned-sst-2-english" diff --git a/tests/test_pipelines_ner.py b/tests/test_pipelines_token_classification.py similarity index 99% rename from tests/test_pipelines_ner.py rename to tests/test_pipelines_token_classification.py index c7b8171ef2578b..756ccbf52dd526 100644 --- a/tests/test_pipelines_ner.py +++ b/tests/test_pipelines_token_classification.py @@ -27,7 +27,7 @@ VALID_INPUTS = ["A simple string", ["list of strings", "A simple string that is quite a bit longer"]] -class NerPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase): +class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase): pipeline_task = "ner" small_models = [ "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"