Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 138 additions & 9 deletions crates/goose/src/providers/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ use crate::providers::utils::RequestLog;
use rmcp::model::Tool;

const OPEN_AI_PROVIDER_NAME: &str = "openai";
const OPEN_AI_DEFAULT_BASE_PATH: &str = "v1/chat/completions";
const OPEN_AI_DEFAULT_RESPONSES_PATH: &str = "v1/responses";
const OPEN_AI_DEFAULT_MODELS_PATH: &str = "v1/models";
pub const OPEN_AI_DEFAULT_MODEL: &str = "gpt-4o";
pub const OPEN_AI_DEFAULT_FAST_MODEL: &str = "gpt-4o-mini";
pub const OPEN_AI_KNOWN_MODELS: &[(&str, usize)] = &[
Expand Down Expand Up @@ -87,7 +90,7 @@ impl OpenAiProvider {

let base_path: String = config
.get_param("OPENAI_BASE_PATH")
.unwrap_or_else(|_| "v1/chat/completions".to_string());
.unwrap_or_else(|_| OPEN_AI_DEFAULT_BASE_PATH.to_string());
let organization: Option<String> = config.get_param("OPENAI_ORGANIZATION").ok();
let project: Option<String> = config.get_param("OPENAI_PROJECT").ok();
let timeout_secs: u64 = config.get_param("OPENAI_TIMEOUT").unwrap_or(600);
Expand Down Expand Up @@ -133,7 +136,7 @@ impl OpenAiProvider {
pub fn new(api_client: ApiClient, model: ModelConfig) -> Self {
Self {
api_client,
base_path: "v1/chat/completions".to_string(),
base_path: OPEN_AI_DEFAULT_BASE_PATH.to_string(),
organization: None,
project: None,
model,
Expand Down Expand Up @@ -207,8 +210,64 @@ impl OpenAiProvider {
})
}

fn uses_responses_api(model_name: &str) -> bool {
model_name.starts_with("gpt-5-codex") || model_name.starts_with("gpt-5.1-codex")
fn normalize_base_path(base_path: &str) -> String {
if let Some(path) = base_path.strip_prefix('/') {
format!("/{}", path.trim_end_matches('/'))
} else {
base_path.trim_end_matches('/').to_string()
}
}

fn is_chat_completions_path(base_path: &str) -> bool {
let normalized = Self::normalize_base_path(base_path).to_ascii_lowercase();
normalized.contains("chat/completions")
}

fn is_responses_path(base_path: &str) -> bool {
let normalized = Self::normalize_base_path(base_path).to_ascii_lowercase();
normalized.ends_with("responses") || normalized.contains("/responses")
}

fn is_codex_gpt_5_model(model_name: &str) -> bool {
let normalized_model = model_name.to_ascii_lowercase();
normalized_model.starts_with("gpt-5") && normalized_model.contains("codex")
}

fn should_use_responses_api(model_name: &str, base_path: &str) -> bool {
let normalized_base_path = Self::normalize_base_path(base_path);
let has_custom_base_path = normalized_base_path != OPEN_AI_DEFAULT_BASE_PATH;

if has_custom_base_path {
if Self::is_responses_path(&normalized_base_path) {
return true;
}
if Self::is_chat_completions_path(&normalized_base_path) {
return false;
}
}

Self::is_codex_gpt_5_model(model_name)
}

fn map_base_path(base_path: &str, target: &str, fallback: &str) -> String {
let normalized = Self::normalize_base_path(base_path);
if normalized.ends_with(target) || normalized.contains(&format!("/{target}")) {
return normalized;
}

if Self::is_chat_completions_path(&normalized) {
return normalized.replacen("chat/completions", target, 1);
}

if Self::is_responses_path(&normalized) {
return normalized.replacen("responses", target, 1);
}

if normalized.starts_with('/') {
format!("/{}", fallback.trim_start_matches('/'))
} else {
fallback.to_string()
}
}

async fn post(
Expand All @@ -230,7 +289,11 @@ impl OpenAiProvider {
) -> Result<Value, ProviderError> {
let response = self
.api_client
.response_post(session_id, "v1/responses", payload)
.response_post(
session_id,
&Self::map_base_path(&self.base_path, "responses", OPEN_AI_DEFAULT_RESPONSES_PATH),
payload,
)
.await?;
handle_response_openai_compat(response).await
}
Expand Down Expand Up @@ -293,7 +356,7 @@ impl Provider for OpenAiProvider {
messages: &[Message],
tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError> {
if Self::uses_responses_api(&model_config.model_name) {
if Self::should_use_responses_api(&model_config.model_name, &self.base_path) {
let payload = create_responses_request(model_config, system, messages, tools)?;
let mut log = RequestLog::start(&self.model, &payload)?;

Expand Down Expand Up @@ -358,7 +421,8 @@ impl Provider for OpenAiProvider {
}

async fn fetch_supported_models(&self) -> Result<Vec<String>, ProviderError> {
let models_path = self.base_path.replace("v1/chat/completions", "v1/models");
let models_path =
Self::map_base_path(&self.base_path, "models", OPEN_AI_DEFAULT_MODELS_PATH);
let response = self
.api_client
.request(None, &models_path)
Expand Down Expand Up @@ -409,7 +473,7 @@ impl Provider for OpenAiProvider {
messages: &[Message],
tools: &[Tool],
) -> Result<MessageStream, ProviderError> {
if Self::uses_responses_api(&self.model.model_name) {
if Self::should_use_responses_api(&self.model.model_name, &self.base_path) {
let mut payload = create_responses_request(&self.model, system, messages, tools)?;
payload["stream"] = serde_json::Value::Bool(true);

Expand All @@ -420,7 +484,15 @@ impl Provider for OpenAiProvider {
let payload_clone = payload.clone();
let resp = self
.api_client
.response_post(Some(session_id), "v1/responses", &payload_clone)
.response_post(
Some(session_id),
&Self::map_base_path(
&self.base_path,
"responses",
OPEN_AI_DEFAULT_RESPONSES_PATH,
),
&payload_clone,
)
.await?;
handle_status_openai_compat(resp).await
})
Expand Down Expand Up @@ -539,3 +611,60 @@ impl EmbeddingCapable for OpenAiProvider {
.collect())
}
}

#[cfg(test)]
mod tests {
use super::OpenAiProvider;

#[test]
fn gpt_5_2_codex_uses_responses_when_base_path_is_default() {
assert!(OpenAiProvider::should_use_responses_api(
"gpt-5.2-codex",
"v1/chat/completions"
));
}

#[test]
fn explicit_chat_path_forces_chat_completions() {
assert!(!OpenAiProvider::should_use_responses_api(
"gpt-5.2-codex",
"openai/v1/chat/completions"
));
}

#[test]
fn custom_chat_path_maps_to_responses_path() {
let responses_path = OpenAiProvider::map_base_path(
"openai/v1/chat/completions",
"responses",
"v1/responses",
);
assert_eq!(responses_path, "openai/v1/responses");
}

#[test]
fn responses_path_maps_to_models_path() {
let models_path =
OpenAiProvider::map_base_path("openai/v1/responses", "models", "v1/models");
assert_eq!(models_path, "openai/v1/models");
}

#[test]
fn unknown_path_falls_back_to_default_models_path() {
let models_path = OpenAiProvider::map_base_path("custom/path", "models", "v1/models");
assert_eq!(models_path, "v1/models");
}

#[test]
fn absolute_chat_path_maps_to_absolute_responses_path() {
let responses_path =
OpenAiProvider::map_base_path("/v1/chat/completions", "responses", "v1/responses");
assert_eq!(responses_path, "/v1/responses");
}

#[test]
fn unknown_absolute_path_falls_back_to_absolute_models_path() {
let models_path = OpenAiProvider::map_base_path("/custom/path", "models", "v1/models");
assert_eq!(models_path, "/v1/models");
}
}