Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions crates/goose-server/src/auth.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use axum::{
extract::{Request, State},
http::StatusCode,
middleware::Next,
response::Response,
};

pub async fn check_token(
State(state): State<String>,
request: Request,
next: Next,
) -> Result<Response, StatusCode> {
let secret_key = request
.headers()
.get("X-Secret-Key")
.and_then(|value| value.to_str().ok());

match secret_key {
Some(key) if key == state => Ok(next.run(request).await),
_ => Err(StatusCode::UNAUTHORIZED),
}
}
11 changes: 9 additions & 2 deletions crates/goose-server/src/commands/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ use std::sync::Arc;
use crate::configuration;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should really name this differently :) I was like, why is this in agent.rs

use crate::state;
use anyhow::Result;
use axum::middleware;
use etcetera::{choose_app_strategy, AppStrategy};
use goose::agents::Agent;
use goose::config::APP_STRATEGY;
use goose::scheduler_factory::SchedulerFactory;
use goose_server::auth::check_token;
use tower_http::cors::{Any, CorsLayer};
use tracing::info;

Expand All @@ -33,7 +35,7 @@ pub async fn run() -> Result<()> {
let new_agent = Agent::new();
let agent_ref = Arc::new(new_agent);

let app_state = state::AppState::new(agent_ref.clone(), secret_key.clone());
let app_state = state::AppState::new(agent_ref.clone());

let schedule_file_path = choose_app_strategy(APP_STRATEGY.clone())?
.data_dir()
Expand All @@ -50,7 +52,12 @@ pub async fn run() -> Result<()> {
.allow_methods(Any)
.allow_headers(Any);

let app = crate::routes::configure(app_state).layer(cors);
let app = crate::routes::configure(app_state)
.layer(middleware::from_fn_with_state(
secret_key.clone(),
check_token,
))
.layer(cors);

let listener = tokio::net::TcpListener::bind(settings.socket_addr()).await?;
info!("listening on {}", listener.local_addr()?);
Expand Down
1 change: 1 addition & 0 deletions crates/goose-server/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod auth;
pub mod openapi;
pub mod routes;
pub mod state;
Expand Down
1 change: 1 addition & 0 deletions crates/goose-server/src/openapi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ impl<'__s> ToSchema<'__s> for AnnotatedSchema {
#[derive(OpenApi)]
#[openapi(
paths(
super::routes::health::status,
super::routes::config_management::backup_config,
super::routes::config_management::recover_config,
super::routes::config_management::validate_config,
Expand Down
36 changes: 1 addition & 35 deletions crates/goose-server/src/routes/agent.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use super::utils::verify_secret_key;
use crate::state::AppState;
use axum::response::IntoResponse;
use axum::{
extract::{Query, State},
http::{HeaderMap, StatusCode},
http::StatusCode,
routing::{get, post},
Json, Router,
};
Expand Down Expand Up @@ -115,11 +114,8 @@ pub struct ErrorResponse {
)]
async fn start_agent(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(payload): Json<StartAgentRequest>,
) -> Result<Json<StartAgentResponse>, StatusCode> {
verify_secret_key(&headers, &state)?;

state.reset().await;

let session_id = session::generate_session_id();
Expand Down Expand Up @@ -168,12 +164,8 @@ async fn start_agent(
)
)]
async fn resume_agent(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(payload): Json<ResumeAgentRequest>,
) -> Result<Json<StartAgentResponse>, StatusCode> {
verify_secret_key(&headers, &state)?;

let session_path =
match session::get_path(session::Identifier::Name(payload.session_id.clone())) {
Ok(path) => path,
Expand Down Expand Up @@ -209,11 +201,8 @@ async fn resume_agent(
)]
async fn add_sub_recipes(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(payload): Json<AddSubRecipesRequest>,
) -> Result<Json<AddSubRecipesResponse>, StatusCode> {
verify_secret_key(&headers, &state)?;

let agent = state.get_agent().await;
agent.add_sub_recipes(payload.sub_recipes.clone()).await;
Ok(Json(AddSubRecipesResponse { success: true }))
Expand All @@ -231,11 +220,8 @@ async fn add_sub_recipes(
)]
async fn extend_prompt(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(payload): Json<ExtendPromptRequest>,
) -> Result<Json<ExtendPromptResponse>, StatusCode> {
verify_secret_key(&headers, &state)?;

let agent = state.get_agent().await;
agent.extend_system_prompt(payload.extension.clone()).await;
Ok(Json(ExtendPromptResponse { success: true }))
Expand All @@ -257,11 +243,8 @@ async fn extend_prompt(
)]
async fn get_tools(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Query(query): Query<GetToolsQuery>,
) -> Result<Json<Vec<ToolInfo>>, StatusCode> {
verify_secret_key(&headers, &state)?;

let config = Config::global();
let goose_mode = config.get_param("GOOSE_MODE").unwrap_or("auto".to_string());
let agent = state.get_agent().await;
Expand Down Expand Up @@ -314,11 +297,8 @@ async fn get_tools(
)]
async fn update_agent_provider(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(payload): Json<UpdateProviderRequest>,
) -> Result<StatusCode, impl IntoResponse> {
verify_secret_key(&headers, &state).map_err(|e| (e, String::new()))?;

let agent = state.get_agent().await;
let config = Config::global();
let model = match payload
Expand Down Expand Up @@ -364,15 +344,8 @@ async fn update_agent_provider(
)]
async fn update_router_tool_selector(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(_payload): Json<UpdateRouterToolSelectorRequest>,
) -> Result<Json<String>, Json<ErrorResponse>> {
verify_secret_key(&headers, &state).map_err(|_| {
Json(ErrorResponse {
error: "Unauthorized - Invalid or missing API key".to_string(),
})
})?;

let agent = state.get_agent().await;
agent
.update_router_tool_selector(None, Some(true))
Expand Down Expand Up @@ -402,15 +375,8 @@ async fn update_router_tool_selector(
)]
async fn update_session_config(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(payload): Json<SessionConfigRequest>,
) -> Result<Json<String>, Json<ErrorResponse>> {
verify_secret_key(&headers, &state).map_err(|_| {
Json(ErrorResponse {
error: "Unauthorized - Invalid or missing API key".to_string(),
})
})?;

let agent = state.get_agent().await;
if let Some(response) = payload.response {
agent.add_final_output_tool(response).await;
Expand Down
39 changes: 6 additions & 33 deletions crates/goose-server/src/routes/audio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@
///
/// This module provides endpoints for audio transcription using OpenAI's Whisper API.
/// The OpenAI API key must be configured in the backend for this to work.
use super::utils::verify_secret_key;
use crate::state::AppState;
use axum::{
extract::State,
http::{HeaderMap, StatusCode},
http::StatusCode,
routing::{get, post},
Json, Router,
};
Expand Down Expand Up @@ -209,12 +207,8 @@ async fn send_openai_request(
/// - 502: Bad Gateway (OpenAI API error)
/// - 503: Service Unavailable (network error)
async fn transcribe_handler(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(request): Json<TranscribeRequest>,
) -> Result<Json<TranscribeResponse>, StatusCode> {
verify_secret_key(&headers, &state)?;

let (audio_bytes, file_extension) = validate_audio_input(&request.audio, &request.mime_type)?;
let (api_key, openai_host) = get_openai_config()?;

Expand All @@ -237,12 +231,8 @@ async fn transcribe_handler(
/// Uses ElevenLabs' speech-to-text endpoint for transcription.
/// Requires an ElevenLabs API key with speech-to-text access.
async fn transcribe_elevenlabs_handler(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(request): Json<TranscribeElevenLabsRequest>,
) -> Result<Json<TranscribeResponse>, StatusCode> {
verify_secret_key(&headers, &state)?;

let (audio_bytes, file_extension) = validate_audio_input(&request.audio, &request.mime_type)?;

// Get the ElevenLabs API key from config (after input validation)
Expand Down Expand Up @@ -369,12 +359,7 @@ async fn transcribe_elevenlabs_handler(
/// Check if dictation providers are configured
///
/// Returns configuration status for dictation providers
async fn check_dictation_config(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
) -> Result<Json<serde_json::Value>, StatusCode> {
verify_secret_key(&headers, &state)?;

async fn check_dictation_config() -> Result<Json<serde_json::Value>, StatusCode> {
let config = goose::config::Config::global();

// Check if ElevenLabs API key is configured
Expand Down Expand Up @@ -410,10 +395,7 @@ mod tests {

#[tokio::test]
async fn test_transcribe_endpoint_requires_auth() {
let state = AppState::new(
Arc::new(goose::agents::Agent::new()),
"test-secret".to_string(),
);
let state = AppState::new(Arc::new(goose::agents::Agent::new()));
let app = routes(state);

// Test without auth header
Expand All @@ -436,10 +418,7 @@ mod tests {

#[tokio::test]
async fn test_transcribe_endpoint_validates_size() {
let state = AppState::new(
Arc::new(goose::agents::Agent::new()),
"test-secret".to_string(),
);
let state = AppState::new(Arc::new(goose::agents::Agent::new()));
let app = routes(state);

// Create a large base64 string (simulating > 25MB audio)
Expand All @@ -465,10 +444,7 @@ mod tests {

#[tokio::test]
async fn test_transcribe_endpoint_validates_mime_type() {
let state = AppState::new(
Arc::new(goose::agents::Agent::new()),
"test-secret".to_string(),
);
let state = AppState::new(Arc::new(goose::agents::Agent::new()));
let app = routes(state);

let request = Request::builder()
Expand All @@ -494,10 +470,7 @@ mod tests {

#[tokio::test]
async fn test_transcribe_endpoint_handles_invalid_base64() {
let state = AppState::new(
Arc::new(goose::agents::Agent::new()),
"test-secret".to_string(),
);
let state = AppState::new(Arc::new(goose::agents::Agent::new()));
let app = routes(state);

let request = Request::builder()
Expand Down
Loading
Loading