Skip to content

Commit

Permalink
♻️ Refactor and modularize stream handling functions
Browse files Browse the repository at this point in the history
Refactor stream handling logic by introducing new helper functions. This change improves readability and maintainability of the code by encapsulating repetitive logic in dedicated functions like `process_chunk`, `handle_context_window_error`, and `process_finish_reason`.
  • Loading branch information
redadmiral committed Sep 20, 2024
1 parent 079a34f commit 8f61810
Showing 1 changed file with 65 additions and 51 deletions.
116 changes: 65 additions & 51 deletions src/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,28 @@
LOOP_THRESHOLD = 10


def handle_stream(stream, all_json=False, json_pp: bool = False):
def process_chunk(chunk, all_json, indent):
"""
Processes a data chunk and formats it as JSON if all_json is True.
:param chunk: The data chunk to be processed
:param all_json: Flag indicating if the output should be formatted as JSON
:param indent: The indentation level for JSON output
:returns: Processed chunk, optionally formatted as JSON
"""
This function takes the stream and formats in an appropriate way.
if all_json:
if chunk.startswith("Calling"):
resp = {"type": "status", "content": chunk}
else:
resp = {"type": "message", "content": chunk}
return json.dumps(resp, ensure_ascii=False, indent=indent) + "\n"
return chunk


def handle_stream(stream, all_json=False, json_pp: bool = False):
"""
Handles and processes the data stream.
Returns a json object in the format
Expand All @@ -24,92 +42,88 @@ def handle_stream(stream, all_json=False, json_pp: bool = False):
status are notifications from tool calls. message is a message with content for the user, history is the full
message history. The history needs to be passed in with your next call to allow the assistant to refer to earlier
messages.
messages.
:param stream: The input stream.
:param all_json: If true, return every chunk as JSON object. Returns only the content if false.
:param json_pp: If true pretty prints json. Mainly for debugging purposes, since it breaks the newline separation
on clients.
:return: The input stream, enriched with
:param stream: data stream to be processed
:param all_json: whether to process all data as JSON
:param json_pp: whether to pretty-print the JSON output
"""
if json_pp:
indent = 4
else:
indent = None
indent = 4 if json_pp else None

for chunk in stream:
if chunk is None:
continue

if type(chunk) is str:
if all_json:
if chunk.startswith("Calling"):
resp = {
"type": "status",
"content": chunk
}
yield json.dumps(resp, ensure_ascii=False, indent=indent) + "\n"
else:
resp = {
"type": "message",
"content": chunk
}
yield json.dumps(resp, ensure_ascii=False, indent=indent) + "\n"
else:
yield chunk

elif type(chunk) is dict or type(chunk) is list:
resp = {
"type": "history",
"content": chunk
}
yield process_chunk(chunk, all_json, indent)
elif isinstance(chunk, (dict, list)):
resp = {"type": "history", "content": chunk}
yield json.dumps(resp, ensure_ascii=False, indent=indent) + "\n"


def handle_context_window_error(client, prompt, messages, model):
"""
Handles context window errors by truncating the messages list.
:param client: OpenAI client instance to be used
:param prompt: The prompt to be sent to the OpenAI model
:param messages: List of messages as input for context
:param model: The OpenAI model to be called
:returns: The result from the OpenAI model call
"""
for i in range(1, len(messages) - 1):
messages = [messages[0], messages[-1]]
messages.pop(i)
return call_openai(client, prompt, messages, model)


def process_finish_reason(finish_reason: str, chunk_content: str):
"""
Processes the finish reason and raises exceptions for unsupported reasons.
:param finish_reason: The reason the process finished
:param chunk_content: The content related to the chunk being processed
:raises NotImplementedError: Tool Calls or unhandled finish reasons
"""
if finish_reason == "tool_calls":
raise NotImplementedError('Tool Calls are not supported in this application.')
elif finish_reason in ["stop", "tool_calls", None]:
return
else:
raise NotImplementedError(f"Unhandled finish reason: {finish_reason}.")


def tool_chain(client, prompt, messages, model: OpenAiModel = OpenAiModel.gpt35turbo):
"""
Handles the prompt to the LLM and calls the necessary tools if the LLM decides to use one.
:param client: The client session for OpenAI
:param prompt: The prompt or query to the LLM
:param messages: The message history without the current prompt.
:param tool_choice: Whether to use tools. See documentation of ToolChoice for further information.
:param model: The LLM model to use
:return: The LLM's answer as stream.
"""

tool_call_counter = 0
finish_reason = None
content = ""

while finish_reason != "stop" and tool_call_counter < LOOP_THRESHOLD:
try:
response = call_openai(client, prompt, messages, model)
except ContextWindowFullError:
for i in range(1, len(messages) - 1):
messages = [messages[0], messages[-1]]
messages.pop(i)

response = call_openai(client, prompt, messages, model)
response = handle_context_window_error(client, prompt, messages, model)

for chunk in response:
finish_reason = chunk.finish_reason
yield chunk.delta.content
content += str(chunk.delta.content)

if finish_reason == "tool_calls":
raise NotImplementedError('Tool Calls are not supported in this application.')

if finish_reason in ["stop", "tool_calls", None]:
continue
else:
raise NotImplementedError(f"Unhandled finish reason: {finish_reason}.")
process_finish_reason(finish_reason, chunk.delta.content)

messages.append(
{
'role': 'assistant',
'content': content
}
)

yield messages


Expand Down

0 comments on commit 8f61810

Please sign in to comment.