Skip to content

Commit

Permalink
add language support to translation client, solves #1763 (#1869)
Browse files Browse the repository at this point in the history
* add language support to translation client, solves #1763

* Update tests/test_inference_client.py

Co-authored-by: Lucain <lucainp@gmail.com>

* Update tests/test_inference_client.py

Co-authored-by: Lucain <lucainp@gmail.com>

* Update src/huggingface_hub/inference/_client.py

Co-authored-by: Lucain <lucainp@gmail.com>

* update the async client to match

* add cassette for translation tests

* Update src/huggingface_hub/inference/_client.py

* Apply suggestions from code review

* Update src/huggingface_hub/inference/_generated/_async_client.py

---------

Co-authored-by: Lucain <lucainp@gmail.com>
  • Loading branch information
ceferisbarov and Wauplin authored Nov 28, 2023
1 parent 8aec94b commit 32b8d56
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 6 deletions.
34 changes: 31 additions & 3 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1700,12 +1700,17 @@ def token_classification(self, text: str, *, model: Optional[str] = None) -> Lis
)
return _bytes_to_list(response)

def translation(self, text: str, *, model: Optional[str] = None) -> str:
def translation(
self, text: str, *, model: Optional[str] = None, src_lang: Optional[str] = None, tgt_lang: Optional[str] = None
) -> str:
"""
Convert text from one language to another.
Check out https://huggingface.co/tasks/translation for more information on how to choose the best model for
your specific use case. Source and target languages usually depends on the model.
your specific use case. Source and target languages usually depend on the model.
However, it is possible to specify source and target languages for certain models. If you are working with one of these models,
you can use `src_lang` and `tgt_lang` arguments to pass the relevant information.
You can find this information in the model card.
Args:
text (`str`):
Expand All @@ -1714,6 +1719,10 @@ def translation(self, text: str, *, model: Optional[str] = None) -> str:
The model to use for the translation task. Can be a model ID hosted on the Hugging Face Hub or a URL to
a deployed Inference Endpoint. If not provided, the default recommended translation model will be used.
Defaults to None.
src_lang (`str`, *optional*):
Source language of the translation task, i.e. input language. Cannot be passed without `tgt_lang`.
tgt_lang (`str`, *optional*):
Target language of the translation task, i.e. output language. Cannot be passed without `src_lang`.
Returns:
`str`: The generated translated text.
Expand All @@ -1723,6 +1732,8 @@ def translation(self, text: str, *, model: Optional[str] = None) -> str:
If the model is unavailable or the request times out.
`HTTPError`:
If the request fails with an HTTP error status code other than HTTP 503.
`ValueError`:
If only one of the `src_lang` and `tgt_lang` arguments are provided.
Example:
```py
Expand All @@ -1733,8 +1744,25 @@ def translation(self, text: str, *, model: Optional[str] = None) -> str:
>>> client.translation("My name is Wolfgang and I live in Berlin", model="Helsinki-NLP/opus-mt-en-fr")
"Je m'appelle Wolfgang et je vis à Berlin."
```
Specifying languages:
```py
>>> client.translation("My name is Sarah Jessica Parker but you can call me Jessica", model="facebook/mbart-large-50-many-to-many-mmt", src_lang="en_XX", tgt_lang="fr_XX")
"Mon nom est Sarah Jessica Parker mais vous pouvez m\'appeler Jessica"
```
"""
response = self.post(json={"inputs": text}, model=model, task="translation")
# Throw error if only one of `src_lang` and `tgt_lang` was given
if src_lang is not None and tgt_lang is None:
raise ValueError("You cannot specify `src_lang` without specifying `tgt_lang`.")

if src_lang is None and tgt_lang is not None:
raise ValueError("You cannot specify `tgt_lang` without specifying `src_lang`.")

# If both `src_lang` and `tgt_lang` are given, pass them to the request body
payload: Dict = {"inputs": text}
if src_lang and tgt_lang:
payload["parameters"] = {"src_lang": src_lang, "tgt_lang": tgt_lang}
response = self.post(json=payload, model=model, task="translation")
return _bytes_to_dict(response)[0]["translation_text"]

def zero_shot_classification(
Expand Down
34 changes: 31 additions & 3 deletions src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1725,12 +1725,17 @@ async def token_classification(self, text: str, *, model: Optional[str] = None)
)
return _bytes_to_list(response)

async def translation(self, text: str, *, model: Optional[str] = None) -> str:
async def translation(
self, text: str, *, model: Optional[str] = None, src_lang: Optional[str] = None, tgt_lang: Optional[str] = None
) -> str:
"""
Convert text from one language to another.
Check out https://huggingface.co/tasks/translation for more information on how to choose the best model for
your specific use case. Source and target languages usually depends on the model.
your specific use case. Source and target languages usually depend on the model.
However, it is possible to specify source and target languages for certain models. If you are working with one of these models,
you can use `src_lang` and `tgt_lang` arguments to pass the relevant information.
You can find this information in the model card.
Args:
text (`str`):
Expand All @@ -1739,6 +1744,10 @@ async def translation(self, text: str, *, model: Optional[str] = None) -> str:
The model to use for the translation task. Can be a model ID hosted on the Hugging Face Hub or a URL to
a deployed Inference Endpoint. If not provided, the default recommended translation model will be used.
Defaults to None.
src_lang (`str`, *optional*):
Source language of the translation task, i.e. input language. Cannot be passed without `tgt_lang`.
tgt_lang (`str`, *optional*):
Target language of the translation task, i.e. output language. Cannot be passed without `src_lang`.
Returns:
`str`: The generated translated text.
Expand All @@ -1748,6 +1757,8 @@ async def translation(self, text: str, *, model: Optional[str] = None) -> str:
If the model is unavailable or the request times out.
`aiohttp.ClientResponseError`:
If the request fails with an HTTP error status code other than HTTP 503.
`ValueError`:
If only one of the `src_lang` and `tgt_lang` arguments are provided.
Example:
```py
Expand All @@ -1759,8 +1770,25 @@ async def translation(self, text: str, *, model: Optional[str] = None) -> str:
>>> await client.translation("My name is Wolfgang and I live in Berlin", model="Helsinki-NLP/opus-mt-en-fr")
"Je m'appelle Wolfgang et je vis à Berlin."
```
Specifying languages:
```py
>>> client.translation("My name is Sarah Jessica Parker but you can call me Jessica", model="facebook/mbart-large-50-many-to-many-mmt", src_lang="en_XX", tgt_lang="fr_XX")
"Mon nom est Sarah Jessica Parker mais vous pouvez m\'appeler Jessica"
```
"""
response = await self.post(json={"inputs": text}, model=model, task="translation")
# Throw error if only one of `src_lang` and `tgt_lang` was given
if src_lang is not None and tgt_lang is None:
raise ValueError("You cannot specify `src_lang` without specifying `tgt_lang`.")

if src_lang is None and tgt_lang is not None:
raise ValueError("You cannot specify `tgt_lang` without specifying `src_lang`.")

# If both `src_lang` and `tgt_lang` are given, pass them to the request body
payload: Dict = {"inputs": text}
if src_lang and tgt_lang:
payload["parameters"] = {"src_lang": src_lang, "tgt_lang": tgt_lang}
response = await self.post(json=payload, model=model, task="translation")
return _bytes_to_dict(response)[0]["translation_text"]

async def zero_shot_classification(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
interactions:
- request:
body: '{"inputs": "Hello world", "parameters": {"src_lang": "en_XX", "tgt_lang":
"fr_XX"}}'
headers:
Accept:
- '*/*'
Accept-Encoding:
- gzip, deflate
Connection:
- keep-alive
Content-Length:
- '83'
Content-Type:
- application/json
X-Amzn-Trace-Id:
- 6cb959e3-b061-4949-8b49-5ed5069382ed
user-agent:
- unknown/None; hf_hub/0.20.0.dev0; python/3.10.12
method: POST
uri: https://api-inference.huggingface.co/models/facebook/mbart-large-50-many-to-many-mmt
response:
body:
string: '[{"translation_text":"Hello world"}]'
headers:
Connection:
- keep-alive
Content-Type:
- application/json
Date:
- Tue, 28 Nov 2023 12:23:24 GMT
Transfer-Encoding:
- chunked
access-control-allow-credentials:
- 'true'
access-control-expose-headers:
- x-compute-type, x-compute-time
server:
- uvicorn
vary:
- Origin, Access-Control-Request-Method, Access-Control-Request-Headers
x-compute-characters:
- '11'
x-compute-time:
- '0.739'
x-compute-type:
- cpu
x-request-id:
- 7qdna_s8iNTqOlR4PvwI-
x-sha:
- e30b6cb8eb0d43a0b73cab73c7676b9863223a30
status:
code: 200
message: OK
version: 1
12 changes: 12 additions & 0 deletions tests/test_inference_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,18 @@ def test_translation(self) -> None:
output = self.client.translation("Hello world")
self.assertEqual(output, "Hallo Welt")

def test_translation_with_source_and_target_language(self) -> None:
output_with_langs = self.client.translation(
"Hello world", model="facebook/mbart-large-50-many-to-many-mmt", src_lang="en_XX", tgt_lang="fr_XX"
)
self.assertIsInstance(output_with_langs, str)

with self.assertRaises(ValueError):
self.client.translation("Hello world", model="facebook/mbart-large-50-many-to-many-mmt", src_lang="en_XX")

with self.assertRaises(ValueError):
self.client.translation("Hello world", model="facebook/mbart-large-50-many-to-many-mmt", tgt_lang="en_XX")

def test_token_classification(self) -> None:
output = self.client.token_classification("My name is Sarah Jessica Parker but you can call me Jessica")
self.assertIsInstance(output, list)
Expand Down

0 comments on commit 32b8d56

Please sign in to comment.