Skip to content

Commit

Permalink
Support src_lang/tgt_lang in InferenceClient.translation(), solves hu…
Browse files Browse the repository at this point in the history
  • Loading branch information
ceferisbarov committed Nov 27, 2023
1 parent 5213acc commit a32b952
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 3 deletions.
33 changes: 30 additions & 3 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1700,12 +1700,15 @@ def token_classification(self, text: str, *, model: Optional[str] = None) -> Lis
)
return _bytes_to_list(response)

def translation(self, text: str, *, model: Optional[str] = None) -> str:
def translation(self, text: str, *, model: Optional[str] = None,
src_lang: Optional[str] = None, tgt_lang: Optional[str] = None) -> str:
"""
Convert text from one language to another.
Check out https://huggingface.co/tasks/translation for more information on how to choose the best model for
your specific use case. Source and target languages usually depends on the model.
your specific use case. Source and target languages usually depend on the model.
However, it is possible to specify source and target languages for certain models. If you are working with one of these models,
you can use `src_lang` and `tgt_lang` arguments to pass the relevant information.
Args:
text (`str`):
Expand All @@ -1714,6 +1717,10 @@ def translation(self, text: str, *, model: Optional[str] = None) -> str:
The model to use for the translation task. Can be a model ID hosted on the Hugging Face Hub or a URL to
a deployed Inference Endpoint. If not provided, the default recommended translation model will be used.
Defaults to None.
src_lang (`str`, *optional*):
Source language of the translation task, i.e. input language. Cannot be passed without `tgt_lang`.
tgt_lang (`str`, *optional*):
Target language of the translation task, i.e. output language. Cannot be passed without `src_lang`.
Returns:
`str`: The generated translated text.
Expand All @@ -1723,6 +1730,8 @@ def translation(self, text: str, *, model: Optional[str] = None) -> str:
If the model is unavailable or the request times out.
`HTTPError`:
If the request fails with an HTTP error status code other than HTTP 503.
`ValueError`:
If only one of the `src_lang` and `tgt_lang` arguments are provided.
Example:
```py
Expand All @@ -1733,8 +1742,26 @@ def translation(self, text: str, *, model: Optional[str] = None) -> str:
>>> client.translation("My name is Wolfgang and I live in Berlin", model="Helsinki-NLP/opus-mt-en-fr")
"Je m'appelle Wolfgang et je vis à Berlin."
```
Specifying languages:
```py
>>> client.translation("My name is Sarah Jessica Parker but you can call me Jessica", model="facebook/mbart-large-50-many-to-many-mmt", "src_lang="en_XX", tgt_lang="fr_XX")
"Mon nom est Sarah Jessica Parker mais vous pouvez m\'appeler Jessica"
```
"""
response = self.post(json={"inputs": text}, model=model, task="translation")
# Throw error if only one of `src_lang` and `tgt_lang` was given
if src_lang is not None and tgt_lang is None:
raise ValueError("You cannot specify `src_lang` without specifying `tgt_lang`.")

if src_lang is None and tgt_lang is not None:
raise ValueError("You cannot specify `tgt_lang` without specifying `src_lang`.")

# If both `src_lang` and `tgt_lang` are given, pass them to the request body
if src_lang and tgt_lang:
response = self.post(json={"inputs": text, "parameters": {"src_lang": src_lang, "tgt_lang": tgt_lang}},
model=model, task="translation")
else:
response = self.post(json={"inputs": text}, model=model, task="translation")
return _bytes_to_dict(response)[0]["translation_text"]

def zero_shot_classification(
Expand Down
7 changes: 7 additions & 0 deletions tests/test_inference_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,13 @@ def test_translation(self) -> None:
output = self.client.translation("Hello world")
self.assertEqual(output, "Hallo Welt")

output_with_langs = self.client.translation("Hello world", model="facebook/mbart-large-50-many-to-many-mmt", src_lang="en_XX", tgt_lang="fr_XX")
self.assertIsInstance(output_with_langs, str)

with self.assertRaises(ValueError):
self.client.translation("Hello world", model="facebook/mbart-large-50-many-to-many-mmt", src_lang="en_XX")
self.client.translation("Hello world", model="facebook/mbart-large-50-many-to-many-mmt", tgt_lang="en_XX")

def test_token_classification(self) -> None:
output = self.client.token_classification("My name is Sarah Jessica Parker but you can call me Jessica")
self.assertIsInstance(output, list)
Expand Down

0 comments on commit a32b952

Please sign in to comment.