Skip to content

Commit ca2c197

Browse files
committed
fix: chat completions API now also passes tools along
1 parent e207f20 commit ca2c197

File tree

5 files changed

+413
-160
lines changed

5 files changed

+413
-160
lines changed

codex-rs/core/src/chat_completions.rs

Lines changed: 202 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@ use crate::flags::OPENAI_REQUEST_MAX_RETRIES;
2525
use crate::flags::OPENAI_STREAM_IDLE_TIMEOUT_MS;
2626
use crate::models::ContentItem;
2727
use crate::models::ResponseItem;
28+
use crate::openai_tools::create_tools_json;
2829
use crate::util::backoff;
2930

30-
/// Implementation for the classic Chat Completions API. This is intentionally
31-
/// minimal: we only stream back plain assistant text.
31+
/// Implementation for the classic Chat Completions API.
3232
pub(crate) async fn stream_chat_completions(
3333
prompt: &Prompt,
3434
model: &str,
@@ -42,31 +42,111 @@ pub(crate) async fn stream_chat_completions(
4242
messages.push(json!({"role": "system", "content": full_instructions}));
4343

4444
for item in &prompt.input {
45-
if let ResponseItem::Message { role, content } = item {
46-
let mut text = String::new();
47-
for c in content {
48-
match c {
49-
ContentItem::InputText { text: t } | ContentItem::OutputText { text: t } => {
50-
text.push_str(t);
45+
match item {
46+
ResponseItem::Message { role, content } => {
47+
let mut text = String::new();
48+
for c in content {
49+
match c {
50+
ContentItem::InputText { text: t }
51+
| ContentItem::OutputText { text: t } => {
52+
text.push_str(t);
53+
}
54+
_ => {}
5155
}
52-
_ => {}
5356
}
57+
messages.push(json!({"role": role, "content": text}));
58+
}
59+
ResponseItem::FunctionCall {
60+
name,
61+
arguments,
62+
call_id,
63+
} => {
64+
messages.push(json!({
65+
"role": "assistant",
66+
"content": null,
67+
"tool_calls": [{
68+
"id": call_id,
69+
"type": "function",
70+
"function": {
71+
"name": name,
72+
"arguments": arguments,
73+
}
74+
}]
75+
}));
76+
}
77+
ResponseItem::LocalShellCall {
78+
id,
79+
call_id: _,
80+
status,
81+
action,
82+
} => {
83+
// Confirm with API team.
84+
messages.push(json!({
85+
"role": "assistant",
86+
"content": null,
87+
"tool_calls": [{
88+
"id": id.clone().unwrap_or_else(|| "".to_string()),
89+
"type": "local_shell_call",
90+
"status": status,
91+
"action": action,
92+
}]
93+
}));
94+
}
95+
ResponseItem::FunctionCallOutput { call_id, output } => {
96+
messages.push(json!({
97+
"role": "tool",
98+
"tool_call_id": call_id,
99+
"content": output.content,
100+
}));
101+
}
102+
ResponseItem::Reasoning { .. } | ResponseItem::Other => {
103+
// Omit these items from the conversation history.
104+
continue;
54105
}
55-
messages.push(json!({"role": role, "content": text}));
56106
}
57107
}
58108

109+
let tools_json = create_tools_json(prompt, model)?;
110+
// create_tools_json() returns JSON values that are compatible with
111+
// Function Calling in the Responses API:
112+
// https://platform.openai.com/docs/guides/function-calling?api-mode=responses
113+
// So we must rewrite "tools" to match the chat completions tool call format:
114+
// https://platform.openai.com/docs/guides/function-calling?api-mode=chat
115+
let tools_json = tools_json
116+
.into_iter()
117+
.filter_map(|mut tool| {
118+
if tool.get("type") != Some(&serde_json::Value::String("function".to_string())) {
119+
return None;
120+
}
121+
122+
if let Some(map) = tool.as_object_mut() {
123+
// Remove "type" field as it is not needed in chat completions.
124+
map.remove("type");
125+
Some(json!({
126+
"type": "function",
127+
"function": map,
128+
}))
129+
} else {
130+
None
131+
}
132+
})
133+
.collect::<Vec<serde_json::Value>>();
134+
59135
let payload = json!({
60136
"model": model,
61137
"messages": messages,
62-
"stream": true
138+
"stream": true,
139+
"tools": tools_json,
63140
});
64141

65142
let base_url = provider.base_url.trim_end_matches('/');
66143
let url = format!("{}/chat/completions", base_url);
67144

68145
debug!(url, "POST (chat)");
69-
trace!("request payload: {}", payload);
146+
trace!(
147+
"request payload: {}",
148+
serde_json::to_string_pretty(&payload).unwrap_or_default()
149+
);
70150

71151
let api_key = provider.api_key()?;
72152
let mut attempt = 0;
@@ -134,6 +214,21 @@ where
134214

135215
let idle_timeout = *OPENAI_STREAM_IDLE_TIMEOUT_MS;
136216

217+
// State to accumulate a function call across streaming chunks.
218+
// OpenAI may split the `arguments` string over multiple `delta` events
219+
// until the chunk whose `finish_reason` is `tool_calls` is emitted. We
220+
// keep collecting the pieces here and forward a single
221+
// `ResponseItem::FunctionCall` once the call is complete.
222+
#[derive(Default)]
223+
struct FunctionCallState {
224+
name: Option<String>,
225+
arguments: String,
226+
call_id: Option<String>,
227+
active: bool,
228+
}
229+
230+
let mut fn_call_state = FunctionCallState::default();
231+
137232
loop {
138233
let sse = match timeout(idle_timeout, stream.next()).await {
139234
Ok(Some(Ok(ev))) => ev,
@@ -173,23 +268,89 @@ where
173268
Ok(v) => v,
174269
Err(_) => continue,
175270
};
271+
trace!("chat_completions received SSE chunk: {chunk:?}");
272+
273+
let choice_opt = chunk.get("choices").and_then(|c| c.get(0));
274+
275+
if let Some(choice) = choice_opt {
276+
// Handle assistant content tokens.
277+
if let Some(content) = choice
278+
.get("delta")
279+
.and_then(|d| d.get("content"))
280+
.and_then(|c| c.as_str())
281+
{
282+
let item = ResponseItem::Message {
283+
role: "assistant".to_string(),
284+
content: vec![ContentItem::OutputText {
285+
text: content.to_string(),
286+
}],
287+
};
288+
289+
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
290+
}
176291

177-
let content_opt = chunk
178-
.get("choices")
179-
.and_then(|c| c.get(0))
180-
.and_then(|c| c.get("delta"))
181-
.and_then(|d| d.get("content"))
182-
.and_then(|c| c.as_str());
183-
184-
if let Some(content) = content_opt {
185-
let item = ResponseItem::Message {
186-
role: "assistant".to_string(),
187-
content: vec![ContentItem::OutputText {
188-
text: content.to_string(),
189-
}],
190-
};
191-
192-
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
292+
// Handle streaming function / tool calls.
293+
if let Some(tool_calls) = choice
294+
.get("delta")
295+
.and_then(|d| d.get("tool_calls"))
296+
.and_then(|tc| tc.as_array())
297+
{
298+
if let Some(tool_call) = tool_calls.first() {
299+
// Mark that we have an active function call in progress.
300+
fn_call_state.active = true;
301+
302+
// Extract call_id if present.
303+
if let Some(id) = tool_call.get("id").and_then(|v| v.as_str()) {
304+
fn_call_state.call_id.get_or_insert_with(|| id.to_string());
305+
}
306+
307+
// Extract function details if present.
308+
if let Some(function) = tool_call.get("function") {
309+
if let Some(name) = function.get("name").and_then(|n| n.as_str()) {
310+
fn_call_state.name.get_or_insert_with(|| name.to_string());
311+
}
312+
313+
if let Some(args_fragment) =
314+
function.get("arguments").and_then(|a| a.as_str())
315+
{
316+
fn_call_state.arguments.push_str(args_fragment);
317+
}
318+
}
319+
}
320+
}
321+
322+
// Emit end-of-turn when finish_reason signals completion.
323+
if let Some(finish_reason) = choice.get("finish_reason").and_then(|v| v.as_str()) {
324+
match finish_reason {
325+
"tool_calls" if fn_call_state.active => {
326+
// Build the FunctionCall response item.
327+
let item = ResponseItem::FunctionCall {
328+
name: fn_call_state.name.clone().unwrap_or_else(|| "".to_string()),
329+
arguments: fn_call_state.arguments.clone(),
330+
call_id: fn_call_state.call_id.clone().unwrap_or_else(String::new),
331+
};
332+
333+
// Emit it downstream.
334+
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
335+
}
336+
"stop" => {
337+
// Regular turn without tool-call.
338+
}
339+
_ => {}
340+
}
341+
342+
// Emit Completed regardless of reason so the agent can advance.
343+
let _ = tx_event
344+
.send(Ok(ResponseEvent::Completed {
345+
response_id: String::new(),
346+
}))
347+
.await;
348+
349+
// Prepare for potential next turn (should not happen in same stream).
350+
// fn_call_state = FunctionCallState::default();
351+
352+
return; // End processing for this SSE stream.
353+
}
193354
}
194355
}
195356
}
@@ -236,20 +397,28 @@ where
236397
Poll::Ready(None) => return Poll::Ready(None),
237398
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
238399
Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))) => {
239-
// Accumulate *assistant* text but do not emit yet.
240-
if let crate::models::ResponseItem::Message { role, content } = &item {
241-
if role == "assistant" {
400+
// If this is an incremental assistant message chunk, accumulate but
401+
// do NOT emit yet. Forward any other item (e.g. FunctionCall) right
402+
// away so downstream consumers see it.
403+
404+
let is_assistant_delta = matches!(&item, crate::models::ResponseItem::Message { role, .. } if role == "assistant");
405+
406+
if is_assistant_delta {
407+
if let crate::models::ResponseItem::Message { content, .. } = &item {
242408
if let Some(text) = content.iter().find_map(|c| match c {
243409
crate::models::ContentItem::OutputText { text } => Some(text),
244410
_ => None,
245411
}) {
246412
this.cumulative.push_str(text);
247413
}
248414
}
415+
416+
// Swallow partial assistant chunk; keep polling.
417+
continue;
249418
}
250419

251-
// Swallow partial event; keep polling.
252-
continue;
420+
// Not an assistant message – forward immediately.
421+
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item))));
253422
}
254423
Poll::Ready(Some(Ok(ResponseEvent::Completed { response_id }))) => {
255424
if !this.cumulative.is_empty() {

0 commit comments

Comments
 (0)