Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 34 additions & 26 deletions codex-rs/app-server/src/bespoke_event_handling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::codex_message_processor::read_summary_from_rollout;
use crate::codex_message_processor::summary_to_thread;
use crate::error_code::INTERNAL_ERROR_CODE;
use crate::error_code::INVALID_REQUEST_ERROR_CODE;
use crate::outgoing_message::OutgoingMessageSender;
use crate::outgoing_message::ThreadScopedOutgoingMessageSender;
use crate::thread_state::ThreadState;
use crate::thread_state::TurnSummary;
use codex_app_server_protocol::AccountRateLimitsUpdatedNotification;
Expand Down Expand Up @@ -107,7 +107,7 @@ pub(crate) async fn apply_bespoke_event_handling(
event: Event,
conversation_id: ThreadId,
conversation: Arc<CodexThread>,
outgoing: Arc<OutgoingMessageSender>,
outgoing: ThreadScopedOutgoingMessageSender,
thread_state: Arc<tokio::sync::Mutex<ThreadState>>,
api_version: ApiVersion,
fallback_model_provider: String,
Expand Down Expand Up @@ -850,7 +850,7 @@ pub(crate) async fn apply_bespoke_event_handling(
conversation_id,
&event_turn_id,
raw_response_item_event.item,
outgoing.as_ref(),
&outgoing,
)
.await;
}
Expand Down Expand Up @@ -899,7 +899,7 @@ pub(crate) async fn apply_bespoke_event_handling(
changes,
status,
event_turn_id.clone(),
outgoing.as_ref(),
&outgoing,
&thread_state,
)
.await;
Expand Down Expand Up @@ -1142,7 +1142,7 @@ pub(crate) async fn apply_bespoke_event_handling(
&event_turn_id,
turn_diff_event,
api_version,
outgoing.as_ref(),
&outgoing,
)
.await;
}
Expand All @@ -1152,7 +1152,7 @@ pub(crate) async fn apply_bespoke_event_handling(
&event_turn_id,
plan_update_event,
api_version,
outgoing.as_ref(),
&outgoing,
)
.await;
}
Expand All @@ -1166,7 +1166,7 @@ async fn handle_turn_diff(
event_turn_id: &str,
turn_diff_event: TurnDiffEvent,
api_version: ApiVersion,
outgoing: &OutgoingMessageSender,
outgoing: &ThreadScopedOutgoingMessageSender,
) {
if let ApiVersion::V2 = api_version {
let notification = TurnDiffUpdatedNotification {
Expand All @@ -1185,7 +1185,7 @@ async fn handle_turn_plan_update(
event_turn_id: &str,
plan_update_event: UpdatePlanArgs,
api_version: ApiVersion,
outgoing: &OutgoingMessageSender,
outgoing: &ThreadScopedOutgoingMessageSender,
) {
// `update_plan` is a todo/checklist tool; it is not related to plan-mode updates
if let ApiVersion::V2 = api_version {
Expand All @@ -1210,7 +1210,7 @@ async fn emit_turn_completed_with_status(
event_turn_id: String,
status: TurnStatus,
error: Option<TurnError>,
outgoing: &OutgoingMessageSender,
outgoing: &ThreadScopedOutgoingMessageSender,
) {
let notification = TurnCompletedNotification {
thread_id: conversation_id.to_string(),
Expand All @@ -1232,7 +1232,7 @@ async fn complete_file_change_item(
changes: Vec<FileUpdateChange>,
status: PatchApplyStatus,
turn_id: String,
outgoing: &OutgoingMessageSender,
outgoing: &ThreadScopedOutgoingMessageSender,
thread_state: &Arc<Mutex<ThreadState>>,
) {
let mut state = thread_state.lock().await;
Expand Down Expand Up @@ -1264,7 +1264,7 @@ async fn complete_command_execution_item(
process_id: Option<String>,
command_actions: Vec<V2ParsedCommand>,
status: CommandExecutionStatus,
outgoing: &OutgoingMessageSender,
outgoing: &ThreadScopedOutgoingMessageSender,
) {
let item = ThreadItem::CommandExecution {
id: item_id,
Expand Down Expand Up @@ -1292,7 +1292,7 @@ async fn maybe_emit_raw_response_item_completed(
conversation_id: ThreadId,
turn_id: &str,
item: codex_protocol::models::ResponseItem,
outgoing: &OutgoingMessageSender,
outgoing: &ThreadScopedOutgoingMessageSender,
) {
let ApiVersion::V2 = api_version else {
return;
Expand All @@ -1319,7 +1319,7 @@ async fn find_and_remove_turn_summary(
async fn handle_turn_complete(
conversation_id: ThreadId,
event_turn_id: String,
outgoing: &OutgoingMessageSender,
outgoing: &ThreadScopedOutgoingMessageSender,
thread_state: &Arc<Mutex<ThreadState>>,
) {
let turn_summary = find_and_remove_turn_summary(conversation_id, thread_state).await;
Expand All @@ -1335,7 +1335,7 @@ async fn handle_turn_complete(
async fn handle_turn_interrupted(
conversation_id: ThreadId,
event_turn_id: String,
outgoing: &OutgoingMessageSender,
outgoing: &ThreadScopedOutgoingMessageSender,
thread_state: &Arc<Mutex<ThreadState>>,
) {
find_and_remove_turn_summary(conversation_id, thread_state).await;
Expand All @@ -1354,7 +1354,7 @@ async fn handle_thread_rollback_failed(
_conversation_id: ThreadId,
message: String,
thread_state: &Arc<Mutex<ThreadState>>,
outgoing: &OutgoingMessageSender,
outgoing: &ThreadScopedOutgoingMessageSender,
) {
let pending_rollback = thread_state.lock().await.pending_rollbacks.take();

Expand All @@ -1376,7 +1376,7 @@ async fn handle_token_count_event(
conversation_id: ThreadId,
turn_id: String,
token_count_event: TokenCountEvent,
outgoing: &OutgoingMessageSender,
outgoing: &ThreadScopedOutgoingMessageSender,
) {
let TokenCountEvent { info, rate_limits } = token_count_event;
if let Some(token_usage) = info.map(ThreadTokenUsage::from) {
Expand Down Expand Up @@ -1633,7 +1633,7 @@ async fn on_file_change_request_approval_response(
changes: Vec<FileUpdateChange>,
receiver: oneshot::Receiver<JsonValue>,
codex: Arc<CodexThread>,
outgoing: Arc<OutgoingMessageSender>,
outgoing: ThreadScopedOutgoingMessageSender,
thread_state: Arc<Mutex<ThreadState>>,
) {
let response = receiver.await;
Expand Down Expand Up @@ -1666,7 +1666,7 @@ async fn on_file_change_request_approval_response(
changes,
status,
event_turn_id.clone(),
outgoing.as_ref(),
&outgoing,
&thread_state,
)
.await;
Expand All @@ -1693,7 +1693,7 @@ async fn on_command_execution_request_approval_response(
command_actions: Vec<V2ParsedCommand>,
receiver: oneshot::Receiver<JsonValue>,
conversation: Arc<CodexThread>,
outgoing: Arc<OutgoingMessageSender>,
outgoing: ThreadScopedOutgoingMessageSender,
) {
let response = receiver.await;
let (decision, completion_status) = match response {
Expand Down Expand Up @@ -1748,7 +1748,7 @@ async fn on_command_execution_request_approval_response(
None,
command_actions.clone(),
status,
outgoing.as_ref(),
&outgoing,
)
.await;
}
Expand Down Expand Up @@ -1876,6 +1876,7 @@ async fn construct_mcp_tool_call_end_notification(
mod tests {
use super::*;
use crate::CHANNEL_CAPACITY;
use crate::outgoing_message::ConnectionId;
use crate::outgoing_message::OutgoingEnvelope;
use crate::outgoing_message::OutgoingMessage;
use crate::outgoing_message::OutgoingMessageSender;
Expand Down Expand Up @@ -1914,9 +1915,7 @@ mod tests {
.ok_or_else(|| anyhow!("should send one message"))?;
match envelope {
OutgoingEnvelope::Broadcast { message } => Ok(message),
OutgoingEnvelope::ToConnection { connection_id, .. } => {
bail!("unexpected targeted message for connection {connection_id:?}")
}
OutgoingEnvelope::ToConnection { message, .. } => Ok(message),
}
}

Expand Down Expand Up @@ -2011,6 +2010,7 @@ mod tests {
let event_turn_id = "complete1".to_string();
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
let outgoing = Arc::new(OutgoingMessageSender::new(tx));
let outgoing = ThreadScopedOutgoingMessageSender::new(outgoing, vec![ConnectionId(1)]);
let thread_state = new_thread_state();

handle_turn_complete(
Expand Down Expand Up @@ -2051,6 +2051,7 @@ mod tests {
.await;
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
let outgoing = Arc::new(OutgoingMessageSender::new(tx));
let outgoing = ThreadScopedOutgoingMessageSender::new(outgoing, vec![ConnectionId(1)]);

handle_turn_interrupted(
conversation_id,
Expand Down Expand Up @@ -2090,6 +2091,7 @@ mod tests {
.await;
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
let outgoing = Arc::new(OutgoingMessageSender::new(tx));
let outgoing = ThreadScopedOutgoingMessageSender::new(outgoing, vec![ConnectionId(1)]);

handle_turn_complete(
conversation_id,
Expand Down Expand Up @@ -2122,7 +2124,8 @@ mod tests {
#[tokio::test]
async fn test_handle_turn_plan_update_emits_notification_for_v2() -> Result<()> {
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
let outgoing = OutgoingMessageSender::new(tx);
let outgoing = Arc::new(OutgoingMessageSender::new(tx));
let outgoing = ThreadScopedOutgoingMessageSender::new(outgoing, vec![ConnectionId(1)]);
let update = UpdatePlanArgs {
explanation: Some("need plan".to_string()),
plan: vec![
Expand Down Expand Up @@ -2172,6 +2175,7 @@ mod tests {
let turn_id = "turn-123".to_string();
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
let outgoing = Arc::new(OutgoingMessageSender::new(tx));
let outgoing = ThreadScopedOutgoingMessageSender::new(outgoing, vec![ConnectionId(1)]);

let info = TokenUsageInfo {
total_token_usage: TokenUsage {
Expand Down Expand Up @@ -2255,6 +2259,7 @@ mod tests {
let turn_id = "turn-456".to_string();
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
let outgoing = Arc::new(OutgoingMessageSender::new(tx));
let outgoing = ThreadScopedOutgoingMessageSender::new(outgoing, vec![ConnectionId(1)]);

handle_token_count_event(
conversation_id,
Expand Down Expand Up @@ -2321,6 +2326,7 @@ mod tests {

let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
let outgoing = Arc::new(OutgoingMessageSender::new(tx));
let outgoing = ThreadScopedOutgoingMessageSender::new(outgoing, vec![ConnectionId(1)]);

// Turn 1 on conversation A
let a_turn1 = "a_turn1".to_string();
Expand Down Expand Up @@ -2542,7 +2548,8 @@ mod tests {
#[tokio::test]
async fn test_handle_turn_diff_emits_v2_notification() -> Result<()> {
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
let outgoing = OutgoingMessageSender::new(tx);
let outgoing = Arc::new(OutgoingMessageSender::new(tx));
let outgoing = ThreadScopedOutgoingMessageSender::new(outgoing, vec![ConnectionId(1)]);
let unified_diff = "--- a\n+++ b\n".to_string();
let conversation_id = ThreadId::new();

Expand Down Expand Up @@ -2575,7 +2582,8 @@ mod tests {
#[tokio::test]
async fn test_handle_turn_diff_is_noop_for_v1() -> Result<()> {
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
let outgoing = OutgoingMessageSender::new(tx);
let outgoing = Arc::new(OutgoingMessageSender::new(tx));
let outgoing = ThreadScopedOutgoingMessageSender::new(outgoing, vec![ConnectionId(1)]);
let conversation_id = ThreadId::new();

handle_turn_diff(
Expand Down
Loading
Loading