diff --git a/src/huggingface_hub/inference/_client.py b/src/huggingface_hub/inference/_client.py index b4270971ec..3a76f5ad15 100644 --- a/src/huggingface_hub/inference/_client.py +++ b/src/huggingface_hub/inference/_client.py @@ -1700,12 +1700,17 @@ def token_classification(self, text: str, *, model: Optional[str] = None) -> Lis ) return _bytes_to_list(response) - def translation(self, text: str, *, model: Optional[str] = None) -> str: + def translation( + self, text: str, *, model: Optional[str] = None, src_lang: Optional[str] = None, tgt_lang: Optional[str] = None + ) -> str: """ Convert text from one language to another. Check out https://huggingface.co/tasks/translation for more information on how to choose the best model for - your specific use case. Source and target languages usually depends on the model. + your specific use case. Source and target languages usually depend on the model. + However, it is possible to specify source and target languages for certain models. If you are working with one of these models, + you can use `src_lang` and `tgt_lang` arguments to pass the relevant information. + You can find this information in the model card. Args: text (`str`): @@ -1714,6 +1719,10 @@ def translation(self, text: str, *, model: Optional[str] = None) -> str: The model to use for the translation task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended translation model will be used. Defaults to None. + src_lang (`str`, *optional*): + Source language of the translation task, i.e. input language. Cannot be passed without `tgt_lang`. + tgt_lang (`str`, *optional*): + Target language of the translation task, i.e. output language. Cannot be passed without `src_lang`. Returns: `str`: The generated translated text. @@ -1723,6 +1732,8 @@ def translation(self, text: str, *, model: Optional[str] = None) -> str: If the model is unavailable or the request times out. `HTTPError`: If the request fails with an HTTP error status code other than HTTP 503. + `ValueError`: + If only one of the `src_lang` and `tgt_lang` arguments are provided. Example: ```py @@ -1733,8 +1744,25 @@ def translation(self, text: str, *, model: Optional[str] = None) -> str: >>> client.translation("My name is Wolfgang and I live in Berlin", model="Helsinki-NLP/opus-mt-en-fr") "Je m'appelle Wolfgang et je vis à Berlin." ``` + + Specifying languages: + ```py + >>> client.translation("My name is Sarah Jessica Parker but you can call me Jessica", model="facebook/mbart-large-50-many-to-many-mmt", src_lang="en_XX", tgt_lang="fr_XX") + "Mon nom est Sarah Jessica Parker mais vous pouvez m\'appeler Jessica" + ``` """ - response = self.post(json={"inputs": text}, model=model, task="translation") + # Throw error if only one of `src_lang` and `tgt_lang` was given + if src_lang is not None and tgt_lang is None: + raise ValueError("You cannot specify `src_lang` without specifying `tgt_lang`.") + + if src_lang is None and tgt_lang is not None: + raise ValueError("You cannot specify `tgt_lang` without specifying `src_lang`.") + + # If both `src_lang` and `tgt_lang` are given, pass them to the request body + payload: Dict = {"inputs": text} + if src_lang and tgt_lang: + payload["parameters"] = {"src_lang": src_lang, "tgt_lang": tgt_lang} + response = self.post(json=payload, model=model, task="translation") return _bytes_to_dict(response)[0]["translation_text"] def zero_shot_classification( diff --git a/src/huggingface_hub/inference/_generated/_async_client.py b/src/huggingface_hub/inference/_generated/_async_client.py index c5b5daec11..67a6128b81 100644 --- a/src/huggingface_hub/inference/_generated/_async_client.py +++ b/src/huggingface_hub/inference/_generated/_async_client.py @@ -1725,12 +1725,17 @@ async def token_classification(self, text: str, *, model: Optional[str] = None) ) return _bytes_to_list(response) - async def translation(self, text: str, *, model: Optional[str] = None) -> str: + async def translation( + self, text: str, *, model: Optional[str] = None, src_lang: Optional[str] = None, tgt_lang: Optional[str] = None + ) -> str: """ Convert text from one language to another. Check out https://huggingface.co/tasks/translation for more information on how to choose the best model for - your specific use case. Source and target languages usually depends on the model. + your specific use case. Source and target languages usually depend on the model. + However, it is possible to specify source and target languages for certain models. If you are working with one of these models, + you can use `src_lang` and `tgt_lang` arguments to pass the relevant information. + You can find this information in the model card. Args: text (`str`): @@ -1739,6 +1744,10 @@ async def translation(self, text: str, *, model: Optional[str] = None) -> str: The model to use for the translation task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended translation model will be used. Defaults to None. + src_lang (`str`, *optional*): + Source language of the translation task, i.e. input language. Cannot be passed without `tgt_lang`. + tgt_lang (`str`, *optional*): + Target language of the translation task, i.e. output language. Cannot be passed without `src_lang`. Returns: `str`: The generated translated text. @@ -1748,6 +1757,8 @@ async def translation(self, text: str, *, model: Optional[str] = None) -> str: If the model is unavailable or the request times out. `aiohttp.ClientResponseError`: If the request fails with an HTTP error status code other than HTTP 503. + `ValueError`: + If only one of the `src_lang` and `tgt_lang` arguments are provided. Example: ```py @@ -1759,8 +1770,25 @@ async def translation(self, text: str, *, model: Optional[str] = None) -> str: >>> await client.translation("My name is Wolfgang and I live in Berlin", model="Helsinki-NLP/opus-mt-en-fr") "Je m'appelle Wolfgang et je vis à Berlin." ``` + + Specifying languages: + ```py + >>> client.translation("My name is Sarah Jessica Parker but you can call me Jessica", model="facebook/mbart-large-50-many-to-many-mmt", src_lang="en_XX", tgt_lang="fr_XX") + "Mon nom est Sarah Jessica Parker mais vous pouvez m\'appeler Jessica" + ``` """ - response = await self.post(json={"inputs": text}, model=model, task="translation") + # Throw error if only one of `src_lang` and `tgt_lang` was given + if src_lang is not None and tgt_lang is None: + raise ValueError("You cannot specify `src_lang` without specifying `tgt_lang`.") + + if src_lang is None and tgt_lang is not None: + raise ValueError("You cannot specify `tgt_lang` without specifying `src_lang`.") + + # If both `src_lang` and `tgt_lang` are given, pass them to the request body + payload: Dict = {"inputs": text} + if src_lang and tgt_lang: + payload["parameters"] = {"src_lang": src_lang, "tgt_lang": tgt_lang} + response = await self.post(json=payload, model=model, task="translation") return _bytes_to_dict(response)[0]["translation_text"] async def zero_shot_classification( diff --git a/tests/cassettes/InferenceClientVCRTest.test_translation_with_source_and_target_language.yaml b/tests/cassettes/InferenceClientVCRTest.test_translation_with_source_and_target_language.yaml new file mode 100644 index 0000000000..99d531b506 --- /dev/null +++ b/tests/cassettes/InferenceClientVCRTest.test_translation_with_source_and_target_language.yaml @@ -0,0 +1,55 @@ +interactions: +- request: + body: '{"inputs": "Hello world", "parameters": {"src_lang": "en_XX", "tgt_lang": + "fr_XX"}}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate + Connection: + - keep-alive + Content-Length: + - '83' + Content-Type: + - application/json + X-Amzn-Trace-Id: + - 6cb959e3-b061-4949-8b49-5ed5069382ed + user-agent: + - unknown/None; hf_hub/0.20.0.dev0; python/3.10.12 + method: POST + uri: https://api-inference.huggingface.co/models/facebook/mbart-large-50-many-to-many-mmt + response: + body: + string: '[{"translation_text":"Hello world"}]' + headers: + Connection: + - keep-alive + Content-Type: + - application/json + Date: + - Tue, 28 Nov 2023 12:23:24 GMT + Transfer-Encoding: + - chunked + access-control-allow-credentials: + - 'true' + access-control-expose-headers: + - x-compute-type, x-compute-time + server: + - uvicorn + vary: + - Origin, Access-Control-Request-Method, Access-Control-Request-Headers + x-compute-characters: + - '11' + x-compute-time: + - '0.739' + x-compute-type: + - cpu + x-request-id: + - 7qdna_s8iNTqOlR4PvwI- + x-sha: + - e30b6cb8eb0d43a0b73cab73c7676b9863223a30 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/test_inference_client.py b/tests/test_inference_client.py index d919fc4fe2..53a4a4bdad 100644 --- a/tests/test_inference_client.py +++ b/tests/test_inference_client.py @@ -324,6 +324,18 @@ def test_translation(self) -> None: output = self.client.translation("Hello world") self.assertEqual(output, "Hallo Welt") + def test_translation_with_source_and_target_language(self) -> None: + output_with_langs = self.client.translation( + "Hello world", model="facebook/mbart-large-50-many-to-many-mmt", src_lang="en_XX", tgt_lang="fr_XX" + ) + self.assertIsInstance(output_with_langs, str) + + with self.assertRaises(ValueError): + self.client.translation("Hello world", model="facebook/mbart-large-50-many-to-many-mmt", src_lang="en_XX") + + with self.assertRaises(ValueError): + self.client.translation("Hello world", model="facebook/mbart-large-50-many-to-many-mmt", tgt_lang="en_XX") + def test_token_classification(self) -> None: output = self.client.token_classification("My name is Sarah Jessica Parker but you can call me Jessica") self.assertIsInstance(output, list)