Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/goose-server/src/openapi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ impl<'__s> ToSchema<'__s> for AnnotatedSchema {
super::routes::config_management::get_extensions,
super::routes::config_management::read_all_config,
super::routes::config_management::providers,
super::routes::config_management::get_provider_models,
super::routes::config_management::upsert_permissions,
super::routes::config_management::create_custom_provider,
super::routes::config_management::remove_custom_provider,
Expand Down
89 changes: 86 additions & 3 deletions crates/goose-server/src/routes/config_management.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::utils::verify_secret_key;
use crate::routes::utils::check_provider_configured;
use crate::state::AppState;
use axum::{
extract::State,
extract::{Path, State},
routing::{delete, get, post},
Json, Router,
};
Expand Down Expand Up @@ -386,6 +386,45 @@ pub async fn providers(
Ok(Json(providers_response))
}

#[utoipa::path(
get,
path = "/config/providers/{name}/models",
params(
("name" = String, Path, description = "Provider name (e.g., openai)")
),
responses(
(status = 200, description = "Models fetched successfully", body = [String]),
(status = 400, description = "Unknown provider or provider not configured"),
(status = 500, description = "Internal server error")
)
)]
pub async fn get_provider_models(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Path(name): Path<String>,
) -> Result<Json<Vec<String>>, StatusCode> {
verify_secret_key(&headers, &state)?;

let all = get_providers();
let Some(metadata) = all.into_iter().find(|m| m.name == name) else {
return Err(StatusCode::BAD_REQUEST);
};
if !check_provider_configured(&metadata) {
return Err(StatusCode::BAD_REQUEST);
}

let model_config =
ModelConfig::new(&metadata.default_model).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let provider = goose::providers::create(&name, model_config)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

match provider.fetch_supported_models().await {
Ok(Some(models)) => Ok(Json(models)),
Ok(None) => Ok(Json(Vec::new())),
Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
}
}

#[derive(Serialize, ToSchema)]
pub struct PricingData {
pub provider: String,
Expand Down Expand Up @@ -771,6 +810,7 @@ pub fn routes(state: Arc<AppState>) -> Router {
.route("/config/extensions", post(add_extension))
.route("/config/extensions/{name}", delete(remove_extension))
.route("/config/providers", get(providers))
.route("/config/providers/{name}/models", get(get_provider_models))
.route("/config/pricing", post(get_pricing))
.route("/config/init", post(init_config))
.route("/config/backup", post(backup_config))
Expand All @@ -790,8 +830,7 @@ pub fn routes(state: Arc<AppState>) -> Router {
mod tests {
use super::*;

#[tokio::test]
async fn test_read_model_limits() {
async fn create_test_state() -> Arc<AppState> {
let test_state = AppState::new(
Arc::new(goose::agents::Agent::default()),
"test".to_string(),
Expand All @@ -805,6 +844,12 @@ mod tests {
.await
.unwrap();
test_state.set_scheduler(sched).await;
test_state
}

#[tokio::test]
async fn test_read_model_limits() {
let test_state = create_test_state().await;
let mut headers = HeaderMap::new();
headers.insert("X-Secret-Key", "test".parse().unwrap());

Expand All @@ -829,4 +874,42 @@ mod tests {
assert!(gpt4_limit.is_some());
assert_eq!(gpt4_limit.unwrap().context_limit, 128_000);
}

#[tokio::test]
async fn test_get_provider_models_unknown_provider() {
let test_state = create_test_state().await;
let mut headers = HeaderMap::new();
headers.insert("X-Secret-Key", "test".parse().unwrap());

let result = get_provider_models(
State(test_state),
headers,
Path("unknown_provider".to_string()),
)
.await;

assert!(result.is_err());
assert_eq!(result.unwrap_err(), StatusCode::BAD_REQUEST);
}

#[tokio::test]
async fn test_get_provider_models_openai_configured() {
std::env::set_var("OPENAI_API_KEY", "test-key");

let test_state = create_test_state().await;
let mut headers = HeaderMap::new();
headers.insert("X-Secret-Key", "test".parse().unwrap());

let result =
get_provider_models(State(test_state), headers, Path("openai".to_string())).await;

// The response should be INTERNAL_SERVER_ERROR since the API key is invalid
assert!(
result.is_err(),
"Expected error response from OpenAI provider with invalid key"
);
assert_eq!(result.unwrap_err(), StatusCode::INTERNAL_SERVER_ERROR);

std::env::remove_var("OPENAI_API_KEY");
}
}
40 changes: 40 additions & 0 deletions ui/desktop/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,46 @@
}
}
},
"/config/providers/{name}/models": {
"get": {
"tags": [
"super::routes::config_management"
],
"operationId": "get_provider_models",
"parameters": [
{
"name": "name",
"in": "path",
"description": "Provider name (e.g., openai)",
"required": true,
"schema": {
"type": "string"
}
}
],
"responses": {
"200": {
"description": "Models fetched successfully",
"content": {
"application/json": {
"schema": {
"type": "array",
"items": {
"type": "string"
}
}
}
}
},
"400": {
"description": "Unknown provider or provider not configured"
},
"500": {
"description": "Internal server error"
}
}
}
},
"/config/read": {
"post": {
"tags": [
Expand Down
9 changes: 8 additions & 1 deletion ui/desktop/src/api/sdk.gen.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// This file is auto-generated by @hey-api/openapi-ts

import type { Options as ClientOptions, TDataShape, Client } from './client';
import type { AddSubRecipesData, AddSubRecipesResponses, AddSubRecipesErrors, ExtendPromptData, ExtendPromptResponses, ExtendPromptErrors, UpdateSessionConfigData, UpdateSessionConfigResponses, UpdateSessionConfigErrors, GetToolsData, GetToolsResponses, GetToolsErrors, UpdateAgentProviderData, UpdateAgentProviderResponses, UpdateAgentProviderErrors, UpdateRouterToolSelectorData, UpdateRouterToolSelectorResponses, UpdateRouterToolSelectorErrors, ReadAllConfigData, ReadAllConfigResponses, BackupConfigData, BackupConfigResponses, BackupConfigErrors, CreateCustomProviderData, CreateCustomProviderResponses, CreateCustomProviderErrors, RemoveCustomProviderData, RemoveCustomProviderResponses, RemoveCustomProviderErrors, GetExtensionsData, GetExtensionsResponses, GetExtensionsErrors, AddExtensionData, AddExtensionResponses, AddExtensionErrors, RemoveExtensionData, RemoveExtensionResponses, RemoveExtensionErrors, InitConfigData, InitConfigResponses, InitConfigErrors, UpsertPermissionsData, UpsertPermissionsResponses, UpsertPermissionsErrors, ProvidersData, ProvidersResponses, ReadConfigData, ReadConfigResponses, ReadConfigErrors, RecoverConfigData, RecoverConfigResponses, RecoverConfigErrors, RemoveConfigData, RemoveConfigResponses, RemoveConfigErrors, UpsertConfigData, UpsertConfigResponses, UpsertConfigErrors, ValidateConfigData, ValidateConfigResponses, ValidateConfigErrors, ConfirmPermissionData, ConfirmPermissionResponses, ConfirmPermissionErrors, ManageContextData, ManageContextResponses, ManageContextErrors, CreateRecipeData, CreateRecipeResponses, CreateRecipeErrors, DecodeRecipeData, DecodeRecipeResponses, DecodeRecipeErrors, EncodeRecipeData, EncodeRecipeResponses, EncodeRecipeErrors, ScanRecipeData, ScanRecipeResponses, CreateScheduleData, CreateScheduleResponses, CreateScheduleErrors, DeleteScheduleData, DeleteScheduleResponses, DeleteScheduleErrors, ListSchedulesData, ListSchedulesResponses, ListSchedulesErrors, UpdateScheduleData, UpdateScheduleResponses, UpdateScheduleErrors, InspectRunningJobData, InspectRunningJobResponses, InspectRunningJobErrors, KillRunningJobData, KillRunningJobResponses, PauseScheduleData, PauseScheduleResponses, PauseScheduleErrors, RunNowHandlerData, RunNowHandlerResponses, RunNowHandlerErrors, SessionsHandlerData, SessionsHandlerResponses, SessionsHandlerErrors, UnpauseScheduleData, UnpauseScheduleResponses, UnpauseScheduleErrors, ListSessionsData, ListSessionsResponses, ListSessionsErrors, GetSessionHistoryData, GetSessionHistoryResponses, GetSessionHistoryErrors } from './types.gen';
import type { AddSubRecipesData, AddSubRecipesResponses, AddSubRecipesErrors, ExtendPromptData, ExtendPromptResponses, ExtendPromptErrors, UpdateSessionConfigData, UpdateSessionConfigResponses, UpdateSessionConfigErrors, GetToolsData, GetToolsResponses, GetToolsErrors, UpdateAgentProviderData, UpdateAgentProviderResponses, UpdateAgentProviderErrors, UpdateRouterToolSelectorData, UpdateRouterToolSelectorResponses, UpdateRouterToolSelectorErrors, ReadAllConfigData, ReadAllConfigResponses, BackupConfigData, BackupConfigResponses, BackupConfigErrors, CreateCustomProviderData, CreateCustomProviderResponses, CreateCustomProviderErrors, RemoveCustomProviderData, RemoveCustomProviderResponses, RemoveCustomProviderErrors, GetExtensionsData, GetExtensionsResponses, GetExtensionsErrors, AddExtensionData, AddExtensionResponses, AddExtensionErrors, RemoveExtensionData, RemoveExtensionResponses, RemoveExtensionErrors, InitConfigData, InitConfigResponses, InitConfigErrors, UpsertPermissionsData, UpsertPermissionsResponses, UpsertPermissionsErrors, ProvidersData, ProvidersResponses, GetProviderModelsData, GetProviderModelsResponses, GetProviderModelsErrors, ReadConfigData, ReadConfigResponses, ReadConfigErrors, RecoverConfigData, RecoverConfigResponses, RecoverConfigErrors, RemoveConfigData, RemoveConfigResponses, RemoveConfigErrors, UpsertConfigData, UpsertConfigResponses, UpsertConfigErrors, ValidateConfigData, ValidateConfigResponses, ValidateConfigErrors, ConfirmPermissionData, ConfirmPermissionResponses, ConfirmPermissionErrors, ManageContextData, ManageContextResponses, ManageContextErrors, CreateRecipeData, CreateRecipeResponses, CreateRecipeErrors, DecodeRecipeData, DecodeRecipeResponses, DecodeRecipeErrors, EncodeRecipeData, EncodeRecipeResponses, EncodeRecipeErrors, ScanRecipeData, ScanRecipeResponses, CreateScheduleData, CreateScheduleResponses, CreateScheduleErrors, DeleteScheduleData, DeleteScheduleResponses, DeleteScheduleErrors, ListSchedulesData, ListSchedulesResponses, ListSchedulesErrors, UpdateScheduleData, UpdateScheduleResponses, UpdateScheduleErrors, InspectRunningJobData, InspectRunningJobResponses, InspectRunningJobErrors, KillRunningJobData, KillRunningJobResponses, PauseScheduleData, PauseScheduleResponses, PauseScheduleErrors, RunNowHandlerData, RunNowHandlerResponses, RunNowHandlerErrors, SessionsHandlerData, SessionsHandlerResponses, SessionsHandlerErrors, UnpauseScheduleData, UnpauseScheduleResponses, UnpauseScheduleErrors, ListSessionsData, ListSessionsResponses, ListSessionsErrors, GetSessionHistoryData, GetSessionHistoryResponses, GetSessionHistoryErrors } from './types.gen';
import { client as _heyApiClient } from './client.gen';

export type Options<TData extends TDataShape = TDataShape, ThrowOnError extends boolean = boolean> = ClientOptions<TData, ThrowOnError> & {
Expand Down Expand Up @@ -158,6 +158,13 @@ export const providers = <ThrowOnError extends boolean = false>(options?: Option
});
};

export const getProviderModels = <ThrowOnError extends boolean = false>(options: Options<GetProviderModelsData, ThrowOnError>) => {
return (options.client ?? _heyApiClient).get<GetProviderModelsResponses, GetProviderModelsErrors, ThrowOnError>({
url: '/config/providers/{name}/models',
...options
});
};

export const readConfig = <ThrowOnError extends boolean = false>(options: Options<ReadConfigData, ThrowOnError>) => {
return (options.client ?? _heyApiClient).post<ReadConfigResponses, ReadConfigErrors, ThrowOnError>({
url: '/config/read',
Expand Down
32 changes: 32 additions & 0 deletions ui/desktop/src/api/types.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1257,6 +1257,38 @@ export type ProvidersResponses = {

export type ProvidersResponse2 = ProvidersResponses[keyof ProvidersResponses];

export type GetProviderModelsData = {
body?: never;
path: {
/**
* Provider name (e.g., openai)
*/
name: string;
};
query?: never;
url: '/config/providers/{name}/models';
};

export type GetProviderModelsErrors = {
/**
* Unknown provider or provider not configured
*/
400: unknown;
/**
* Internal server error
*/
500: unknown;
};

export type GetProviderModelsResponses = {
/**
* Models fetched successfully
*/
200: Array<string>;
};

export type GetProviderModelsResponse = GetProviderModelsResponses[keyof GetProviderModelsResponses];

export type ReadConfigData = {
body: ConfigKeyQuery;
path?: never;
Expand Down
19 changes: 19 additions & 0 deletions ui/desktop/src/components/ConfigContext.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {
addExtension as apiAddExtension,
removeExtension as apiRemoveExtension,
providers,
getProviderModels as apiGetProviderModels,
} from '../api';
import type {
ConfigResponse,
Expand Down Expand Up @@ -39,6 +40,7 @@ interface ConfigContextType {
removeExtension: (name: string) => Promise<void>;
getProviders: (b: boolean) => Promise<ProviderDetails[]>;
getExtensions: (b: boolean) => Promise<FixedExtensionEntry[]>;
getProviderModels: (providerName: string) => Promise<string[]>;
disableAllExtensions: () => Promise<void>;
enableBotExtensions: (extensions: ExtensionConfig[]) => Promise<void>;
}
Expand Down Expand Up @@ -185,6 +187,21 @@ export const ConfigProvider: React.FC<ConfigProviderProps> = ({ children }) => {
[providersList]
);

const getProviderModels = useCallback(async (providerName: string): Promise<string[]> => {
try {
const response = await apiGetProviderModels({
path: { name: providerName },
headers: {
'X-Secret-Key': await window.electron.getSecretKey(),
},
});
return response.data || [];
} catch (error) {
console.error(`Failed to fetch models for provider ${providerName}:`, error);
return [];
}
}, []);

useEffect(() => {
// Load all configuration data and providers on mount
(async () => {
Expand Down Expand Up @@ -242,6 +259,7 @@ export const ConfigProvider: React.FC<ConfigProviderProps> = ({ children }) => {
toggleExtension,
getProviders,
getExtensions,
getProviderModels,
disableAllExtensions,
enableBotExtensions,
};
Expand All @@ -257,6 +275,7 @@ export const ConfigProvider: React.FC<ConfigProviderProps> = ({ children }) => {
toggleExtension,
getProviders,
getExtensions,
getProviderModels,
reloadConfig,
]);

Expand Down
Loading
Loading