Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/goose-server/src/openapi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,7 @@ derive_utoipa!(Icon as IconSchema);
super::routes::recipe::ListRecipeResponse,
super::routes::recipe::DeleteRecipeRequest,
super::routes::recipe::SaveRecipeRequest,
super::routes::recipe::SaveRecipeResponse,
super::routes::errors::ErrorResponse,
super::routes::recipe::ParseRecipeRequest,
super::routes::recipe::ParseRecipeResponse,
Expand All @@ -480,7 +481,6 @@ derive_utoipa!(Icon as IconSchema);
super::routes::agent::UpdateRouterToolSelectorRequest,
super::routes::agent::StartAgentRequest,
super::routes::agent::ResumeAgentRequest,
super::routes::agent::ErrorResponse,
super::routes::setup::SetupResponse,
))
)]
Expand Down
97 changes: 80 additions & 17 deletions crates/goose-server/src/routes/agent.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use crate::routes::errors::ErrorResponse;
use crate::routes::recipe_utils::{load_recipe_by_id, validate_recipe};
use crate::state::AppState;
use axum::{
extract::{Query, State},
Expand All @@ -10,6 +12,7 @@ use goose::config::PermissionManager;
use goose::model::ModelConfig;
use goose::providers::create;
use goose::recipe::{Recipe, Response};
use goose::recipe_deeplink;
use goose::session::{Session, SessionManager};
use goose::{
agents::{extension::ToolInfo, extension_manager::get_parameter_names},
Expand All @@ -20,6 +23,7 @@ use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use tracing::error;

#[derive(Deserialize, utoipa::ToSchema)]
pub struct ExtendPromptRequest {
Expand Down Expand Up @@ -70,52 +74,105 @@ pub struct UpdateRouterToolSelectorRequest {
#[derive(Deserialize, utoipa::ToSchema)]
pub struct StartAgentRequest {
working_dir: String,
#[serde(default)]
recipe: Option<Recipe>,
#[serde(default)]
recipe_id: Option<String>,
#[serde(default)]
recipe_deeplink: Option<String>,
}

#[derive(Deserialize, utoipa::ToSchema)]
pub struct ResumeAgentRequest {
session_id: String,
}

#[derive(Serialize, utoipa::ToSchema)]
pub struct ErrorResponse {
error: String,
}

#[utoipa::path(
post,
path = "/agent/start",
request_body = StartAgentRequest,
responses(
(status = 200, description = "Agent started successfully", body = Session),
(status = 400, description = "Bad request - invalid working directory"),
(status = 400, description = "Bad request", body = ErrorResponse),
(status = 401, description = "Unauthorized - invalid secret key"),
(status = 500, description = "Internal server error")
(status = 500, description = "Internal server error", body = ErrorResponse)
)
)]
async fn start_agent(
State(state): State<Arc<AppState>>,
Json(payload): Json<StartAgentRequest>,
) -> Result<Json<Session>, StatusCode> {
) -> Result<Json<Session>, ErrorResponse> {
let StartAgentRequest {
working_dir,
recipe,
recipe_id,
recipe_deeplink,
} = payload;

let resolved_recipe = if let Some(deeplink) = recipe_deeplink {
match recipe_deeplink::decode(&deeplink) {
Ok(recipe) => Some(recipe),
Err(err) => {
error!("Failed to decode recipe deeplink: {}", err);
return Err(ErrorResponse {
message: err.to_string(),
status: StatusCode::BAD_REQUEST,
});
}
}
} else if let Some(id) = recipe_id {
match load_recipe_by_id(state.as_ref(), &id).await {
Ok(recipe) => Some(recipe),
Err(err) => return Err(err),
}
} else {
recipe
};

if let Some(ref recipe) = resolved_recipe {
if let Err(err) = validate_recipe(recipe) {
return Err(ErrorResponse {
message: err.message,
status: err.status,
});
}
}

let counter = state.session_counter.fetch_add(1, Ordering::SeqCst) + 1;
let description = format!("New session {}", counter);

let mut session =
SessionManager::create_session(PathBuf::from(&payload.working_dir), description)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let mut session = SessionManager::create_session(PathBuf::from(&working_dir), description)
.await
.map_err(|err| {
error!("Failed to create session: {}", err);
ErrorResponse {
message: format!("Failed to create session: {}", err),
status: StatusCode::BAD_REQUEST,
}
})?;

if let Some(recipe) = payload.recipe {
if let Some(recipe) = resolved_recipe {
SessionManager::update_session(&session.id)
.recipe(Some(recipe))
.apply()
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
.map_err(|err| {
error!("Failed to update session with recipe: {}", err);
ErrorResponse {
message: format!("Failed to update session with recipe: {}", err),
status: StatusCode::INTERNAL_SERVER_ERROR,
}
})?;

session = SessionManager::get_session(&session.id, false)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
.map_err(|err| {
error!("Failed to get updated session: {}", err);
ErrorResponse {
message: format!("Failed to get updated session: {}", err),
status: StatusCode::INTERNAL_SERVER_ERROR,
}
})?;
}

Ok(Json(session))
Expand All @@ -134,10 +191,16 @@ async fn start_agent(
)]
async fn resume_agent(
Json(payload): Json<ResumeAgentRequest>,
) -> Result<Json<Session>, StatusCode> {
) -> Result<Json<Session>, ErrorResponse> {
let session = SessionManager::get_session(&payload.session_id, true)
.await
.map_err(|_| StatusCode::NOT_FOUND)?;
.map_err(|err| {
error!("Failed to resume session {}: {}", payload.session_id, err);
ErrorResponse {
message: format!("Failed to resume session: {}", err),
status: StatusCode::NOT_FOUND,
}
})?;

Ok(Json(session))
}
Expand Down
86 changes: 31 additions & 55 deletions crates/goose-server/src/routes/recipe.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::collections::HashMap;
use std::fs;
use std::path::PathBuf;
use std::sync::Arc;

use axum::extract::rejection::JsonRejection;
Expand Down Expand Up @@ -38,7 +37,10 @@ fn clean_data_error(err: &axum::extract::rejection::JsonDataError) -> String {
}

use crate::routes::errors::ErrorResponse;
use crate::routes::recipe_utils::get_all_recipes_manifests;
use crate::routes::recipe_utils::{
get_all_recipes_manifests, get_recipe_file_path_by_id, short_id_from_path, validate_recipe,
RecipeValidationError,
};
use crate::state::AppState;

#[derive(Debug, Deserialize, ToSchema)]
Expand Down Expand Up @@ -97,6 +99,11 @@ pub struct SaveRecipeRequest {
recipe: Recipe,
id: Option<String>,
}

#[derive(Debug, Serialize, ToSchema)]
pub struct SaveRecipeResponse {
id: String,
}
#[derive(Debug, Deserialize, ToSchema)]
pub struct ParseRecipeRequest {
pub content: String,
Expand Down Expand Up @@ -231,7 +238,10 @@ async fn decode_recipe(
Json(request): Json<DecodeRecipeRequest>,
) -> Result<Json<DecodeRecipeResponse>, StatusCode> {
match recipe_deeplink::decode(&request.deeplink) {
Ok(recipe) => Ok(Json(DecodeRecipeResponse { recipe })),
Ok(recipe) => match validate_recipe(&recipe) {
Ok(_) => Ok(Json(DecodeRecipeResponse { recipe })),
Err(RecipeValidationError { status, .. }) => Err(status),
},
Err(err) => {
tracing::error!("Failed to decode deeplink: {}", err);
Err(StatusCode::BAD_REQUEST)
Expand Down Expand Up @@ -309,7 +319,7 @@ async fn delete_recipe(
State(state): State<Arc<AppState>>,
Json(request): Json<DeleteRecipeRequest>,
) -> StatusCode {
let file_path = match get_recipe_file_path_by_id(state.clone(), &request.id).await {
let file_path = match get_recipe_file_path_by_id(state.as_ref(), &request.id).await {
Ok(path) => path,
Err(err) => return err.status,
};
Expand All @@ -326,27 +336,31 @@ async fn delete_recipe(
path = "/recipes/save",
request_body = SaveRecipeRequest,
responses(
(status = 204, description = "Recipe saved to file successfully"),
(status = 204, description = "Recipe saved to file successfully", body = SaveRecipeResponse),
(status = 401, description = "Unauthorized - Invalid or missing API key"),
(status = 401, description = "Unauthorized", body = ErrorResponse),
(status = 404, description = "Not found", body = ErrorResponse),
(status = 500, description = "Internal server error", body = ErrorResponse)
),
tag = "Recipe Management"
)]
async fn save_recipe(
State(state): State<Arc<AppState>>,
payload: Result<Json<Value>, JsonRejection>,
) -> Result<StatusCode, ErrorResponse> {
) -> Result<Json<SaveRecipeResponse>, ErrorResponse> {
let Json(raw_json) = payload.map_err(json_rejection_to_error_response)?;
let request = deserialize_save_recipe_request(raw_json)?;
validate_recipe(&request.recipe)?;
ensure_recipe_valid(&request.recipe)?;

let file_path = match request.id.as_ref() {
Some(id) => Some(get_recipe_file_path_by_id(state.clone(), id).await?),
Some(id) => Some(get_recipe_file_path_by_id(state.as_ref(), id).await?),
None => None,
};

match local_recipes::save_recipe_to_file(request.recipe, file_path) {
Ok(_) => Ok(StatusCode::NO_CONTENT),
match local_recipes::save_recipe_to_file(request.recipe, file_path.clone()) {
Ok(save_file_path) => Ok(Json(SaveRecipeResponse {
id: short_id_from_path(&save_file_path.display().to_string()),
})),
Err(e) => Err(ErrorResponse {
message: e.to_string(),
status: StatusCode::INTERNAL_SERVER_ERROR,
Expand All @@ -361,16 +375,13 @@ fn json_rejection_to_error_response(rejection: JsonRejection) -> ErrorResponse {
}
}

fn validate_recipe(recipe: &Recipe) -> Result<(), ErrorResponse> {
let recipe_json = serde_json::to_string(recipe).map_err(|err| ErrorResponse {
message: err.to_string(),
status: StatusCode::BAD_REQUEST,
})?;

validate_recipe_template_from_content(&recipe_json, None).map_err(|err| ErrorResponse {
message: err.to_string(),
status: StatusCode::BAD_REQUEST,
})?;
fn ensure_recipe_valid(recipe: &Recipe) -> Result<(), ErrorResponse> {
if let Err(err) = validate_recipe(recipe) {
return Err(ErrorResponse {
message: err.message,
status: err.status,
});
}

Ok(())
}
Expand Down Expand Up @@ -398,41 +409,6 @@ fn deserialize_save_recipe_request(value: Value) -> Result<SaveRecipeRequest, Er
})
}

async fn get_recipe_file_path_by_id(
state: Arc<AppState>,
id: &str,
) -> Result<PathBuf, ErrorResponse> {
let cached_path = {
let map = state.recipe_file_hash_map.lock().await;
map.get(id).cloned()
};

if let Some(path) = cached_path {
return Ok(path);
}

let recipe_manifest_with_paths = get_all_recipes_manifests().unwrap_or_default();
let mut recipe_file_hash_map = HashMap::new();
let mut resolved_path: Option<PathBuf> = None;

for recipe_manifest_with_path in &recipe_manifest_with_paths {
if recipe_manifest_with_path.id == id {
resolved_path = Some(recipe_manifest_with_path.file_path.clone());
}
recipe_file_hash_map.insert(
recipe_manifest_with_path.id.clone(),
recipe_manifest_with_path.file_path.clone(),
);
}

state.set_recipe_file_hash_map(recipe_file_hash_map).await;

resolved_path.ok_or_else(|| ErrorResponse {
message: format!("Recipe not found: {}", id),
status: StatusCode::NOT_FOUND,
})
}

#[utoipa::path(
post,
path = "/recipes/parse",
Expand Down
Loading
Loading