Skip to content

Commit

Permalink
test async methods signatures
Browse files Browse the repository at this point in the history
  • Loading branch information
Wauplin committed Jul 4, 2023
1 parent 153e2d9 commit adb970d
Showing 1 changed file with 49 additions and 1 deletion.
50 changes: 49 additions & 1 deletion tests/test_inference_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@
For completeness we also run a test on a simple task (`test_async_sentence_similarity`) and assume all other tasks
work as well.
"""
import inspect

import pytest

import huggingface_hub.inference._common
from huggingface_hub import AsyncInferenceClient
from huggingface_hub import AsyncInferenceClient, InferenceClient
from huggingface_hub.inference._common import _is_tgi_server
from huggingface_hub.inference._text_generation import FinishReason, InputToken
from huggingface_hub.inference._text_generation import ValidationError as TextGenerationValidationError
Expand Down Expand Up @@ -150,3 +152,49 @@ async def test_async_sentence_similarity() -> None:
],
)
assert scores == [0.7785726189613342, 0.4587625563144684, 0.2906219959259033]


def test_sync_vs_async_signatures() -> None:
client = InferenceClient()
async_client = AsyncInferenceClient()

# Some methods have to be tested separately.
special_methods = ["post", "text_generation"]

# Post: this is not automatically tested. No need to test its signature separately.

# Text-generation: return type changes from Iterable[...] to AsyncIterable[...] but input parameters are the same
sync_method = getattr(client, "text_generation")
assert not inspect.iscoroutinefunction(sync_method)
async_method = getattr(async_client, "text_generation")
assert inspect.iscoroutinefunction(async_method)

sync_sig = inspect.signature(sync_method)
async_sig = inspect.signature(async_method)
assert sync_sig.parameters == async_sig.parameters
assert sync_sig.return_annotation != async_sig.return_annotation

[name for name in dir(client) if (not name.startswith("_")) and inspect.ismethod(getattr(client, name))]

# Check that all methods are consistent between InferenceClient and AsyncInferenceClient
for name in dir(client):
if not inspect.ismethod(getattr(client, name)): # not a method
continue
if name.startswith("_"): # not public method
continue
if name in special_methods: # tested separately
continue

# Check that the sync method is not async
sync_method = getattr(client, name)
assert not inspect.iscoroutinefunction(sync_method)

# Check that the async method is async
async_method = getattr(async_client, name)
assert inspect.iscoroutinefunction(async_method)

# Check that expected inputs and outputs are the same
sync_sig = inspect.signature(sync_method)
async_sig = inspect.signature(async_method)
assert sync_sig.parameters == async_sig.parameters
assert sync_sig.return_annotation == async_sig.return_annotation

0 comments on commit adb970d

Please sign in to comment.