diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index d597d5ef04b..3b3e7f6807f 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -4010,7 +4010,11 @@ pub(crate) async fn run_turn( } // Construct the input that we will send to the model. - let sampling_request_input: Vec = { sess.clone_history().await.for_prompt() }; + let sampling_request_input: Vec = { + sess.clone_history() + .await + .for_prompt(&turn_context.model_info.input_modalities) + }; let sampling_request_input_messages = sampling_request_input .iter() @@ -6923,7 +6927,9 @@ mod tests { rollout_items.push(RolloutItem::ResponseItem(assistant1.clone())); let summary1 = "summary one"; - let snapshot1 = live_history.clone().for_prompt(); + let snapshot1 = live_history + .clone() + .for_prompt(&reconstruction_turn.model_info.input_modalities); let user_messages1 = collect_user_messages(&snapshot1); let rebuilt1 = compact::build_compacted_history(initial_context.clone(), &user_messages1, summary1); @@ -6964,7 +6970,9 @@ mod tests { rollout_items.push(RolloutItem::ResponseItem(assistant2.clone())); let summary2 = "summary two"; - let snapshot2 = live_history.clone().for_prompt(); + let snapshot2 = live_history + .clone() + .for_prompt(&reconstruction_turn.model_info.input_modalities); let user_messages2 = collect_user_messages(&snapshot2); let rebuilt2 = compact::build_compacted_history(initial_context.clone(), &user_messages2, summary2); @@ -7004,7 +7012,10 @@ mod tests { ); rollout_items.push(RolloutItem::ResponseItem(assistant3)); - (rollout_items, live_history.for_prompt()) + ( + rollout_items, + live_history.for_prompt(&reconstruction_turn.model_info.input_modalities), + ) } #[tokio::test] diff --git a/codex-rs/core/src/compact.rs b/codex-rs/core/src/compact.rs index bed810e2510..ed8abfc546f 100644 --- a/codex-rs/core/src/compact.rs +++ b/codex-rs/core/src/compact.rs @@ -112,7 +112,9 @@ async fn run_compact_task_inner( loop { // Clone is required because of the loop - let turn_input = history.clone().for_prompt(); + let turn_input = history + .clone() + .for_prompt(&turn_context.model_info.input_modalities); let turn_input_len = turn_input.len(); let prompt = Prompt { input: turn_input, diff --git a/codex-rs/core/src/compact_remote.rs b/codex-rs/core/src/compact_remote.rs index 8c7808e3fb4..6af43dddfdb 100644 --- a/codex-rs/core/src/compact_remote.rs +++ b/codex-rs/core/src/compact_remote.rs @@ -87,7 +87,7 @@ async fn run_remote_compact_task_inner_impl( .collect(); let prompt = Prompt { - input: history.for_prompt(), + input: history.for_prompt(&turn_context.model_info.input_modalities), tools: vec![], parallel_tool_calls: false, base_instructions, diff --git a/codex-rs/core/src/context_manager/history.rs b/codex-rs/core/src/context_manager/history.rs index a5b22380577..b6b3ceae5af 100644 --- a/codex-rs/core/src/context_manager/history.rs +++ b/codex-rs/core/src/context_manager/history.rs @@ -15,6 +15,7 @@ use codex_protocol::models::FunctionCallOutputBody; use codex_protocol::models::FunctionCallOutputContentItem; use codex_protocol::models::FunctionCallOutputPayload; use codex_protocol::models::ResponseItem; +use codex_protocol::openai_models::InputModality; use codex_protocol::protocol::TokenUsage; use codex_protocol::protocol::TokenUsageInfo; use std::ops::Deref; @@ -79,9 +80,11 @@ impl ContextManager { } /// Returns the history prepared for sending to the model. This applies a proper - /// normalization and drop un-suited items. - pub(crate) fn for_prompt(mut self) -> Vec { - self.normalize_history(); + /// normalization and drops un-suited items. When `input_modalities` does not + /// include `InputModality::Image`, images are stripped from messages and tool + /// outputs. + pub(crate) fn for_prompt(mut self, input_modalities: &[InputModality]) -> Vec { + self.normalize_history(input_modalities); self.items .retain(|item| !matches!(item, ResponseItem::GhostSnapshot { .. })); self.items @@ -309,12 +312,16 @@ impl ContextManager { /// This function enforces a couple of invariants on the in-memory history: /// 1. every call (function/custom) has a corresponding output entry /// 2. every output has a corresponding call entry - fn normalize_history(&mut self) { + /// 3. when images are unsupported, image content is stripped from messages and tool outputs + fn normalize_history(&mut self, input_modalities: &[InputModality]) { // all function/tool calls must have a corresponding output normalize::ensure_call_outputs_present(&mut self.items); // all outputs must have a corresponding function/tool call normalize::remove_orphan_outputs(&mut self.items); + + // strip images when model does not support them + normalize::strip_images_when_unsupported(input_modalities, &mut self.items); } fn process_item(&self, item: &ResponseItem, policy: TruncationPolicy) -> ResponseItem { diff --git a/codex-rs/core/src/context_manager/history_tests.rs b/codex-rs/core/src/context_manager/history_tests.rs index 99670c05e46..28c157b1e2a 100644 --- a/codex-rs/core/src/context_manager/history_tests.rs +++ b/codex-rs/core/src/context_manager/history_tests.rs @@ -12,6 +12,8 @@ use codex_protocol::models::LocalShellExecAction; use codex_protocol::models::LocalShellStatus; use codex_protocol::models::ReasoningItemContent; use codex_protocol::models::ReasoningItemReasoningSummary; +use codex_protocol::openai_models::InputModality; +use codex_protocol::openai_models::default_input_modalities; use pretty_assertions::assert_eq; use regex_lite::Regex; @@ -240,13 +242,122 @@ fn total_token_usage_includes_all_items_after_last_model_generated_item() { ); } +#[test] +fn for_prompt_strips_images_when_model_does_not_support_images() { + let items = vec![ + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ + ContentItem::InputText { + text: "look at this".to_string(), + }, + ContentItem::InputImage { + image_url: "https://example.com/img.png".to_string(), + }, + ContentItem::InputText { + text: "caption".to_string(), + }, + ], + end_turn: None, + phase: None, + }, + ResponseItem::FunctionCall { + id: None, + name: "view_image".to_string(), + arguments: "{}".to_string(), + call_id: "call-1".to_string(), + }, + ResponseItem::FunctionCallOutput { + call_id: "call-1".to_string(), + output: FunctionCallOutputPayload::from_content_items(vec![ + FunctionCallOutputContentItem::InputText { + text: "image result".to_string(), + }, + FunctionCallOutputContentItem::InputImage { + image_url: "https://example.com/result.png".to_string(), + }, + ]), + }, + ]; + let history = create_history_with_items(items); + let text_only_modalities = vec![InputModality::Text]; + let stripped = history.for_prompt(&text_only_modalities); + + let expected = vec![ + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ + ContentItem::InputText { + text: "look at this".to_string(), + }, + ContentItem::InputText { + text: "image content omitted because you do not support image input" + .to_string(), + }, + ContentItem::InputText { + text: "caption".to_string(), + }, + ], + end_turn: None, + phase: None, + }, + ResponseItem::FunctionCall { + id: None, + name: "view_image".to_string(), + arguments: "{}".to_string(), + call_id: "call-1".to_string(), + }, + ResponseItem::FunctionCallOutput { + call_id: "call-1".to_string(), + output: FunctionCallOutputPayload::from_content_items(vec![ + FunctionCallOutputContentItem::InputText { + text: "image result".to_string(), + }, + FunctionCallOutputContentItem::InputText { + text: "image content omitted because you do not support image input" + .to_string(), + }, + ]), + }, + ]; + assert_eq!(stripped, expected); + + // With image support, images are preserved + let modalities = default_input_modalities(); + let with_images = create_history_with_items(vec![ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ + ContentItem::InputText { + text: "look".to_string(), + }, + ContentItem::InputImage { + image_url: "https://example.com/img.png".to_string(), + }, + ], + end_turn: None, + phase: None, + }]); + let preserved = with_images.for_prompt(&modalities); + assert_eq!(preserved.len(), 1); + if let ResponseItem::Message { content, .. } = &preserved[0] { + assert_eq!(content.len(), 2); + assert!(matches!(content[1], ContentItem::InputImage { .. })); + } else { + panic!("expected Message"); + } +} + #[test] fn get_history_for_prompt_drops_ghost_commits() { let items = vec![ResponseItem::GhostSnapshot { ghost_commit: GhostCommit::new("ghost-1".to_string(), None, Vec::new(), Vec::new()), }]; let history = create_history_with_items(items); - let filtered = history.for_prompt(); + let modalities = default_input_modalities(); + let filtered = history.for_prompt(&modalities); assert_eq!(filtered, vec![]); } @@ -422,10 +533,11 @@ fn drop_last_n_user_turns_preserves_prefix() { assistant_msg("a2"), ]; + let modalities = default_input_modalities(); let mut history = create_history_with_items(items); history.drop_last_n_user_turns(1); assert_eq!( - history.for_prompt(), + history.for_prompt(&modalities), vec![ assistant_msg("session prefix item"), user_msg("u1"), @@ -442,7 +554,7 @@ fn drop_last_n_user_turns_preserves_prefix() { ]); history.drop_last_n_user_turns(99); assert_eq!( - history.for_prompt(), + history.for_prompt(&modalities), vec![assistant_msg("session prefix item")] ); } @@ -465,6 +577,7 @@ fn drop_last_n_user_turns_ignores_session_prefix_user_messages() { assistant_msg("turn 2 assistant"), ]; + let modalities = default_input_modalities(); let mut history = create_history_with_items(items); history.drop_last_n_user_turns(1); @@ -482,7 +595,10 @@ fn drop_last_n_user_turns_ignores_session_prefix_user_messages() { assistant_msg("turn 1 assistant"), ]; - assert_eq!(history.for_prompt(), expected_prefix_and_first_turn); + assert_eq!( + history.for_prompt(&modalities), + expected_prefix_and_first_turn + ); let expected_prefix_only = vec![ user_input_text_msg("ctx"), @@ -512,7 +628,7 @@ fn drop_last_n_user_turns_ignores_session_prefix_user_messages() { assistant_msg("turn 2 assistant"), ]); history.drop_last_n_user_turns(2); - assert_eq!(history.for_prompt(), expected_prefix_only); + assert_eq!(history.for_prompt(&modalities), expected_prefix_only); let mut history = create_history_with_items(vec![ user_input_text_msg("ctx"), @@ -530,7 +646,7 @@ fn drop_last_n_user_turns_ignores_session_prefix_user_messages() { assistant_msg("turn 2 assistant"), ]); history.drop_last_n_user_turns(3); - assert_eq!(history.for_prompt(), expected_prefix_only); + assert_eq!(history.for_prompt(&modalities), expected_prefix_only); } #[test] @@ -574,8 +690,9 @@ fn normalization_retains_local_shell_outputs() { }, ]; + let modalities = default_input_modalities(); let history = create_history_with_items(items.clone()); - let normalized = history.for_prompt(); + let normalized = history.for_prompt(&modalities); assert_eq!(normalized, items); } @@ -777,7 +894,7 @@ fn normalize_adds_missing_output_for_function_call() { }]; let mut h = create_history_with_items(items); - h.normalize_history(); + h.normalize_history(&default_input_modalities()); assert_eq!( h.raw_items(), @@ -808,7 +925,7 @@ fn normalize_adds_missing_output_for_custom_tool_call() { }]; let mut h = create_history_with_items(items); - h.normalize_history(); + h.normalize_history(&default_input_modalities()); assert_eq!( h.raw_items(), @@ -845,7 +962,7 @@ fn normalize_adds_missing_output_for_local_shell_call_with_id() { }]; let mut h = create_history_with_items(items); - h.normalize_history(); + h.normalize_history(&default_input_modalities()); assert_eq!( h.raw_items(), @@ -879,7 +996,7 @@ fn normalize_removes_orphan_function_call_output() { }]; let mut h = create_history_with_items(items); - h.normalize_history(); + h.normalize_history(&default_input_modalities()); assert_eq!(h.raw_items(), vec![]); } @@ -893,7 +1010,7 @@ fn normalize_removes_orphan_custom_tool_call_output() { }]; let mut h = create_history_with_items(items); - h.normalize_history(); + h.normalize_history(&default_input_modalities()); assert_eq!(h.raw_items(), vec![]); } @@ -938,7 +1055,7 @@ fn normalize_mixed_inserts_and_removals() { ]; let mut h = create_history_with_items(items); - h.normalize_history(); + h.normalize_history(&default_input_modalities()); assert_eq!( h.raw_items(), @@ -993,7 +1110,7 @@ fn normalize_adds_missing_output_for_function_call_inserts_output() { call_id: "call-x".to_string(), }]; let mut h = create_history_with_items(items); - h.normalize_history(); + h.normalize_history(&default_input_modalities()); assert_eq!( h.raw_items(), vec![ @@ -1023,7 +1140,7 @@ fn normalize_adds_missing_output_for_custom_tool_call_panics_in_debug() { input: "{}".to_string(), }]; let mut h = create_history_with_items(items); - h.normalize_history(); + h.normalize_history(&default_input_modalities()); } #[cfg(debug_assertions)] @@ -1043,7 +1160,7 @@ fn normalize_adds_missing_output_for_local_shell_call_with_id_panics_in_debug() }), }]; let mut h = create_history_with_items(items); - h.normalize_history(); + h.normalize_history(&default_input_modalities()); } #[cfg(debug_assertions)] @@ -1055,7 +1172,7 @@ fn normalize_removes_orphan_function_call_output_panics_in_debug() { output: FunctionCallOutputPayload::from_text("ok".to_string()), }]; let mut h = create_history_with_items(items); - h.normalize_history(); + h.normalize_history(&default_input_modalities()); } #[cfg(debug_assertions)] @@ -1067,7 +1184,7 @@ fn normalize_removes_orphan_custom_tool_call_output_panics_in_debug() { output: "ok".to_string(), }]; let mut h = create_history_with_items(items); - h.normalize_history(); + h.normalize_history(&default_input_modalities()); } #[cfg(debug_assertions)] @@ -1106,5 +1223,5 @@ fn normalize_mixed_inserts_and_removals_panics_in_debug() { }, ]; let mut h = create_history_with_items(items); - h.normalize_history(); + h.normalize_history(&default_input_modalities()); } diff --git a/codex-rs/core/src/context_manager/normalize.rs b/codex-rs/core/src/context_manager/normalize.rs index 37e177900fc..a4fe9e64fd3 100644 --- a/codex-rs/core/src/context_manager/normalize.rs +++ b/codex-rs/core/src/context_manager/normalize.rs @@ -1,12 +1,18 @@ use std::collections::HashSet; +use codex_protocol::models::ContentItem; use codex_protocol::models::FunctionCallOutputBody; +use codex_protocol::models::FunctionCallOutputContentItem; use codex_protocol::models::FunctionCallOutputPayload; use codex_protocol::models::ResponseItem; +use codex_protocol::openai_models::InputModality; use crate::util::error_or_panic; use tracing::info; +const IMAGE_CONTENT_OMITTED_PLACEHOLDER: &str = + "image content omitted because you do not support image input"; + pub(crate) fn ensure_call_outputs_present(items: &mut Vec) { // Collect synthetic outputs to insert immediately after their calls. // Store the insertion position (index of call) alongside the item so @@ -211,3 +217,53 @@ where items.remove(pos); } } + +/// Strip image content from messages and tool outputs when the model does not support images. +/// When `input_modalities` contains `InputModality::Image`, no stripping is performed. +pub(crate) fn strip_images_when_unsupported( + input_modalities: &[InputModality], + items: &mut [ResponseItem], +) { + let supports_images = input_modalities.contains(&InputModality::Image); + if supports_images { + return; + } + + for item in items.iter_mut() { + match item { + ResponseItem::Message { content, .. } => { + let mut normalized_content = Vec::with_capacity(content.len()); + for content_item in content.iter() { + match content_item { + ContentItem::InputImage { .. } => { + normalized_content.push(ContentItem::InputText { + text: IMAGE_CONTENT_OMITTED_PLACEHOLDER.to_string(), + }); + } + _ => normalized_content.push(content_item.clone()), + } + } + *content = normalized_content; + } + ResponseItem::FunctionCallOutput { output, .. } => { + if let Some(content_items) = output.content_items_mut() { + let mut normalized_content_items = Vec::with_capacity(content_items.len()); + for content_item in content_items.iter() { + match content_item { + FunctionCallOutputContentItem::InputImage { .. } => { + normalized_content_items.push( + FunctionCallOutputContentItem::InputText { + text: IMAGE_CONTENT_OMITTED_PLACEHOLDER.to_string(), + }, + ); + } + _ => normalized_content_items.push(content_item.clone()), + } + } + *content_items = normalized_content_items; + } + } + _ => {} + } + } +} diff --git a/codex-rs/core/tests/suite/model_switching.rs b/codex-rs/core/tests/suite/model_switching.rs index 73e957bdc48..b76dcb0467b 100644 --- a/codex-rs/core/tests/suite/model_switching.rs +++ b/codex-rs/core/tests/suite/model_switching.rs @@ -1,12 +1,24 @@ use anyhow::Result; +use codex_core::CodexAuth; use codex_core::config::types::Personality; use codex_core::features::Feature; +use codex_core::models_manager::manager::RefreshStrategy; use codex_core::protocol::AskForApproval; use codex_core::protocol::EventMsg; use codex_core::protocol::Op; use codex_core::protocol::SandboxPolicy; use codex_protocol::config_types::ReasoningSummary; +use codex_protocol::openai_models::ConfigShellToolType; +use codex_protocol::openai_models::InputModality; +use codex_protocol::openai_models::ModelInfo; +use codex_protocol::openai_models::ModelVisibility; +use codex_protocol::openai_models::ModelsResponse; +use codex_protocol::openai_models::ReasoningEffort; +use codex_protocol::openai_models::ReasoningEffortPreset; +use codex_protocol::openai_models::TruncationPolicyConfig; +use codex_protocol::openai_models::default_input_modalities; use codex_protocol::user_input::UserInput; +use core_test_support::responses::mount_models_once; use core_test_support::responses::mount_sse_sequence; use core_test_support::responses::sse_completed; use core_test_support::responses::start_mock_server; @@ -14,6 +26,8 @@ use core_test_support::skip_if_no_network; use core_test_support::test_codex::test_codex; use core_test_support::wait_for_event; use pretty_assertions::assert_eq; +use serde_json::Value; +use wiremock::MockServer; #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn model_change_appends_model_instructions_developer_message() -> Result<()> { @@ -190,3 +204,170 @@ async fn model_and_personality_change_only_appends_model_instructions() -> Resul Ok(()) } + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn model_change_from_image_to_text_strips_prior_image_content() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = MockServer::start().await; + let image_model_slug = "test-image-model"; + let text_model_slug = "test-text-only-model"; + let image_model = ModelInfo { + slug: image_model_slug.to_string(), + display_name: "Test Image Model".to_string(), + description: Some("supports image input".to_string()), + default_reasoning_level: Some(ReasoningEffort::Medium), + supported_reasoning_levels: vec![ReasoningEffortPreset { + effort: ReasoningEffort::Medium, + description: ReasoningEffort::Medium.to_string(), + }], + shell_type: ConfigShellToolType::ShellCommand, + visibility: ModelVisibility::List, + supported_in_api: true, + input_modalities: default_input_modalities(), + priority: 1, + upgrade: None, + base_instructions: "base instructions".to_string(), + model_messages: None, + supports_reasoning_summaries: false, + support_verbosity: false, + default_verbosity: None, + apply_patch_tool_type: None, + truncation_policy: TruncationPolicyConfig::bytes(10_000), + supports_parallel_tool_calls: false, + context_window: Some(272_000), + auto_compact_token_limit: None, + effective_context_window_percent: 95, + experimental_supported_tools: Vec::new(), + }; + let mut text_model = image_model.clone(); + text_model.slug = text_model_slug.to_string(); + text_model.display_name = "Test Text Model".to_string(); + text_model.description = Some("text only".to_string()); + text_model.input_modalities = vec![InputModality::Text]; + mount_models_once( + &server, + ModelsResponse { + models: vec![image_model, text_model], + }, + ) + .await; + + let responses = mount_sse_sequence( + &server, + vec![sse_completed("resp-1"), sse_completed("resp-2")], + ) + .await; + + let mut builder = test_codex() + .with_auth(CodexAuth::create_dummy_chatgpt_auth_for_testing()) + .with_config(move |config| { + config.features.enable(Feature::RemoteModels); + config.model = Some(image_model_slug.to_string()); + }); + let test = builder.build(&server).await?; + let models_manager = test.thread_manager.get_models_manager(); + let _ = models_manager + .list_models(&test.config, RefreshStrategy::OnlineIfUncached) + .await; + + let image_url = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR4nGNgYAAAAAMAASsJTYQAAAAASUVORK5CYII=" + .to_string(); + + test.codex + .submit(Op::UserTurn { + items: vec![ + UserInput::Image { + image_url: image_url.clone(), + }, + UserInput::Text { + text: "first turn".to_string(), + text_elements: Vec::new(), + }, + ], + final_output_json_schema: None, + cwd: test.cwd_path().to_path_buf(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::ReadOnly, + model: image_model_slug.to_string(), + effort: test.config.model_reasoning_effort, + summary: ReasoningSummary::Auto, + collaboration_mode: None, + personality: None, + }) + .await?; + wait_for_event(&test.codex, |ev| matches!(ev, EventMsg::TurnComplete(_))).await; + + test.codex + .submit(Op::UserTurn { + items: vec![UserInput::Text { + text: "second turn".to_string(), + text_elements: Vec::new(), + }], + final_output_json_schema: None, + cwd: test.cwd_path().to_path_buf(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::ReadOnly, + model: text_model_slug.to_string(), + effort: test.config.model_reasoning_effort, + summary: ReasoningSummary::Auto, + collaboration_mode: None, + personality: None, + }) + .await?; + wait_for_event(&test.codex, |ev| matches!(ev, EventMsg::TurnComplete(_))).await; + + let requests = responses.requests(); + assert_eq!(requests.len(), 2, "expected two model requests"); + + let first_request = requests.first().expect("expected first request"); + let first_has_input_image = first_request.inputs_of_type("message").iter().any(|item| { + item.get("content") + .and_then(Value::as_array) + .is_some_and(|content| { + content + .iter() + .any(|span| span.get("type").and_then(Value::as_str) == Some("input_image")) + }) + }); + assert!( + first_has_input_image, + "first request should include the uploaded image" + ); + + let second_request = requests.last().expect("expected second request"); + let second_has_input_image = second_request.inputs_of_type("message").iter().any(|item| { + item.get("content") + .and_then(Value::as_array) + .is_some_and(|content| { + content + .iter() + .any(|span| span.get("type").and_then(Value::as_str) == Some("input_image")) + }) + }); + assert!( + !second_has_input_image, + "second request should strip unsupported image content" + ); + let second_user_texts = second_request.message_input_texts("user"); + assert!( + second_user_texts + .iter() + .any(|text| text == "image content omitted because you do not support image input"), + "second request should include the image-omitted placeholder text" + ); + assert!( + second_user_texts + .iter() + .any(|text| text == &codex_protocol::models::image_open_tag_text()), + "second request should preserve the image open tag text" + ); + assert!( + second_user_texts + .iter() + .any(|text| text == &codex_protocol::models::image_close_tag_text()), + "second request should preserve the image close tag text" + ); + + Ok(()) +}