From adb970d1dfe739d79b570a52a3b3c5d632a5f0d9 Mon Sep 17 00:00:00 2001 From: Lucain Pouget Date: Tue, 4 Jul 2023 15:38:47 +0200 Subject: [PATCH] test async methods signatures --- tests/test_inference_async_client.py | 50 +++++++++++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/tests/test_inference_async_client.py b/tests/test_inference_async_client.py index 47e0a19eaf..1c98528ea6 100644 --- a/tests/test_inference_async_client.py +++ b/tests/test_inference_async_client.py @@ -22,10 +22,12 @@ For completeness we also run a test on a simple task (`test_async_sentence_similarity`) and assume all other tasks work as well. """ +import inspect + import pytest import huggingface_hub.inference._common -from huggingface_hub import AsyncInferenceClient +from huggingface_hub import AsyncInferenceClient, InferenceClient from huggingface_hub.inference._common import _is_tgi_server from huggingface_hub.inference._text_generation import FinishReason, InputToken from huggingface_hub.inference._text_generation import ValidationError as TextGenerationValidationError @@ -150,3 +152,49 @@ async def test_async_sentence_similarity() -> None: ], ) assert scores == [0.7785726189613342, 0.4587625563144684, 0.2906219959259033] + + +def test_sync_vs_async_signatures() -> None: + client = InferenceClient() + async_client = AsyncInferenceClient() + + # Some methods have to be tested separately. + special_methods = ["post", "text_generation"] + + # Post: this is not automatically tested. No need to test its signature separately. + + # Text-generation: return type changes from Iterable[...] to AsyncIterable[...] but input parameters are the same + sync_method = getattr(client, "text_generation") + assert not inspect.iscoroutinefunction(sync_method) + async_method = getattr(async_client, "text_generation") + assert inspect.iscoroutinefunction(async_method) + + sync_sig = inspect.signature(sync_method) + async_sig = inspect.signature(async_method) + assert sync_sig.parameters == async_sig.parameters + assert sync_sig.return_annotation != async_sig.return_annotation + + [name for name in dir(client) if (not name.startswith("_")) and inspect.ismethod(getattr(client, name))] + + # Check that all methods are consistent between InferenceClient and AsyncInferenceClient + for name in dir(client): + if not inspect.ismethod(getattr(client, name)): # not a method + continue + if name.startswith("_"): # not public method + continue + if name in special_methods: # tested separately + continue + + # Check that the sync method is not async + sync_method = getattr(client, name) + assert not inspect.iscoroutinefunction(sync_method) + + # Check that the async method is async + async_method = getattr(async_client, name) + assert inspect.iscoroutinefunction(async_method) + + # Check that expected inputs and outputs are the same + sync_sig = inspect.signature(sync_method) + async_sig = inspect.signature(async_method) + assert sync_sig.parameters == async_sig.parameters + assert sync_sig.return_annotation == async_sig.return_annotation