Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support text-generation in InferenceClient #1513

Merged
merged 23 commits into from
Jun 27, 2023

Conversation

Wauplin
Copy link
Contributor

@Wauplin Wauplin commented Jun 16, 2023

Original implementation taken from the text-generation-inference Python client (see client library and repo) from @OlivierDehaene . A vast majority of the code comes from there so kudos goes to him 🙏 .

Changes compared to the original implementation:

  • use pydantic.dataclasses instead of BaseModel
  • default to Python's dataclasses if pydantic is not installed (meaning pydantic/arg validation is recommended but optional)
    • added default values for optional parameters (not needed in BaseModel but dataclasses yes)
  • integrated in huggingface_hub.InferenceClient
  • added stream: bool and details: bool in the text_generation method instead of having different methods for each use case
  • NO asyncio support yet => TODO in a next PR

If model is not served with TGI backend (example: "gpt2"-like models), some parameters are ignored. The client always consider that TGI is enable but default back to a normal call if that's not the case. A warning is triggered for the user + details=True is not possible.

Integration is now functional and locally tested.

Docs: text_generation and dataclasses descriptions.
.
TODO:

  • handle models not served with text-generation-inference backend
  • add unit tests
  • add documentation / reference

Normal use (with details)

from huggingface_hub import InferenceClient

client = InferenceClient()
output = client.text_generation("\ndef hello_world(name: str)", model="bigcode/starcoder", stream=False, details=True)
print(output)
TextGenerationResponse(generated_text=' -> str:\n    return f"Hello {name}"\n\n\ndef test_hello_world():', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=20, seed=None, prefill=[], tokens=[Token(id=967, text=' ->', logprob=-0.0009317398, special=False), Token(id=596, text=' str', logprob=-0.25805664, special=False), Token(id=44, text=':', logprob=-0.0005950928, special=False), Token(id=284, text='\n   ', logprob=-0.07080078, special=False), Token(id=442, text=' return', logprob=-0.6455078, special=False), Token(id=296, text=' f', logprob=-0.40722656, special=False), Token(id=20, text='"', logprob=-0.37060547, special=False), Token(id=8279, text='Hello', logprob=-0.11077881, special=False), Token(id=301, text=' {', logprob=-0.67626953, special=False), Token(id=426, text='name', logprob=-0.006881714, special=False), Token(id=3845, text='}"', logprob=-0.72021484, special=False), Token(id=478, text='\n\n', logprob=-0.56689453, special=False), Token(id=203, text='\n', logprob=-0.009246826, special=False), Token(id=589, text='def', logprob=-1.2724609, special=False), Token(id=894, text=' test', logprob=-1.5761719, special=False), Token(id=81, text='_', logprob=-0.012329102, special=False), Token(id=7656, text='hello', logprob=-1.8945312, special=False), Token(id=81, text='_', logprob=-0.071777344, special=False), Token(id=5860, text='world', logprob=-0.10595703, special=False), Token(id=2262, text='():', logprob=-0.7866211, special=False)], best_of_sequences=None))

Streaming (no details)

for output in client.text_generation(
    "Number between 1 and 1000: 1, 2, ", max_new_tokens=100, model="bigcode/starcoder", stream=True, details=False
):
    print(output, sep="", end="", flush=True)
3, 4, 5, 6, 7, 8, 9, (...)

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 16, 2023

The documentation is not available anymore as the PR was closed or merged.

@codecov
Copy link

codecov bot commented Jun 16, 2023

Codecov Report

Patch coverage: 93.42% and project coverage change: +4.69 🎉

Comparison is base (fd1494a) 77.97% compared to head (b4e30cf) 82.67%.

❗ Current head b4e30cf differs from pull request most recent head 678784f. Consider uploading reports for the commit 678784f to get more accurate results

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1513      +/-   ##
==========================================
+ Coverage   77.97%   82.67%   +4.69%     
==========================================
  Files          55       58       +3     
  Lines        5835     6332     +497     
==========================================
+ Hits         4550     5235     +685     
+ Misses       1285     1097     -188     
Impacted Files Coverage Δ
src/huggingface_hub/utils/__init__.py 100.00% <ø> (ø)
src/huggingface_hub/utils/_runtime.py 55.86% <60.00%> (+3.71%) ⬆️
src/huggingface_hub/inference/_client.py 81.96% <88.88%> (+3.61%) ⬆️
src/huggingface_hub/inference/_text_generation.py 96.25% <96.25%> (ø)

... and 27 files with indirect coverage changes

☔ View full report in Codecov by Sentry.
📢 Do you have feedback about the report comment? Let us know in this issue.

@Wauplin Wauplin changed the title [WIP] Support text-generation in InferenceClient Support text-generation in InferenceClient Jun 16, 2023
Copy link
Member

@OlivierDehaene OlivierDehaene left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the backend forces details = True if decoder_input_details = True so maybe it should be the same here.

src/huggingface_hub/inference/_client.py Outdated Show resolved Hide resolved
src/huggingface_hub/inference/_client.py Show resolved Hide resolved
src/huggingface_hub/inference/_client.py Outdated Show resolved Hide resolved
# Whether to prepend the prompt to the generated text
return_full_text: bool = False
# Stop generating tokens if a member of `stop_sequences` is generated
stop: List[str] = field(default_factory=lambda: [])

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed we are not validating the stop sequences. I believe the max is 4

Copy link
Contributor Author

@Wauplin Wauplin Jun 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the heads up @StephenHodgson!

@OlivierDehaene could you confirm this 4-items limit? I found the same in the openapi.json specs but prefer to cross-check with you

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_best_of and max_stop_sequences are both parameters that can be modified in TGI. The defaults are 2 and 4 but they can also be turned off.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok so let's not validate them client-side. Thanks for confirming

@Wauplin Wauplin marked this pull request as ready for review June 20, 2023 16:57
@Wauplin Wauplin requested a review from LysandreJik June 21, 2023 08:10
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Impressive conditional type response decided by the values of details/stream! Is this so that pydantic behaves correctly?

I didn't find a guide/example showcasing how you'd use it in practice. I think it'd be worth adding to the docs as a docstring example.

For example, four small examples showcasing the difference between the details and stream modes and self-explanatory results about their content. I'm sure a motivated user could go and search for the docs of TextGenerationResponse an Token, but something like this would get that info straight away in the doc page they'd be looking at:


This method has four different return possibilities, according to the value of the details and stream parameters passed to the method.

With details and stream as False:

>>> from huggingface_hub import InferenceClient
>>> client = InferenceClient()
>>> client.text_generation("What's up")

" with the weather?\nI'm sorry, I am an AI language model and do not have"

If details is True and stream is False:

>>> client.text_generation("What's up", details=True)

TextGenerationResponse(
    generated_text=" with the weather?\nI'm sorry, I am an AI language model and do not have", 
    details=Details(
        finish_reason=<FinishReason.Length: 'length'>, 
        generated_tokens=20, 
        seed=None, 
        prefill=[InputToken(id=1562, text='What', logprob=None), InputToken(id=18, text="'", logprob=-2.5390625), InputToken(id=94, text='s', logprob=-0.13061523), InputToken(id=510, text=' up', logprob=-4.2382812)], 
        tokens=[Token(id=335, text=' with', logprob=-1.4267578, special=False), Token(id=248, text=' the', logprob=-1.4677734, special=False), Token(id=5015, text=' weather', logprob=-3.15625, special=False), Token(id=42, text='?', logprob=-0.4638672, special=False), Token(id=193, text='\n', logprob=-0.06262207, special=False), Token(id=52, text='I', logprob=-0.22729492, special=False), Token(id=18, text="'", logprob=-0.0769043, special=False), Token(id=88, text='m', logprob=-0.0022010803, special=False), Token(id=6893, text=' sorry', logprob=-0.027435303, special=False), Token(id=23, text=',', logprob=-0.033081055, special=False), Token(id=295, text=' I', logprob=-0.7036133, special=False), Token(id=653, text=' am', logprob=-0.9658203, special=False), Token(id=267, text=' an', logprob=-0.3400879, special=False), Token(id=8317, text=' AI', logprob=-0.052856445, special=False), Token(id=3599, text=' language', logprob=-0.0051193237, special=False), Token(id=2308, text=' model', logprob=-0.0007414818, special=False), Token(id=273, text=' and', logprob=-0.019866943, special=False), Token(id=441, text=' do', logprob=-0.40722656, special=False), Token(id=416, text=' not', logprob=-0.00076293945, special=False), Token(id=413, text=' have', logprob=-0.011177063, special=False)], 
        best_of_sequences=None
    )
)

...


Overall, it works very well. The tests look good. The cassettes could be offloaded to a dataset repo if you don't want to weigh down the repo, but no strong need to move it out as it's code only.

Comment on lines 876 to 880
`Union[str, TextGenerationResponse, Iterable[str], Iterable[TextGenerationStreamResponse]]`: generated response.
Format depends on the input. If `details=False` (the default), the generated text is returned as a string. If
`details=False` and `stream=True`, an `Iterable[str]` is returned. If `details=True`, a [`~huggingface_hub.inference._text_generation.TextGenerationResponse`]
object is returned, containing details about the generated text. Finally, if `details=True` and `stream=True`
are passed, an iterable of [`~huggingface_hub.inference._text_generation.TextGenerationStreamResponse`] is passed.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit hard to read when converted to the docs. Would it make sense to have it be a list with the different possibilities?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the docs. Much cleaner this way! 😄

2023-06-26_16-15

@Wauplin
Copy link
Contributor Author

Wauplin commented Jun 26, 2023

Thanks @LysandreJik for the review and suggesting some improvements in the docs! ❤️ I've made the requested changes (mainly the return type + adding some examples). Check it out in the docs.

Impressive conditional type response decided by the values of details/stream! Is this so that pydantic behaves correctly?

About the return type with the @overload type annotation, the main goal is to help with IDE autocompletion / type checks in downstream libraries. It tells any tool what will be the returned object depending on the details and stream input values. In theory, Pydantic is quite good at validating things with Union[...] but I'm not using it here for this.

The cassettes could be offloaded to a dataset repo if you don't want to weigh down the repo, but no strong need to move it out as it's code only.

Yes that could be a possibility but for now I'd prefer not to, just for the sake of keeping it simple. The main advantage of having it directly in the repo is that updating the cassette is as simple as running pytest with --vcr-record=all. No need to re-upload the fixed cassettes which makes the process simpler. I might do it in a future PR if it become too heavy.

@Wauplin Wauplin merged commit 909fbd6 into main Jun 27, 2023
@Wauplin Wauplin deleted the support-tgi-in-inference-client-2 branch June 27, 2023 07:41
@Wauplin Wauplin mentioned this pull request Jun 29, 2023
7 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants