diff --git a/src/exchange/providers/openai.py b/src/exchange/providers/openai.py index 1f3133b..6339a4b 100644 --- a/src/exchange/providers/openai.py +++ b/src/exchange/providers/openai.py @@ -14,6 +14,8 @@ tools_to_openai_spec, ) from exchange.tool import Tool +from exchange.content import ToolUse, ToolResult, Text + OPENAI_HOST = "https://api.openai.com/" @@ -64,7 +66,7 @@ def complete( messages: List[Message], tools: Tuple[Tool], **kwargs: Dict[str, Any], - ) -> Tuple[Message, Usage]: + ) -> Tuple[Message, Usage]: payload = dict( messages=[ {"role": "system", "content": system}, @@ -77,16 +79,77 @@ def complete( payload = {k: v for k, v in payload.items() if v} response = self._send_request(payload) + + # Check for context_length_exceeded error for single, long input message if "error" in response.json() and len(messages) == 1: openai_single_message_context_length_exceeded(response.json()["error"]) data = raise_for_status(response).json() - message = openai_response_to_message(data) + + # optionally use the reasoning model to get a better answer if tool usage isn't required. + resoning_model = self.get_reasoning_model() + if resoning_model and len(self.tool_use(message)) == 0: # if a tool is needed we let things continue on without extra reasoning. + # will limit its invocation to non trivial things for now + filtered_messages = self.messages_filtered(messages) + latest_message = filtered_messages[-1] + latest_completion = data["choices"][0]["message"]["content"] + if len(latest_message['content']) > 50 and len(latest_completion) > 100: + print("---> using deep reasoning") + payload = dict( + messages=[ + *filtered_messages, + {"role": "user", "content": "TASK: please check the answer that follows, if it is ok then return it, otherwise rewrite it and return it:" + latest_completion}, + ], + model=resoning_model, + **kwargs, + ) + payload = {k: v for k, v in payload.items() if v} + response = self._send_request(payload) + + # Check for context_length_exceeded error for single, long input message + if "error" in response.json() and len(messages) == 1: + openai_single_message_context_length_exceeded(response.json()["error"]) + + data = raise_for_status(response).json() + message = openai_response_to_message(data) + usage = self.get_usage(data) return message, usage + def tool_use(self, message): + """ checks if the returned message is asking for tool usage or not """ + return [content for content in message.content if isinstance(content, ToolUse) or isinstance(content, ToolResult)] + @retry_httpx_request() def _send_request(self, payload: Any) -> httpx.Response: # noqa: ANN401 return self.client.post("v1/chat/completions", json=payload) + + + def messages_filtered(self, messages: List[Message]) -> List[Dict[str, Any]]: + """This is for models that don't handle tool call output or images directly""" + messages_spec = [] + for message in messages: + converted = {"role": "user"} + output = [] + for content in message.content: + if isinstance(content, Text): + converted["content"] = content.text + elif isinstance(content, ToolResult): + output.append( + { + "role": "user", + "content": content.output, + } + ) + + if "content" in converted or "tool_calls" in converted: + output = [converted] + output + messages_spec.extend(output) + return messages_spec + + def get_reasoning_model(self): + """ set this to o1-mini or o1-preview depending if you are doing code or not """ + return os.environ.get("OPENAI_REASONING", None) +