Skip to content

Commit

Permalink
Add translation to inference client (#1608)
Browse files Browse the repository at this point in the history
* Add translation to inference client

* Corrected typo

* Refactor in line with review points on related PRs

* Change in line with review points on other PRs

---------

Co-authored-by: Lucain Pouget <lucainp@gmail.com>
  • Loading branch information
martinbrose and Wauplin authored Sep 6, 2023
1 parent c716714 commit e6f6760
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 1 deletion.
2 changes: 1 addition & 1 deletion docs/source/en/guides/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ has a simple API that supports the most common tasks. Here is a list of the curr
| | [Text Classification](https://huggingface.co/tasks/text-classification) || [`~InferenceClient.text_classification`] |
| | [Text Generation](https://huggingface.co/tasks/text-generation) || [`~InferenceClient.text_generation`] |
| | [Token Classification](https://huggingface.co/tasks/token-classification) || [`~InferenceClient.token_classification`] |
| | [Translation](https://huggingface.co/tasks/translation) | | |
| | [Translation](https://huggingface.co/tasks/translation) | | [`~InferenceClient.translation`] |
| | [Zero Shot Classification](https://huggingface.co/tasks/zero-shot-image-classification) | | |
| Tabular | [Tabular Classification](https://huggingface.co/tasks/tabular-classification) | | |
| | [Tabular Regression](https://huggingface.co/tasks/tabular-regression) | | |
Expand Down
37 changes: 37 additions & 0 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1320,6 +1320,43 @@ def token_classification(self, text: str, *, model: Optional[str] = None) -> Lis
)
return _bytes_to_list(response)

def translation(self, text: str, *, model: Optional[str] = None) -> str:
"""
Convert text from one language to another.
Check out https://huggingface.co/tasks/translation for more information on how to choose the best model for
your specific use case. Source and target languages usually depends on the model.
Args:
text (`str`):
A string to be translated.
model (`str`, *optional*):
The model to use for the translation task. Can be a model ID hosted on the Hugging Face Hub or a URL to
a deployed Inference Endpoint. If not provided, the default recommended translation model will be used.
Defaults to None.
Returns:
`str`: The generated translated text.
Raises:
[`InferenceTimeoutError`]:
If the model is unavailable or the request times out.
`HTTPError`:
If the request fails with an HTTP error status code other than HTTP 503.
Example:
```py
>>> from huggingface_hub import InferenceClient
>>> client = InferenceClient()
>>> client.translation("My name is Wolfgang and I live in Berlin")
'Mein Name ist Wolfgang und ich lebe in Berlin.'
>>> client.translation("My name is Wolfgang and I live in Berlin", model="Helsinki-NLP/opus-mt-en-fr")
"Je m'appelle Wolfgang et je vis à Berlin."
```
"""
response = self.post(json={"inputs": text}, model=model, task="translation")
return _bytes_to_dict(response)[0]["translation_text"]

def zero_shot_image_classification(
self, image: ContentT, labels: List[str], *, model: Optional[str] = None
) -> List[ClassificationOutput]:
Expand Down
38 changes: 38 additions & 0 deletions src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1333,6 +1333,44 @@ async def token_classification(self, text: str, *, model: Optional[str] = None)
)
return _bytes_to_list(response)

async def translation(self, text: str, *, model: Optional[str] = None) -> str:
"""
Convert text from one language to another.
Check out https://huggingface.co/tasks/translation for more information on how to choose the best model for
your specific use case. Source and target languages usually depends on the model.
Args:
text (`str`):
A string to be translated.
model (`str`, *optional*):
The model to use for the translation task. Can be a model ID hosted on the Hugging Face Hub or a URL to
a deployed Inference Endpoint. If not provided, the default recommended translation model will be used.
Defaults to None.
Returns:
`str`: The generated translated text.
Raises:
[`InferenceTimeoutError`]:
If the model is unavailable or the request times out.
`aiohttp.ClientResponseError`:
If the request fails with an HTTP error status code other than HTTP 503.
Example:
```py
# Must be run in an async context
>>> from huggingface_hub import AsyncInferenceClient
>>> client = AsyncInferenceClient()
>>> await client.translation("My name is Wolfgang and I live in Berlin")
'Mein Name ist Wolfgang und ich lebe in Berlin.'
>>> await client.translation("My name is Wolfgang and I live in Berlin", model="Helsinki-NLP/opus-mt-en-fr")
"Je m'appelle Wolfgang et je vis à Berlin."
```
"""
response = await self.post(json={"inputs": text}, model=model, task="translation")
return _bytes_to_dict(response)[0]["translation_text"]

async def zero_shot_image_classification(
self, image: ContentT, labels: List[str], *, model: Optional[str] = None
) -> List[ClassificationOutput]:
Expand Down
48 changes: 48 additions & 0 deletions tests/cassettes/InferenceClientVCRTest.test_translation.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
interactions:
- request:
body: '{"inputs": "Hello world"}'
headers:
Accept:
- '*/*'
Accept-Encoding:
- gzip, deflate, br
Connection:
- keep-alive
Content-Length:
- '25'
Content-Type:
- application/json
X-Amzn-Trace-Id:
- 4a43b00f-51f2-4e1a-86ba-92e4606509be
user-agent:
- unknown/None; hf_hub/0.17.0.dev0; python/3.10.12
method: POST
uri: https://api-inference.huggingface.co/models/t5-small
response:
body:
string: '[{"translation_text":"Hallo Welt"}]'
headers:
Connection:
- keep-alive
Content-Length:
- '35'
Content-Type:
- application/json
Date:
- Sun, 20 Aug 2023 14:32:31 GMT
access-control-allow-credentials:
- 'true'
vary:
- Origin, Access-Control-Request-Method, Access-Control-Request-Headers
x-compute-time:
- '0.084'
x-compute-type:
- cache
x-request-id:
- FcgAtoJb-BUKbk49I4_Km
x-sha:
- df1b051c49625cf57a3d0d8d3863ed4d13564fe4
status:
code: 200
message: OK
version: 1
4 changes: 4 additions & 0 deletions tests/test_inference_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,10 @@ def test_text_to_speech(self) -> None:
audio = self.client.text_to_speech("Hello world")
self.assertIsInstance(audio, bytes)

def test_translation(self) -> None:
output = self.client.translation("Hello world", model="t5-small")
self.assertEqual(output, "Hallo Welt")

def test_token_classification(self) -> None:
model = "dbmdz/bert-large-cased-finetuned-conll03-english"
output = self.client.token_classification(
Expand Down

0 comments on commit e6f6760

Please sign in to comment.