Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 65 additions & 2 deletions src/exchange/providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
tools_to_openai_spec,
)
from exchange.tool import Tool
from exchange.content import ToolUse, ToolResult, Text


OPENAI_HOST = "https://api.openai.com/"

Expand Down Expand Up @@ -64,7 +66,7 @@ def complete(
messages: List[Message],
tools: Tuple[Tool],
**kwargs: Dict[str, Any],
) -> Tuple[Message, Usage]:
) -> Tuple[Message, Usage]:
payload = dict(
messages=[
{"role": "system", "content": system},
Expand All @@ -77,16 +79,77 @@ def complete(
payload = {k: v for k, v in payload.items() if v}
response = self._send_request(payload)



# Check for context_length_exceeded error for single, long input message
if "error" in response.json() and len(messages) == 1:
openai_single_message_context_length_exceeded(response.json()["error"])

data = raise_for_status(response).json()

message = openai_response_to_message(data)

# optionally use the reasoning model to get a better answer if tool usage isn't required.
resoning_model = self.get_reasoning_model()
if resoning_model and len(self.tool_use(message)) == 0: # if a tool is needed we let things continue on without extra reasoning.
# will limit its invocation to non trivial things for now
filtered_messages = self.messages_filtered(messages)
latest_message = filtered_messages[-1]
latest_completion = data["choices"][0]["message"]["content"]
if len(latest_message['content']) > 50 and len(latest_completion) > 100:
print("---> using deep reasoning")
payload = dict(
messages=[
*filtered_messages,
{"role": "user", "content": "TASK: please check the answer that follows, if it is ok then return it, otherwise rewrite it and return it:" + latest_completion},
],
model=resoning_model,
**kwargs,
)
payload = {k: v for k, v in payload.items() if v}
response = self._send_request(payload)

# Check for context_length_exceeded error for single, long input message
if "error" in response.json() and len(messages) == 1:
openai_single_message_context_length_exceeded(response.json()["error"])

data = raise_for_status(response).json()
message = openai_response_to_message(data)

usage = self.get_usage(data)
return message, usage

def tool_use(self, message):
""" checks if the returned message is asking for tool usage or not """
return [content for content in message.content if isinstance(content, ToolUse) or isinstance(content, ToolResult)]

@retry_httpx_request()
def _send_request(self, payload: Any) -> httpx.Response: # noqa: ANN401
return self.client.post("v1/chat/completions", json=payload)


def messages_filtered(self, messages: List[Message]) -> List[Dict[str, Any]]:
"""This is for models that don't handle tool call output or images directly"""
messages_spec = []
for message in messages:
converted = {"role": "user"}
output = []
for content in message.content:
if isinstance(content, Text):
converted["content"] = content.text
elif isinstance(content, ToolResult):
output.append(
{
"role": "user",
"content": content.output,
}
)

if "content" in converted or "tool_calls" in converted:
output = [converted] + output
messages_spec.extend(output)
return messages_spec

def get_reasoning_model(self):
""" set this to o1-mini or o1-preview depending if you are doing code or not """
return os.environ.get("OPENAI_REASONING", None)