Skip to content

Commit

Permalink
Add table question answering to inference client (#1612)
Browse files Browse the repository at this point in the history
* dd table question answering to inference client

* Change in line with review points on other PRs

* typo

* make style

---------

Co-authored-by: Lucain Pouget <lucainp@gmail.com>
  • Loading branch information
martinbrose and Wauplin authored Sep 6, 2023
1 parent e98ada3 commit 0fbbf6b
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 6 deletions.
2 changes: 1 addition & 1 deletion docs/source/en/guides/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ has a simple API that supports the most common tasks. Here is a list of the curr
| | [Question Answering](https://huggingface.co/tasks/question-answering) | ✅ | [`~InferenceClient.question_answering`]
| | [Sentence Similarity](https://huggingface.co/tasks/sentence-similarity) || [`~InferenceClient.sentence_similarity`] |
| | [Summarization](https://huggingface.co/tasks/summarization) || [`~InferenceClient.summarization`] |
| | [Table Question Answering](https://huggingface.co/tasks/table-question-answering) | | |
| | [Table Question Answering](https://huggingface.co/tasks/table-question-answering) | | [`~InferenceClient.table_question_answering`] |
| | [Text Classification](https://huggingface.co/tasks/text-classification) || [`~InferenceClient.text_classification`] |
| | [Text Generation](https://huggingface.co/tasks/text-generation) || [`~InferenceClient.text_generation`] |
| | [Token Classification](https://huggingface.co/tasks/token-classification) || [`~InferenceClient.token_classification`] |
Expand Down
46 changes: 46 additions & 0 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
ImageSegmentationOutput,
ObjectDetectionOutput,
QuestionAnsweringOutput,
TableQuestionAnsweringOutput,
TokenClassificationOutput,
)
from huggingface_hub.utils import (
Expand Down Expand Up @@ -808,6 +809,51 @@ def summarization(
response = self.post(json=payload, model=model, task="summarization")
return _bytes_to_dict(response)[0]["summary_text"]

def table_question_answering(
self, table: Dict[str, Any], query: str, *, model: Optional[str] = None
) -> TableQuestionAnsweringOutput:
"""
Retrieve the answer to a question from information given in a table.
Args:
table (`str`):
A table of data represented as a dict of lists where entries are headers and the lists are all the
values, all lists must have the same size.
query (`str`):
The query in plain text that you want to ask the table.
model (`str`):
The model to use for the table-question-answering task. Can be a model ID hosted on the Hugging Face
Hub or a URL to a deployed Inference Endpoint.
Returns:
`Dict`: a dictionary of table question answering output containing the answer, coordinates, cells and the aggregator used.
Raises:
[`InferenceTimeoutError`]:
If the model is unavailable or the request times out.
`HTTPError`:
If the request fails with an HTTP error status code other than HTTP 503.
Example:
```py
>>> from huggingface_hub import InferenceClient
>>> client = InferenceClient()
>>> query = "How many stars does the transformers repository have?"
>>> table = {"Repository": ["Transformers", "Datasets", "Tokenizers"], "Stars": ["36542", "4512", "3934"]}
>>> client.table_question_answering(table, query, model="google/tapas-base-finetuned-wtq")
{'answer': 'AVERAGE > 36542', 'coordinates': [[0, 1]], 'cells': ['36542'], 'aggregator': 'AVERAGE'}
```
"""
response = self.post(
json={
"query": query,
"table": table,
},
model=model,
task="table-question-answering",
)
return _bytes_to_dict(response) # type: ignore

def text_classification(self, text: str, *, model: Optional[str] = None) -> List[ClassificationOutput]:
"""
Perform sentiment-analysis on the given text.
Expand Down
47 changes: 47 additions & 0 deletions src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
ImageSegmentationOutput,
ObjectDetectionOutput,
QuestionAnsweringOutput,
TableQuestionAnsweringOutput,
TokenClassificationOutput,
)
from huggingface_hub.utils import (
Expand Down Expand Up @@ -816,6 +817,52 @@ async def summarization(
response = await self.post(json=payload, model=model, task="summarization")
return _bytes_to_dict(response)[0]["summary_text"]

async def table_question_answering(
self, table: Dict[str, Any], query: str, *, model: Optional[str] = None
) -> TableQuestionAnsweringOutput:
"""
Retrieve the answer to a question from information given in a table.
Args:
table (`str`):
A table of data represented as a dict of lists where entries are headers and the lists are all the
values, all lists must have the same size.
query (`str`):
The query in plain text that you want to ask the table.
model (`str`):
The model to use for the table-question-answering task. Can be a model ID hosted on the Hugging Face
Hub or a URL to a deployed Inference Endpoint.
Returns:
`Dict`: a dictionary of table question answering output containing the answer, coordinates, cells and the aggregator used.
Raises:
[`InferenceTimeoutError`]:
If the model is unavailable or the request times out.
`aiohttp.ClientResponseError`:
If the request fails with an HTTP error status code other than HTTP 503.
Example:
```py
# Must be run in an async context
>>> from huggingface_hub import AsyncInferenceClient
>>> client = AsyncInferenceClient()
>>> query = "How many stars does the transformers repository have?"
>>> table = {"Repository": ["Transformers", "Datasets", "Tokenizers"], "Stars": ["36542", "4512", "3934"]}
>>> await client.table_question_answering(table, query, model="google/tapas-base-finetuned-wtq")
{'answer': 'AVERAGE > 36542', 'coordinates': [[0, 1]], 'cells': ['36542'], 'aggregator': 'AVERAGE'}
```
"""
response = await self.post(
json={
"query": query,
"table": table,
},
model=model,
task="table-question-answering",
)
return _bytes_to_dict(response) # type: ignore

async def text_classification(self, text: str, *, model: Optional[str] = None) -> List[ClassificationOutput]:
"""
Perform sentiment-analysis on the given text.
Expand Down
20 changes: 20 additions & 0 deletions src/huggingface_hub/inference/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,26 @@ class QuestionAnsweringOutput(TypedDict):
answer: str


class TableQuestionAnsweringOutput(TypedDict):
"""Dictionary containing information about a [`~InferenceClient.table_question_answering`] task.
Args:
answer (`str`):
The plaintext answer.
coordinates (`List[List[int]]`):
A list of coordinates of the cells referenced in the answer.
cells (`List[int]`):
A list of coordinates of the cells contents.
aggregator (`str`):
The aggregator used to get the answer.
"""

answer: str
coordinates: List[List[int]]
cells: List[List[int]]
aggregator: str


class TokenClassificationOutput(TypedDict):
"""Dictionary containing the output of a [`~InferenceClient.token_classification`] task.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
interactions:
- request:
body: '{"query": "How many stars does the transformers repository have?", "table":
{"Repository": ["Transformers", "Datasets", "Tokenizers"], "Stars": ["36542",
"4512", "3934"]}}'
headers:
Accept:
- '*/*'
Accept-Encoding:
- gzip, deflate, br
Connection:
- keep-alive
Content-Length:
- '171'
Content-Type:
- application/json
X-Amzn-Trace-Id:
- 98fa539e-9908-4f3f-864c-94e56277e54b
user-agent:
- unknown/None; hf_hub/0.17.0.dev0; python/3.10.12; torch/2.0.0.post101
method: POST
uri: https://api-inference.huggingface.co/models/google/tapas-base-finetuned-wtq
response:
body:
string: '{"answer":"AVERAGE > 36542","coordinates":[[0,1]],"cells":["36542"],"aggregator":"AVERAGE"}'
headers:
Connection:
- keep-alive
Content-Type:
- application/json
Date:
- Tue, 05 Sep 2023 22:14:00 GMT
Transfer-Encoding:
- chunked
access-control-allow-credentials:
- 'true'
access-control-expose-headers:
- x-compute-type, x-compute-time
server:
- uvicorn
vary:
- Origin, Access-Control-Request-Method, Access-Control-Request-Headers
x-compute-characters:
- '96'
x-compute-time:
- '0.056'
x-compute-type:
- cpu
x-request-id:
- qQfAM1RLCDKDAJYuavdEW
x-sha:
- e3dde1905dea877b0df1a5c057533e48327dee77
status:
code: 200
message: OK
version: 1
24 changes: 19 additions & 5 deletions tests/test_inference_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,12 @@
"object-detection": "facebook/detr-resnet-50",
"sentence-similarity": "sentence-transformers/all-MiniLM-L6-v2",
"summarization": "sshleifer/distilbart-cnn-12-6",
"table-question-answering": "google/tapas-base-finetuned-wtq",
"text-classification": "distilbert-base-uncased-finetuned-sst-2-english",
"text-to-image": "CompVis/stable-diffusion-v1-4",
"text-to-speech": "espnet/kan-bayashi_ljspeech_vits",
"token-classification": "dbmdz/bert-large-cased-finetuned-conll03-english",
"translation": "t5-small",
"zero-shot-image-classification": "openai/clip-vit-base-patch32",
}

Expand Down Expand Up @@ -210,6 +213,20 @@ def test_summarization(self) -> None:
" surpassed the Washington Monument to become the tallest man-made structure in the world.",
)

def test_table_question_answering(self) -> None:
table = {
"Repository": ["Transformers", "Datasets", "Tokenizers"],
"Stars": ["36542", "4512", "3934"],
}
query = "How many stars does the transformers repository have?"
output = self.client.table_question_answering(query=query, table=table)
self.assertEqual(type(output), dict)
self.assertEqual(len(output), 4)
self.assertEqual(
set(output.keys()),
{"aggregator", "answer", "cells", "coordinates"},
)

def test_text_classification(self) -> None:
output = self.client.text_classification("I like you")
self.assertIsInstance(output, list)
Expand Down Expand Up @@ -238,14 +255,11 @@ def test_text_to_speech(self) -> None:
self.assertIsInstance(audio, bytes)

def test_translation(self) -> None:
output = self.client.translation("Hello world", model="t5-small")
output = self.client.translation("Hello world")
self.assertEqual(output, "Hallo Welt")

def test_token_classification(self) -> None:
model = "dbmdz/bert-large-cased-finetuned-conll03-english"
output = self.client.token_classification(
"My name is Sarah Jessica Parker but you can call me Jessica", model=model
)
output = self.client.token_classification("My name is Sarah Jessica Parker but you can call me Jessica")
self.assertIsInstance(output, list)
self.assertGreater(len(output), 0)
for item in output:
Expand Down

0 comments on commit 0fbbf6b

Please sign in to comment.