Skip to content

Commit

Permalink
Adding pipeline task aliases. (huggingface#11247)
Browse files Browse the repository at this point in the history
* Adding task aliases and adding `token-classification` and
`text-classification` tasks.

* Cleaning docstring.
  • Loading branch information
Narsil authored and Iwontbecreative committed Jul 15, 2021
1 parent 58b76e2 commit 4a40372
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 12 deletions.
6 changes: 4 additions & 2 deletions src/transformers/commands/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from argparse import ArgumentParser

from ..pipelines import SUPPORTED_TASKS, Pipeline, PipelineDataFormat, pipeline
from ..pipelines import SUPPORTED_TASKS, TASK_ALIASES, Pipeline, PipelineDataFormat, pipeline
from ..utils import logging
from . import BaseTransformersCLICommand

Expand Down Expand Up @@ -63,7 +63,9 @@ def __init__(self, nlp: Pipeline, reader: PipelineDataFormat):
@staticmethod
def register_subcommand(parser: ArgumentParser):
run_parser = parser.add_parser("run", help="Run a pipeline through the CLI")
run_parser.add_argument("--task", choices=SUPPORTED_TASKS.keys(), help="Task to run")
run_parser.add_argument(
"--task", choices=list(SUPPORTED_TASKS.keys()) + list(TASK_ALIASES.keys()), help="Task to run"
)
run_parser.add_argument("--input", type=str, help="Path to the file to use for inference")
run_parser.add_argument("--output", type=str, help="Path to the file that will be used post to write results.")
run_parser.add_argument("--model", type=str, help="Name or path to the model to instantiate.")
Expand Down
7 changes: 5 additions & 2 deletions src/transformers/commands/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from argparse import ArgumentParser, Namespace
from typing import Any, List, Optional

from ..pipelines import SUPPORTED_TASKS, Pipeline, pipeline
from ..pipelines import SUPPORTED_TASKS, TASK_ALIASES, Pipeline, pipeline
from ..utils import logging
from . import BaseTransformersCLICommand

Expand Down Expand Up @@ -102,7 +102,10 @@ def register_subcommand(parser: ArgumentParser):
"serve", help="CLI tool to run inference requests through REST and GraphQL endpoints."
)
serve_parser.add_argument(
"--task", type=str, choices=SUPPORTED_TASKS.keys(), help="The task to run the pipeline on"
"--task",
type=str,
choices=list(SUPPORTED_TASKS.keys()) + list(TASK_ALIASES.keys()),
help="The task to run the pipeline on",
)
serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.")
serve_parser.add_argument("--port", type=int, default=8888, help="Port the serving will listen to.")
Expand Down
24 changes: 18 additions & 6 deletions src/transformers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,18 @@


# Register all the supported tasks here
TASK_ALIASES = {
"sentiment-analysis": "text-classification",
"ner": "token-classification",
}
SUPPORTED_TASKS = {
"feature-extraction": {
"impl": FeatureExtractionPipeline,
"tf": TFAutoModel if is_tf_available() else None,
"pt": AutoModel if is_torch_available() else None,
"default": {"model": {"pt": "distilbert-base-cased", "tf": "distilbert-base-cased"}},
},
"sentiment-analysis": {
"text-classification": {
"impl": TextClassificationPipeline,
"tf": TFAutoModelForSequenceClassification if is_tf_available() else None,
"pt": AutoModelForSequenceClassification if is_torch_available() else None,
Expand All @@ -111,7 +115,7 @@
},
},
},
"ner": {
"token-classification": {
"impl": TokenClassificationPipeline,
"tf": TFAutoModelForTokenClassification if is_tf_available() else None,
"pt": AutoModelForTokenClassification if is_torch_available() else None,
Expand Down Expand Up @@ -206,8 +210,10 @@ def check_task(task: str) -> Tuple[Dict, Any]:
The task defining which pipeline will be returned. Currently accepted tasks are:
- :obj:`"feature-extraction"`
- :obj:`"sentiment-analysis"`
- :obj:`"ner"`
- :obj:`"text-classification"`
- :obj:`"sentiment-analysis"` (alias of :obj:`"text-classification")
- :obj:`"token-classification"`
- :obj:`"ner"` (alias of :obj:`"token-classification")
- :obj:`"question-answering"`
- :obj:`"fill-mask"`
- :obj:`"summarization"`
Expand All @@ -222,6 +228,8 @@ def check_task(task: str) -> Tuple[Dict, Any]:
"""
if task in TASK_ALIASES:
task = TASK_ALIASES[task]
if task in SUPPORTED_TASKS:
targeted_task = SUPPORTED_TASKS[task]
return targeted_task, None
Expand Down Expand Up @@ -264,8 +272,12 @@ def pipeline(
The task defining which pipeline will be returned. Currently accepted tasks are:
- :obj:`"feature-extraction"`: will return a :class:`~transformers.FeatureExtractionPipeline`.
- :obj:`"sentiment-analysis"`: will return a :class:`~transformers.TextClassificationPipeline`.
- :obj:`"ner"`: will return a :class:`~transformers.TokenClassificationPipeline`.
- :obj:`"text-classification"`: will return a :class:`~transformers.TextClassificationPipeline`.
- :obj:`"sentiment-analysis"`: (alias of :obj:`"text-classification") will return a
:class:`~transformers.TextClassificationPipeline`.
- :obj:`"token-classification"`: will return a :class:`~transformers.TokenClassificationPipeline`.
- :obj:`"ner"` (alias of :obj:`"token-classification"): will return a
:class:`~transformers.TokenClassificationPipeline`.
- :obj:`"question-answering"`: will return a :class:`~transformers.QuestionAnsweringPipeline`.
- :obj:`"fill-mask"`: will return a :class:`~transformers.FillMaskPipeline`.
- :obj:`"summarization"`: will return a :class:`~transformers.SummarizationPipeline`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .test_pipelines_common import MonoInputPipelineCommonMixin


class SentimentAnalysisPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
class TextClassificationPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
pipeline_task = "sentiment-analysis"
small_models = [
"sshleifer/tiny-distilbert-base-uncased-finetuned-sst-2-english"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
VALID_INPUTS = ["A simple string", ["list of strings", "A simple string that is quite a bit longer"]]


class NerPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
pipeline_task = "ner"
small_models = [
"sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"
Expand Down

0 comments on commit 4a40372

Please sign in to comment.