Skip to content

Commit

Permalink
Modify default for max_new_tokens in python client (#1336)
Browse files Browse the repository at this point in the history
# What does this PR do?
Since
([#1097](#1097))
the clients do not need to specify a max_length anymore. However, the
python client in this repo had not yet been adapted to these changes.
This PR makes it possible to use the python client and not provide
max_new_tokens.

<!-- Remove if not applicable -->


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [x] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [x] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.
  • Loading branch information
freitng authored Jan 29, 2024
1 parent a9ea606 commit 2d56f10
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 5 deletions.
16 changes: 16 additions & 0 deletions clients/python/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,22 @@ def test_generate(flan_t5_xxl_url, hf_headers):
assert not response.details.tokens[0].special


def test_generate_max_new_tokens_not_set(flan_t5_xxl_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers)
response = client.generate("test", decoder_input_details=True)

assert response.generated_text != ""
assert response.details.finish_reason == FinishReason.EndOfSequenceToken
assert response.details.generated_tokens > 1
assert response.details.seed is None
assert len(response.details.prefill) == 1
assert response.details.prefill[0] == InputToken(id=0, text="<pad>", logprob=None)
assert len(response.details.tokens) > 1
assert response.details.tokens[0].id == 3
assert response.details.tokens[0].text == " "
assert not response.details.tokens[0].special


def test_generate_best_of(flan_t5_xxl_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers)
response = client.generate(
Expand Down
8 changes: 4 additions & 4 deletions clients/python/text_generation/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def generate(
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: int = 20,
max_new_tokens: Optional[int] = None,
best_of: Optional[int] = None,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
Expand Down Expand Up @@ -157,7 +157,7 @@ def generate_stream(
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: int = 20,
max_new_tokens: Optional[int] = None,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
Expand Down Expand Up @@ -312,7 +312,7 @@ async def generate(
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: int = 20,
max_new_tokens: Optional[int] = None,
best_of: Optional[int] = None,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
Expand Down Expand Up @@ -405,7 +405,7 @@ async def generate_stream(
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: int = 20,
max_new_tokens: Optional[int] = None,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
Expand Down
2 changes: 1 addition & 1 deletion clients/python/text_generation/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class Parameters(BaseModel):
# Activate logits sampling
do_sample: bool = False
# Maximum number of generated tokens
max_new_tokens: int = 20
max_new_tokens: Optional[int] = None
# The parameter for repetition penalty. 1.0 means no penalty.
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
repetition_penalty: Optional[float] = None
Expand Down

0 comments on commit 2d56f10

Please sign in to comment.