diff --git a/codex-rs/rmcp-client/src/perform_oauth_login.rs b/codex-rs/rmcp-client/src/perform_oauth_login.rs index 64cf979ec74..09b746837e1 100644 --- a/codex-rs/rmcp-client/src/perform_oauth_login.rs +++ b/codex-rs/rmcp-client/src/perform_oauth_login.rs @@ -103,21 +103,32 @@ fn spawn_callback_server(server: Arc, tx: oneshot::Sender<(String, Strin tokio::task::spawn_blocking(move || { while let Ok(request) = server.recv() { let path = request.url().to_string(); - if let Some(OauthCallbackResult { code, state }) = parse_oauth_callback(&path) { - let response = - Response::from_string("Authentication complete. You may close this window."); - if let Err(err) = request.respond(response) { - eprintln!("Failed to respond to OAuth callback: {err}"); + match parse_oauth_callback(&path) { + CallbackOutcome::Success(OauthCallbackResult { code, state }) => { + let response = Response::from_string( + "Authentication complete. You may close this window.", + ); + if let Err(err) = request.respond(response) { + eprintln!("Failed to respond to OAuth callback: {err}"); + } + if let Err(err) = tx.send((code, state)) { + eprintln!("Failed to send OAuth callback: {err:?}"); + } + break; } - if let Err(err) = tx.send((code, state)) { - eprintln!("Failed to send OAuth callback: {err:?}"); + CallbackOutcome::Error(description) => { + let response = Response::from_string(format!("OAuth error: {description}")) + .with_status_code(400); + if let Err(err) = request.respond(response) { + eprintln!("Failed to respond to OAuth callback: {err}"); + } } - break; - } else { - let response = - Response::from_string("Invalid OAuth callback").with_status_code(400); - if let Err(err) = request.respond(response) { - eprintln!("Failed to respond to OAuth callback: {err}"); + CallbackOutcome::Invalid => { + let response = + Response::from_string("Invalid OAuth callback").with_status_code(400); + if let Err(err) = request.respond(response) { + eprintln!("Failed to respond to OAuth callback: {err}"); + } } } } @@ -129,29 +140,49 @@ struct OauthCallbackResult { state: String, } -fn parse_oauth_callback(path: &str) -> Option { - let (route, query) = path.split_once('?')?; +enum CallbackOutcome { + Success(OauthCallbackResult), + Error(String), + Invalid, +} + +fn parse_oauth_callback(path: &str) -> CallbackOutcome { + let Some((route, query)) = path.split_once('?') else { + return CallbackOutcome::Invalid; + }; if route != "/callback" { - return None; + return CallbackOutcome::Invalid; } let mut code = None; let mut state = None; + let mut error_description = None; for pair in query.split('&') { - let (key, value) = pair.split_once('=')?; - let decoded = decode(value).ok()?.into_owned(); + let Some((key, value)) = pair.split_once('=') else { + continue; + }; + let Ok(decoded) = decode(value) else { + continue; + }; + let decoded = decoded.into_owned(); match key { "code" => code = Some(decoded), "state" => state = Some(decoded), + "error_description" => error_description = Some(decoded), _ => {} } } - Some(OauthCallbackResult { - code: code?, - state: state?, - }) + if let (Some(code), Some(state)) = (code, state) { + return CallbackOutcome::Success(OauthCallbackResult { code, state }); + } + + if let Some(description) = error_description { + return CallbackOutcome::Error(description); + } + + CallbackOutcome::Invalid } pub struct OauthLoginHandle {