From 9d39fd53c0bcaa2567b68ebd0428c37d9e2be96f Mon Sep 17 00:00:00 2001 From: guogeer <1500065870@qq.com> Date: Thu, 24 Oct 2024 21:51:36 +0800 Subject: [PATCH] openai compatiable api usage and id (#9800) Co-authored-by: jinqi.guo --- .../model_runtime/entities/llm_entities.py | 1 + .../openai_api_compatible/llm/llm.py | 33 +++++++++++++++---- 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/api/core/model_runtime/entities/llm_entities.py b/api/core/model_runtime/entities/llm_entities.py index 52b590f66a3873..88531d8ae00037 100644 --- a/api/core/model_runtime/entities/llm_entities.py +++ b/api/core/model_runtime/entities/llm_entities.py @@ -105,6 +105,7 @@ class LLMResult(BaseModel): Model class for llm result. """ + id: Optional[str] = None model: str prompt_messages: list[PromptMessage] message: AssistantPromptMessage diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py index 356ac56b1e4b2f..e1342fe985ac13 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py @@ -397,16 +397,21 @@ def _handle_generate_stream_response( chunk_index = 0 def create_final_llm_result_chunk( - index: int, message: AssistantPromptMessage, finish_reason: str + id: Optional[str], index: int, message: AssistantPromptMessage, finish_reason: str, usage: dict ) -> LLMResultChunk: # calculate num tokens - prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content) - completion_tokens = self._num_tokens_from_string(model, full_assistant_content) + prompt_tokens = usage and usage.get("prompt_tokens") + if prompt_tokens is None: + prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content) + completion_tokens = usage and usage.get("completion_tokens") + if completion_tokens is None: + completion_tokens = self._num_tokens_from_string(model, full_assistant_content) # transform usage usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) return LLMResultChunk( + id=id, model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage), @@ -450,7 +455,7 @@ def get_tool_call(tool_call_id: str): tool_call.function.arguments += new_tool_call.function.arguments finish_reason = None # The default value of finish_reason is None - + message_id, usage = None, None for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter): chunk = chunk.strip() if chunk: @@ -462,20 +467,26 @@ def get_tool_call(tool_call_id: str): continue try: - chunk_json = json.loads(decoded_chunk) + chunk_json: dict = json.loads(decoded_chunk) # stream ended except json.JSONDecodeError as e: yield create_final_llm_result_chunk( + id=message_id, index=chunk_index + 1, message=AssistantPromptMessage(content=""), finish_reason="Non-JSON encountered.", + usage=usage, ) break + if chunk_json: + if u := chunk_json.get("usage"): + usage = u if not chunk_json or len(chunk_json["choices"]) == 0: continue choice = chunk_json["choices"][0] finish_reason = chunk_json["choices"][0].get("finish_reason") + message_id = chunk_json.get("id") chunk_index += 1 if "delta" in choice: @@ -524,6 +535,7 @@ def get_tool_call(tool_call_id: str): continue yield LLMResultChunk( + id=message_id, model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( @@ -536,6 +548,7 @@ def get_tool_call(tool_call_id: str): if tools_calls: yield LLMResultChunk( + id=message_id, model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( @@ -545,17 +558,22 @@ def get_tool_call(tool_call_id: str): ) yield create_final_llm_result_chunk( - index=chunk_index, message=AssistantPromptMessage(content=""), finish_reason=finish_reason + id=message_id, + index=chunk_index, + message=AssistantPromptMessage(content=""), + finish_reason=finish_reason, + usage=usage, ) def _handle_generate_response( self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage] ) -> LLMResult: - response_json = response.json() + response_json: dict = response.json() completion_type = LLMMode.value_of(credentials["mode"]) output = response_json["choices"][0] + message_id = response_json.get("id") response_content = "" tool_calls = None @@ -593,6 +611,7 @@ def _handle_generate_response( # transform response result = LLMResult( + id=message_id, model=response_json["model"], prompt_messages=prompt_messages, message=assistant_message,