Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 171 additions & 34 deletions codex-rs/core/src/chat_completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ use crate::models::ResponseItem;
use crate::openai_tools::create_tools_json_for_chat_completions_api;
use crate::util::backoff;

/// Implementation for the classic Chat Completions API. This is intentionally
/// minimal: we only stream back plain assistant text.
/// Implementation for the classic Chat Completions API.
pub(crate) async fn stream_chat_completions(
prompt: &Prompt,
model: &str,
Expand All @@ -43,17 +42,67 @@ pub(crate) async fn stream_chat_completions(
messages.push(json!({"role": "system", "content": full_instructions}));

for item in &prompt.input {
if let ResponseItem::Message { role, content } = item {
let mut text = String::new();
for c in content {
match c {
ContentItem::InputText { text: t } | ContentItem::OutputText { text: t } => {
text.push_str(t);
match item {
ResponseItem::Message { role, content } => {
let mut text = String::new();
for c in content {
match c {
ContentItem::InputText { text: t }
| ContentItem::OutputText { text: t } => {
text.push_str(t);
}
_ => {}
}
_ => {}
}
messages.push(json!({"role": role, "content": text}));
}
ResponseItem::FunctionCall {
name,
arguments,
call_id,
} => {
messages.push(json!({
"role": "assistant",
"content": null,
"tool_calls": [{
"id": call_id,
"type": "function",
"function": {
"name": name,
"arguments": arguments,
}
}]
}));
}
ResponseItem::LocalShellCall {
id,
call_id: _,
status,
action,
} => {
// Confirm with API team.
messages.push(json!({
"role": "assistant",
"content": null,
"tool_calls": [{
"id": id.clone().unwrap_or_else(|| "".to_string()),
"type": "local_shell_call",
"status": status,
"action": action,
}]
}));
}
ResponseItem::FunctionCallOutput { call_id, output } => {
messages.push(json!({
"role": "tool",
"tool_call_id": call_id,
"content": output.content,
}));
}
ResponseItem::Reasoning { .. } | ResponseItem::Other => {
// Omit these items from the conversation history.
continue;
}
messages.push(json!({"role": role, "content": text}));
}
}

Expand All @@ -68,9 +117,8 @@ pub(crate) async fn stream_chat_completions(
let base_url = provider.base_url.trim_end_matches('/');
let url = format!("{}/chat/completions", base_url);

debug!(url, "POST (chat)");
trace!(
"request payload: {}",
debug!(
"POST to {url}: {}",
serde_json::to_string_pretty(&payload).unwrap_or_default()
);

Expand Down Expand Up @@ -140,6 +188,21 @@ where

let idle_timeout = *OPENAI_STREAM_IDLE_TIMEOUT_MS;

// State to accumulate a function call across streaming chunks.
// OpenAI may split the `arguments` string over multiple `delta` events
// until the chunk whose `finish_reason` is `tool_calls` is emitted. We
// keep collecting the pieces here and forward a single
// `ResponseItem::FunctionCall` once the call is complete.
#[derive(Default)]
struct FunctionCallState {
name: Option<String>,
arguments: String,
call_id: Option<String>,
active: bool,
}

let mut fn_call_state = FunctionCallState::default();

loop {
let sse = match timeout(idle_timeout, stream.next()).await {
Ok(Some(Ok(ev))) => ev,
Expand Down Expand Up @@ -179,23 +242,89 @@ where
Ok(v) => v,
Err(_) => continue,
};
trace!("chat_completions received SSE chunk: {chunk:?}");

let choice_opt = chunk.get("choices").and_then(|c| c.get(0));

if let Some(choice) = choice_opt {
// Handle assistant content tokens.
if let Some(content) = choice
.get("delta")
.and_then(|d| d.get("content"))
.and_then(|c| c.as_str())
{
let item = ResponseItem::Message {
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
text: content.to_string(),
}],
};

let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
}

// Handle streaming function / tool calls.
if let Some(tool_calls) = choice
.get("delta")
.and_then(|d| d.get("tool_calls"))
.and_then(|tc| tc.as_array())
{
if let Some(tool_call) = tool_calls.first() {
// Mark that we have an active function call in progress.
fn_call_state.active = true;

// Extract call_id if present.
if let Some(id) = tool_call.get("id").and_then(|v| v.as_str()) {
fn_call_state.call_id.get_or_insert_with(|| id.to_string());
}

// Extract function details if present.
if let Some(function) = tool_call.get("function") {
if let Some(name) = function.get("name").and_then(|n| n.as_str()) {
fn_call_state.name.get_or_insert_with(|| name.to_string());
}

if let Some(args_fragment) =
function.get("arguments").and_then(|a| a.as_str())
{
fn_call_state.arguments.push_str(args_fragment);
}
}
}
}

// Emit end-of-turn when finish_reason signals completion.
if let Some(finish_reason) = choice.get("finish_reason").and_then(|v| v.as_str()) {
match finish_reason {
"tool_calls" if fn_call_state.active => {
// Build the FunctionCall response item.
let item = ResponseItem::FunctionCall {
name: fn_call_state.name.clone().unwrap_or_else(|| "".to_string()),
arguments: fn_call_state.arguments.clone(),
call_id: fn_call_state.call_id.clone().unwrap_or_else(String::new),
};

// Emit it downstream.
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
}
"stop" => {
// Regular turn without tool-call.
}
_ => {}
}

let content_opt = chunk
.get("choices")
.and_then(|c| c.get(0))
.and_then(|c| c.get("delta"))
.and_then(|d| d.get("content"))
.and_then(|c| c.as_str());

if let Some(content) = content_opt {
let item = ResponseItem::Message {
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
text: content.to_string(),
}],
};

let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
// Emit Completed regardless of reason so the agent can advance.
let _ = tx_event
.send(Ok(ResponseEvent::Completed {
response_id: String::new(),
}))
.await;

// Prepare for potential next turn (should not happen in same stream).
// fn_call_state = FunctionCallState::default();

return; // End processing for this SSE stream.
}
}
}
}
Expand Down Expand Up @@ -242,20 +371,28 @@ where
Poll::Ready(None) => return Poll::Ready(None),
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))) => {
// Accumulate *assistant* text but do not emit yet.
if let crate::models::ResponseItem::Message { role, content } = &item {
if role == "assistant" {
// If this is an incremental assistant message chunk, accumulate but
// do NOT emit yet. Forward any other item (e.g. FunctionCall) right
// away so downstream consumers see it.

let is_assistant_delta = matches!(&item, crate::models::ResponseItem::Message { role, .. } if role == "assistant");

if is_assistant_delta {
if let crate::models::ResponseItem::Message { content, .. } = &item {
if let Some(text) = content.iter().find_map(|c| match c {
crate::models::ContentItem::OutputText { text } => Some(text),
_ => None,
}) {
this.cumulative.push_str(text);
}
}

// Swallow partial assistant chunk; keep polling.
continue;
}

// Swallow partial event; keep polling.
continue;
// Not an assistant message – forward immediately.
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item))));
}
Poll::Ready(Some(Ok(ResponseEvent::Completed { response_id }))) => {
if !this.cumulative.is_empty() {
Expand Down
Loading