From 716669f6bcb1ffa8f2a37870d70163514f073de4 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 5 Sep 2023 19:54:52 -0700 Subject: [PATCH] Revert "add best_of and use_beam_search for completions interface (#2348)" This reverts commit f99663cc565c9db1aab20e34ce7f719765a16519. --- fastchat/protocol/api_protocol.py | 2 +- fastchat/protocol/openai_api_protocol.py | 4 +- fastchat/serve/openai_api_server.py | 29 +--------- fastchat/serve/vllm_worker.py | 70 +++++++----------------- 4 files changed, 26 insertions(+), 79 deletions(-) diff --git a/fastchat/protocol/api_protocol.py b/fastchat/protocol/api_protocol.py index 1091f5e5a..7dc8fe1c3 100644 --- a/fastchat/protocol/api_protocol.py +++ b/fastchat/protocol/api_protocol.py @@ -150,7 +150,7 @@ class CompletionResponse(BaseModel): created: int = Field(default_factory=lambda: int(time.time())) model: str choices: List[CompletionResponseChoice] - usage: Union[UsageInfo, List[UsageInfo]] + usage: UsageInfo class CompletionResponseStreamChoice(BaseModel): diff --git a/fastchat/protocol/openai_api_protocol.py b/fastchat/protocol/openai_api_protocol.py index fc3c91ebd..6232e8b9b 100644 --- a/fastchat/protocol/openai_api_protocol.py +++ b/fastchat/protocol/openai_api_protocol.py @@ -151,13 +151,11 @@ class CompletionRequest(BaseModel): presence_penalty: Optional[float] = 0.0 frequency_penalty: Optional[float] = 0.0 user: Optional[str] = None - use_beam_search: Optional[bool] = False - best_of: Optional[int] = None class CompletionResponseChoice(BaseModel): index: int - text: Union[str, List[str]] + text: str logprobs: Optional[int] = None finish_reason: Optional[Literal["stop", "length"]] = None diff --git a/fastchat/serve/openai_api_server.py b/fastchat/serve/openai_api_server.py index 1344ec46f..02e8481f4 100644 --- a/fastchat/serve/openai_api_server.py +++ b/fastchat/serve/openai_api_server.py @@ -238,12 +238,9 @@ async def get_gen_params( *, temperature: float, top_p: float, - best_of: Optional[int], max_tokens: Optional[int], - n: Optional[int], echo: Optional[bool], stop: Optional[Union[str, List[str]]], - use_beam_search: Optional[bool], ) -> Dict[str, Any]: conv = await get_conv(model_name, worker_addr) conv = Conversation( @@ -290,11 +287,6 @@ async def get_gen_params( "stop_token_ids": conv.stop_token_ids, } - if best_of: - gen_params.update({"n": n, "best_of": best_of}) - if use_beam_search is not None: - gen_params.update({"use_beam_search": use_beam_search}) - new_stop = set() _add_to_set(stop, new_stop) _add_to_set(conv.stop_str, new_stop) @@ -499,21 +491,15 @@ async def create_completion(request: CompletionRequest): text, temperature=request.temperature, top_p=request.top_p, - best_of=request.best_of, max_tokens=request.max_tokens, - n=request.n, echo=request.echo, stop=request.stop, - use_beam_search=request.use_beam_search, ) for i in range(request.n): content = asyncio.create_task( generate_completion(gen_params, worker_addr) ) text_completions.append(content) - # when use with best_of, only need send one request - if request.best_of: - break try: all_tasks = await asyncio.gather(*text_completions) @@ -533,18 +519,9 @@ async def create_completion(request: CompletionRequest): finish_reason=content.get("finish_reason", "stop"), ) ) - idx = 0 - while True: - info = content["usage"] - if isinstance(info, list): - info = info[idx] - - task_usage = UsageInfo.parse_obj(info) - - for usage_key, usage_value in task_usage.dict().items(): - setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) - idx += 1 - break + task_usage = UsageInfo.parse_obj(content["usage"]) + for usage_key, usage_value in task_usage.dict().items(): + setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) return CompletionResponse( model=request.model, choices=choices, usage=UsageInfo.parse_obj(usage) diff --git a/fastchat/serve/vllm_worker.py b/fastchat/serve/vllm_worker.py index 71a30f890..8e255b79c 100644 --- a/fastchat/serve/vllm_worker.py +++ b/fastchat/serve/vllm_worker.py @@ -18,7 +18,6 @@ from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid -from fastchat.constants import ErrorCode, SERVER_ERROR_MSG from fastchat.serve.model_worker import ( BaseModelWorker, logger, @@ -75,9 +74,6 @@ async def generate_stream(self, params): if self.tokenizer.eos_token_id is not None: stop_token_ids.append(self.tokenizer.eos_token_id) echo = params.get("echo", True) - use_beam_search = params.get("use_beam_search", False) - best_of = params.get("best_of", None) - n = params.get("n", 1) # Handle stop_str stop = set() @@ -94,51 +90,27 @@ async def generate_stream(self, params): top_p = max(top_p, 1e-5) if temperature <= 1e-5: top_p = 1.0 - try: - sampling_params = SamplingParams( - n=n, - temperature=temperature, - top_p=top_p, - use_beam_search=use_beam_search, - stop=list(stop), - max_tokens=max_new_tokens, - best_of=best_of, - ) - - results_generator = engine.generate(context, sampling_params, request_id) - - async for request_output in results_generator: - prompt = request_output.prompt - prompt_tokens = len(request_output.prompt_token_ids) - output_usage = [] - for out in request_output.outputs: - completion_tokens = len(out.token_ids) - total_tokens = prompt_tokens + completion_tokens - output_usage.append( - { - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "total_tokens": total_tokens, - } - ) - - if echo: - text_outputs = [ - prompt + output.text for output in request_output.outputs - ] - else: - text_outputs = [output.text for output in request_output.outputs] - - if sampling_params.best_of is None: - text_outputs = [" ".join(text_outputs)] - ret = {"text": text_outputs, "error_code": 0, "usage": output_usage} - yield (json.dumps(ret) + "\0").encode() - except (ValueError, RuntimeError) as e: - ret = { - "text": f"{e}", - "error_code": ErrorCode.PARAM_OUT_OF_RANGE, - "usage": {}, - } + sampling_params = SamplingParams( + n=1, + temperature=temperature, + top_p=top_p, + use_beam_search=False, + stop=list(stop), + max_tokens=max_new_tokens, + ) + results_generator = engine.generate(context, sampling_params, request_id) + + async for request_output in results_generator: + prompt = request_output.prompt + if echo: + text_outputs = [ + prompt + output.text for output in request_output.outputs + ] + else: + text_outputs = [output.text for output in request_output.outputs] + text_outputs = " ".join(text_outputs) + # Note: usage is not supported yet + ret = {"text": text_outputs, "error_code": 0, "usage": {}} yield (json.dumps(ret) + "\0").encode() async def generate(self, params):