Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions codex-rs/codex-api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pub use crate::endpoint::responses_websocket::ResponsesWebsocketConnection;
pub use crate::error::ApiError;
pub use crate::provider::Provider;
pub use crate::provider::WireApi;
pub use crate::provider::is_azure_responses_wire_base_url;
pub use crate::requests::ChatRequest;
pub use crate::requests::ChatRequestBuilder;
pub use crate::requests::ResponsesRequest;
Expand Down
80 changes: 68 additions & 12 deletions codex-rs/codex-api/src/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,16 +95,7 @@ impl Provider {
}

pub fn is_azure_responses_endpoint(&self) -> bool {
if self.wire != WireApi::Responses {
return false;
}

if self.name.eq_ignore_ascii_case("azure") {
return true;
}

self.base_url.to_ascii_lowercase().contains("openai.azure.")
|| matches_azure_responses_base_url(&self.base_url)
is_azure_responses_wire_base_url(self.wire.clone(), &self.name, Some(&self.base_url))
}

pub fn websocket_url_for_path(&self, path: &str) -> Result<Url, url::ParseError> {
Expand All @@ -121,6 +112,23 @@ impl Provider {
}
}

pub fn is_azure_responses_wire_base_url(wire: WireApi, name: &str, base_url: Option<&str>) -> bool {
if wire != WireApi::Responses {
return false;
}

if name.eq_ignore_ascii_case("azure") {
return true;
}

let Some(base_url) = base_url else {
return false;
};

let base = base_url.to_ascii_lowercase();
base.contains("openai.azure.") || matches_azure_responses_base_url(&base)
}

fn matches_azure_responses_base_url(base_url: &str) -> bool {
const AZURE_MARKERS: [&str; 5] = [
"cognitiveservices.azure.",
Expand All @@ -129,6 +137,54 @@ fn matches_azure_responses_base_url(base_url: &str) -> bool {
"azurefd.",
"windows.net/openai",
];
let base = base_url.to_ascii_lowercase();
AZURE_MARKERS.iter().any(|marker| base.contains(marker))
AZURE_MARKERS.iter().any(|marker| base_url.contains(marker))
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn detects_azure_responses_base_urls() {
let positive_cases = [
"https://foo.openai.azure.com/openai",
"https://foo.openai.azure.us/openai/deployments/bar",
"https://foo.cognitiveservices.azure.cn/openai",
"https://foo.aoai.azure.com/openai",
"https://foo.openai.azure-api.net/openai",
"https://foo.z01.azurefd.net/",
];

for base_url in positive_cases {
assert!(
is_azure_responses_wire_base_url(WireApi::Responses, "test", Some(base_url)),
"expected {base_url} to be detected as Azure"
);
}

assert!(is_azure_responses_wire_base_url(
WireApi::Responses,
"Azure",
Some("https://example.com")
));

let negative_cases = [
"https://api.openai.com/v1",
"https://example.com/openai",
"https://myproxy.azurewebsites.net/openai",
];

for base_url in negative_cases {
assert!(
!is_azure_responses_wire_base_url(WireApi::Responses, "test", Some(base_url)),
"expected {base_url} not to be detected as Azure"
);
}

assert!(!is_azure_responses_wire_base_url(
WireApi::Chat,
"Azure",
Some("https://foo.openai.azure.com/openai")
));
}
}
1 change: 1 addition & 0 deletions codex-rs/core/src/codex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,7 @@ impl Session {
per_turn_config.model_personality = session_configuration.personality;
per_turn_config.web_search_mode = Some(resolve_web_search_mode_for_turn(
per_turn_config.web_search_mode,
session_configuration.provider.is_azure_responses_endpoint(),
session_configuration.sandbox_policy.get(),
));
per_turn_config.features = config.features.clone();
Expand Down
16 changes: 14 additions & 2 deletions codex-rs/core/src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1246,11 +1246,15 @@ fn resolve_web_search_mode(

pub(crate) fn resolve_web_search_mode_for_turn(
explicit_mode: Option<WebSearchMode>,
is_azure_responses_endpoint: bool,
sandbox_policy: &SandboxPolicy,
) -> WebSearchMode {
if let Some(mode) = explicit_mode {
return mode;
}
if is_azure_responses_endpoint {
return WebSearchMode::Disabled;
}
if matches!(sandbox_policy, SandboxPolicy::DangerFullAccess) {
WebSearchMode::Live
} else {
Expand Down Expand Up @@ -2347,14 +2351,14 @@ trust_level = "trusted"

#[test]
fn web_search_mode_for_turn_defaults_to_cached_when_unset() {
let mode = resolve_web_search_mode_for_turn(None, &SandboxPolicy::ReadOnly);
let mode = resolve_web_search_mode_for_turn(None, false, &SandboxPolicy::ReadOnly);

assert_eq!(mode, WebSearchMode::Cached);
}

#[test]
fn web_search_mode_for_turn_defaults_to_live_for_danger_full_access() {
let mode = resolve_web_search_mode_for_turn(None, &SandboxPolicy::DangerFullAccess);
let mode = resolve_web_search_mode_for_turn(None, false, &SandboxPolicy::DangerFullAccess);

assert_eq!(mode, WebSearchMode::Live);
}
Expand All @@ -2363,12 +2367,20 @@ trust_level = "trusted"
fn web_search_mode_for_turn_prefers_explicit_value() {
let mode = resolve_web_search_mode_for_turn(
Some(WebSearchMode::Cached),
false,
&SandboxPolicy::DangerFullAccess,
);

assert_eq!(mode, WebSearchMode::Cached);
}

#[test]
fn web_search_mode_for_turn_disables_for_azure_responses_endpoint() {
let mode = resolve_web_search_mode_for_turn(None, true, &SandboxPolicy::DangerFullAccess);

assert_eq!(mode, WebSearchMode::Disabled);
}

#[test]
fn profile_legacy_toggles_override_base() -> std::io::Result<()> {
let codex_home = TempDir::new()?;
Expand Down
2 changes: 2 additions & 0 deletions codex-rs/core/src/config_loader/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,7 @@ async fn codex_home_is_not_loaded_as_project_layer_from_home_dir() -> std::io::R
Some(cwd),
&[] as &[(String, TomlValue)],
LoaderOverrides::default(),
None,
)
.await?;

Expand Down Expand Up @@ -818,6 +819,7 @@ async fn codex_home_within_project_tree_is_not_double_loaded() -> std::io::Resul
Some(cwd),
&[] as &[(String, TomlValue)],
LoaderOverrides::default(),
None,
)
.await?;

Expand Down
93 changes: 10 additions & 83 deletions codex-rs/core/src/model_provider_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::auth::AuthMode;
use crate::error::EnvVarError;
use codex_api::Provider as ApiProvider;
use codex_api::WireApi as ApiWireApi;
use codex_api::is_azure_responses_wire_base_url;
use codex_api::provider::RetryConfig as ApiRetryConfig;
use http::HeaderMap;
use http::header::HeaderName;
Expand Down Expand Up @@ -170,6 +171,15 @@ impl ModelProviderInfo {
})
}

pub(crate) fn is_azure_responses_endpoint(&self) -> bool {
let wire = match self.wire_api {
WireApi::Responses => ApiWireApi::Responses,
WireApi::Chat => ApiWireApi::Chat,
};

is_azure_responses_wire_base_url(wire, &self.name, self.base_url.as_deref())
}

/// If `env_key` is Some, returns the API key for this provider if present
/// (and non-empty) in the environment. If `env_key` is required but
/// cannot be found, returns an error.
Expand Down Expand Up @@ -432,87 +442,4 @@ env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" }
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
assert_eq!(expected_provider, provider);
}

#[test]
fn detects_azure_responses_base_urls() {
let positive_cases = [
"https://foo.openai.azure.com/openai",
"https://foo.openai.azure.us/openai/deployments/bar",
"https://foo.cognitiveservices.azure.cn/openai",
"https://foo.aoai.azure.com/openai",
"https://foo.openai.azure-api.net/openai",
"https://foo.z01.azurefd.net/",
];
for base_url in positive_cases {
let provider = ModelProviderInfo {
name: "test".into(),
base_url: Some(base_url.into()),
env_key: None,
env_key_instructions: None,
experimental_bearer_token: None,
wire_api: WireApi::Responses,
query_params: None,
http_headers: None,
env_http_headers: None,
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
requires_openai_auth: false,
supports_websockets: false,
};
let api = provider.to_api_provider(None).expect("api provider");
assert!(
api.is_azure_responses_endpoint(),
"expected {base_url} to be detected as Azure"
);
}

let named_provider = ModelProviderInfo {
name: "Azure".into(),
base_url: Some("https://example.com".into()),
env_key: None,
env_key_instructions: None,
experimental_bearer_token: None,
wire_api: WireApi::Responses,
query_params: None,
http_headers: None,
env_http_headers: None,
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
requires_openai_auth: false,
supports_websockets: false,
};
let named_api = named_provider.to_api_provider(None).expect("api provider");
assert!(named_api.is_azure_responses_endpoint());

let negative_cases = [
"https://api.openai.com/v1",
"https://example.com/openai",
"https://myproxy.azurewebsites.net/openai",
];
for base_url in negative_cases {
let provider = ModelProviderInfo {
name: "test".into(),
base_url: Some(base_url.into()),
env_key: None,
env_key_instructions: None,
experimental_bearer_token: None,
wire_api: WireApi::Responses,
query_params: None,
http_headers: None,
env_http_headers: None,
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
requires_openai_auth: false,
supports_websockets: false,
};
let api = provider.to_api_provider(None).expect("api provider");
assert!(
!api.is_azure_responses_endpoint(),
"expected {base_url} not to be detected as Azure"
);
}
}
}
2 changes: 1 addition & 1 deletion codex-rs/core/tests/suite/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,5 +80,5 @@ mod unstable_features_warning;
mod user_notification;
mod user_shell_cmd;
mod view_image;
mod web_search_cached;
mod web_search;
mod websocket_fallback;
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#![allow(clippy::unwrap_used)]

use codex_core::WireApi;
use codex_core::built_in_model_providers;
use codex_core::features::Feature;
use codex_core::protocol::SandboxPolicy;
use codex_protocol::config_types::WebSearchMode;
Expand All @@ -25,6 +27,15 @@ fn find_web_search_tool(body: &Value) -> &Value {
.expect("tools should include a web_search tool")
}

#[allow(clippy::expect_used)]
fn has_web_search_tool(body: &Value) -> bool {
body["tools"]
.as_array()
.expect("request body should include tools array")
.iter()
.any(|tool| tool.get("type").and_then(Value::as_str) == Some("web_search"))
}

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn web_search_mode_cached_sets_external_web_access_false() {
skip_if_no_network!();
Expand Down Expand Up @@ -174,3 +185,45 @@ async fn web_search_mode_updates_between_turns_with_sandbox_policy() {
"danger-full-access policy should default web_search to live"
);
}

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn web_search_mode_defaults_to_disabled_for_azure_responses() {
skip_if_no_network!();

let server = start_mock_server().await;
let sse = sse_completed("resp-1");
let resp_mock = responses::mount_sse_once(&server, sse).await;

let mut builder = test_codex()
.with_model("gpt-5-codex")
.with_config(|config| {
let base_url = config.model_provider.base_url.clone();
let mut provider = built_in_model_providers()["openai"].clone();
provider.name = "Azure".to_string();
provider.base_url = base_url;
provider.wire_api = WireApi::Responses;
config.model_provider_id = provider.name.clone();
config.model_provider = provider;
config.web_search_mode = None;
config.features.disable(Feature::WebSearchCached);
config.features.disable(Feature::WebSearchRequest);
});
let test = builder
.build(&server)
.await
.expect("create test Codex conversation");

test.submit_turn_with_policy(
"hello azure default web search",
SandboxPolicy::DangerFullAccess,
)
.await
.expect("submit turn");

let body = resp_mock.single_request().body_json();
assert_eq!(
has_web_search_tool(&body),
false,
"azure responses requests should disable web_search by default"
);
}
Loading